Last active
January 29, 2024 14:49
-
-
Save floffy-f/0e5293dde8b921d3db318094d7652e1d to your computer and use it in GitHub Desktop.
Lower-triangular circulant embedding of pytorch vector
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Partially inspired from https://stackoverflow.com/a/70686229 | |
| """ | |
| import torch as pt | |
| def lowtr_circulant(diag: pt.Tensor, dim: int=-1): | |
| r""" | |
| From a vector ``diag``, computes its flipped circular embedding with zeros | |
| above the diagonal. | |
| Can be batched in left dimension(s). | |
| ::math | |
| \lambda \to \begin{matrix} | |
| \lambda_1 & & & \\ | |
| \lambda_2 & \ddots & 0 & \\ | |
| \vdots & \cdot & \ddots & \\ | |
| \lambda_n & \dots & \lambda_2 & \lambda_1 \\ | |
| \end{matrix} | |
| """ | |
| d = diag.shape[dim] # Dimension of the vector | |
| zero_padds = pt.zeros(d - 1).to(diag) | |
| return pt.cat([zero_padds, diag]).unfold(dim, d, 1).flip((-1,)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment