Created
November 22, 2024 21:25
-
-
Save danielvarga/d969c695b0360779c608b6b5a80ff66f to your computer and use it in GitHub Desktop.
Tensor network defining subject-relation-object triplets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| class TensorNetworkBase(nn.Module): | |
| def __init__(self, d, k): | |
| super(TensorNetworkBase, self).__init__() | |
| self.d = d | |
| self.k = k | |
| self.proj1 = nn.Parameter(torch.randn(d, k)) # Oo | |
| self.proj2 = nn.Parameter(torch.randn(d, k)) # Rr | |
| self.proj3 = nn.Parameter(torch.randn(d, k)) # Ss | |
| self.core_tensor = nn.Parameter(torch.randn(k, k, k)) # ors | |
| def three_vectors_to_scalar(self, input1, input2, input3): | |
| # Step 1: Project inputs to k dimensions | |
| p1 = torch.einsum('Oo,bO->bo', self.proj1, input2) # bo | |
| p2 = torch.einsum('Rr,bR->br', self.proj2, input2) # br | |
| p3 = torch.einsum('Ss,bS->bs', self.proj3, input2) # bs | |
| # Step 2: Contract with the core tensor | |
| # Batch-wise contraction for a scalar output | |
| output = torch.einsum('bo,br,bs,ors->b', p1, p2, p3, self.core_tensor) | |
| return output | |
| def vector_to_matrix(self, input2): | |
| # Step 1: Project inputs to k dimensions | |
| p2 = torch.einsum('Rr,bR->br', self.proj2, input2) # br | |
| output = torch.einsum('Oo,ors,Ss,br->bos', self.proj1, self.core_tensor, self.proj3, p2) | |
| return output | |
| class TensorNetworkVectorToMatrix(TensorNetworkBase): | |
| forward = TensorNetworkBase.vector_to_matrix | |
| class TensorNetworkThreeVectorsToScalar(TensorNetworkBase): | |
| forward = TensorNetworkBase.three_vectors_to_scalar | |
| # Example usage | |
| d = 100 # Original input dimension | |
| k = d // 10 # Reduced dimension | |
| batch_size = 32 | |
| input1 = torch.randn(batch_size, d) | |
| input2 = torch.randn(batch_size, d) | |
| input3 = torch.randn(batch_size, d) | |
| model = TensorNetworkVectorToMatrix(d, k) | |
| output = model(input2) | |
| print(output.shape) | |
| model = TensorNetworkThreeVectorsToScalar(d, k) | |
| output = model(input1, input2, input3) | |
| print(output.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment