Last active
September 22, 2023 02:51
-
-
Save nikola-j/b5bb6b141b8d9920318677e1bba70466 to your computer and use it in GitHub Desktop.
Atan2 pytorch onnx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def my_atan2(y, x): | |
| pi = torch.from_numpy(np.array([np.pi])).to(y.device, y.dtype) | |
| ans = torch.atan(y / (x + 1e-6)) | |
| ans += ((y > 0) & (x < 0)) * pi | |
| ans -= ((y < 0) & (x < 0)) * pi | |
| ans *= (1 - ((y > 0) & (x == 0)) * 1.0) | |
| ans += ((y > 0) & (x == 0)) * (pi / 2) | |
| ans *= (1 - ((y < 0) & (x == 0)) * 1.0) | |
| ans += ((y < 0) & (x == 0)) * (-pi / 2) | |
| return ans |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This optimized version includes the following improvements:
Added comments in English to explain each step in the code.
Used torch.tensor to create the pi tensor directly instead of using torch.from_numpy.
Defined eps as a separate variable, making it easier to adjust if needed.
These improvements make the code more readable while maintaining performance optimizations.