Source code for optrade.models.utils.pos_enc

import torch
import torch.nn as nn
import torch.nn.init as init


[docs] class PositionalEncoding(nn.Module):
[docs] def __init__(self, patch_dim: int = 16, d_model: int = 128, num_patches: int = 64): super(PositionalEncoding, self).__init__() self.projection = nn.Linear(patch_dim, d_model) # P x D projection matrix self.pos_encoding = nn.Parameter( torch.empty(num_patches, d_model) ) # N x D positional encoding matrix # Weight initialization init.xavier_uniform_(self.projection.weight) init.uniform_(self.pos_encoding, -0.02, 0.02)
[docs] def forward(self, x): """ Args: x: input tensor of shape (B, M, N, P) where B = batch_size, M = num_channels, N = num_patches, P = patch_dim. Returns: x: tensor of shape (B, M, N, D) where D = d_model. """ B, M, N, P = x.shape x = x.view( B * M, N, P ) # Reshape the tensor to (B * M, N, P). We process each channel independently. x = self.projection(x) + self.pos_encoding.unsqueeze(0) x = x.view(B, M, N, -1) # Reshape the tensor to (B, M, N, D). return x
# # Test PatchtstOG vs ours # # Hyperparameters # batch_size, num_channels, num_patches, patch_dim = 3, 4, 17, 6 # d_model = 5 # x = torch.randn(batch_size, num_channels, num_patches, patch_dim) # # #==========================================MANUAL====================================================== # # Define parameters # W_P_og = nn.Linear(patch_dim, d_model) # W_pos_og = torch.empty((num_patches, d_model)) # # Initalization for W_P and W_pos # nn.init.xavier_uniform_(W_P_og.weight) # nn.init.uniform_(W_pos_og, -0.02, 0.02) # # Fix random parameters # # Print out parameters (out_features, in_features) format # print(f"OG weight shape {W_P_og.weight.data.shape}") # print(f"OG bias shape {W_P_og.bias.data.shape}") # W_P_og.weight.data = torch.randn(d_model, patch_dim) * 5 # W_P_og.bias.data = torch.randn(d_model) * 17.3 # print(f"OG pos enc shape {W_pos_og.data.shape}") # W_pos_og.data = torch.randn(num_patches, d_model) * (-2.4) # #==========================================OURS====================================================== # # Initialize matrices to the same values # positional_encoding = PositionalEncoding(patch_dim=patch_dim, d_model=d_model, num_patches=num_patches) # print(f"(Before update) Ours weight shape {positional_encoding.projection.weight.data.shape}, " # f"bias shape {positional_encoding.projection.bias.data.shape}, " # f"and pos enc shape {positional_encoding.pos_encoding.data.shape}") # positional_encoding.projection.weight.data = W_P_og.weight.data # positional_encoding.projection.bias.data = W_P_og.bias.data # positional_encoding.pos_encoding.data = W_pos_og.data # print(f"(After update) Ours weight shape {positional_encoding.projection.weight.data.shape}, " # f"bias shape {positional_encoding.projection.bias.data.shape}, " # f"and pos enc shape {positional_encoding.pos_encoding.data.shape}") # # OG # # Project + Positional Encoding # print(f"x shape: {x.shape}") # print(f"W_P_og shape: {W_P_og.weight.data.shape}") # u = W_P_og(x) # x: [bs x nvars x patch_num x d_model] # u = torch.reshape(u, (u.shape[0]*u.shape[1],u.shape[2],u.shape[3])) # u: [bs * nvars x patch_num x d_model] # u = u + W_pos_og # u: [bs * nvars x patch_num x d_model] # # OURS # print(u.shape) # v = positional_encoding(x) # print(v.shape) # assert torch.allclose(u, v.view(batch_size*num_channels, num_patches, -1)), "Error: Positional Encoding not equal to OG"