Source code for optrade.models.pytorch.mamba

import torch
import torch.nn as nn
from mambapy.mamba import Mamba as MambaBackbone
from mambapy.mamba import MambaConfig
from optrade.models.utils.revin import RevIN
from optrade.models.utils.patcher import Patcher
from optrade.models.utils.weight_init import xavier_init
from optrade.models.utils.pos_enc import PositionalEncoding
from optrade.models.utils.utils import Reshape

from typing import Optional


[docs] class Mamba(nn.Module):
[docs] def __init__( self, d_model, num_enc_layers, pred_len, num_channels: int = 1, revin: bool = False, revout: bool = False, revin_affine: bool = False, eps_revin: float = 1e-5, head_type: str = "linear", norm_mode: str = "layer", patching: bool = False, patch_dim: int = 16, patch_stride: int = 8, seq_len: int = 512, last_state: bool = True, dropout: float = 0.0, channel_independent: bool = False, target_channels: Optional[list] = None, ) -> None: super(Mamba, self).__init__() # Parameters self.num_patches = int((seq_len - patch_dim) / patch_stride) + 2 self.num_channels = num_channels self.eps_revin = eps_revin self.revin_affine = revin_affine self.revout = revout self.target_channels = target_channels self.last_state = last_state # Initialize layers if revin: self._init_revin() else: self._revin = None self.revout = None # Patching (only works for FIXED sequence length) if patching: self._patching = True self.patcher = Patcher(patch_dim, patch_stride) self.pos_enc = PositionalEncoding(patch_dim, d_model, self.num_patches) else: self._patching = None # Mamba Backbone config = MambaConfig( d_model=d_model, n_layers=num_enc_layers, ) self.backbone = MambaBackbone(config) # Head # head_dim = self.num_patches * d_model if patching else d_model num_output_channels = ( len(target_channels) if target_channels is not None else num_channels ) if head_type == "linear": if channel_independent: self.head = nn.Linear(seq_len, pred_len) elif target_channels is not None or not channel_independent: self.head = nn.Sequential( Reshape(-1, num_output_channels * seq_len), nn.Linear( num_output_channels * seq_len, num_output_channels * pred_len ), Reshape(-1, num_output_channels, pred_len), ) print("Selecting correct head") elif head_type == "mlp": self.head = nn.Sequential( nn.Linear(head_dim, head_dim), nn.GELU(), nn.Linear(head_dim, pred_len) ) # Final Normalization Layer self.norm = nn.LayerNorm(d_model) if norm_mode == "layer" else nn.Identity() # Weight initialization self.apply(xavier_init)
def _init_revin(self): self._revin = True self.revin = RevIN( num_channels=self.num_channels, eps=self.eps_revin, affine=self.revin_affine, target_channels=self.target_channels, )
[docs] def forward(self, x): """ Computes the forward pass of the Mamba model. There are two possible modes: Patched Version: This is meant for univariate or multivariate time series forecasting, which applies a patching mechanism to the input sequence. The input tensor should have shape (B, M, L), where B is the batch size, M is the number of channels, and L is the sequence length. The output tensor will have shape (B, pred_len), where pred_len is the prediction length. Non-Patched Version: This is meant for univariate variable-length time series classification (SOZ localization), where the input tensor should have shape (B, L, 1), where B is the batch size, and L is the sequence length which can change from batch to batch, and is padded accordingly. The output tensor will have shape (B, pred_len), where pred_len is the prediction length (usually set to pred_len=1 for binary classification). Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, num_channels) Legend: B: batch_size, M: num_channels, L: seq_len, N: num_patches, P: patch_dim, D: d_model. """ if not self._patching: x = x.transpose(1, 2) # (B, M, L) -> (B, L, M) # RevIN if self._revin: x = self.revin( x, mode="norm" ) # Patched version:(B, M, L). Non-patched version: (B, L, M) # Patching if self._patching: x = self.patcher(x) # (B, M, N, P) x = self.pos_enc(x) # (B, M, N, D) B, M, N, D = x.shape x = x.view(B * M, N, D) # Mamba forward pass print(f"Shape before backbone: {x.shape}") x = self.backbone( x ) # Patched version: (B*M, N, D). Non-patched version: (B, L, M) # Normalization x = self.norm(x) # Patched version: (B*M, N, D). Non-patched version: (B, L, M) # Apply head if self.last_state and not self._patching: x = ( x[:, :, self.target_channels] if self.target_channels is not None else x ) # (B, L, num_output_channels) x = x.transpose(1, 2) # (B, num_output_channels, L) elif self._patching: x = x.view(B, M, N * D) # (B, M, N*D) else: raise ValueError( "Invalid configuration for the Mamba model. Please check the parameters." ) # Head print(f"Shape before head: {x.shape}") print(f"Parameter sizes of head: {[p.size() for p in self.head.parameters()]}") x = self.head( x ) # Patched version: ? Non-patched version: (B, pred_len, num_output_channels) # RevOUT if self.revout: x = self.revin(x, mode="denorm") # if not self._patching: # x = x.transpose(1, 2) # (B, num_output_channels, pred_len) return x
if __name__ == "__main__": # <---Non-patched version---> # Define model parameters batch_size = 32 num_enc_layers = 2 pred_len = 96 seq_len = 512 num_channels = d_model = 7 # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create an instance of the LSTM model model = Mamba( d_model=d_model, num_enc_layers=num_enc_layers, pred_len=pred_len, seq_len=seq_len, num_channels=num_channels, revin=True, head_type="linear", patching=False, last_state=True, channel_independent=False, target_channels=[0, 3, 5], ).to(device) x = torch.randn(batch_size, num_channels, seq_len).to(device) # Pass the data through the model output = model(x) output = output.to(device) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") # #<--Patched Version (forecasting)---> # # Define model parameters # patch_dim = 16 # patch_stride = 8 # batch_size = 32 # d_model = 128 # input_size = d_model # num_enc_layers = 5 # pred_len = 96 # seq_len = 512 # num_channels = 7 # # Device # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # # Create an instance of the LSTM model # model = Mamba( # d_model=d_model, # num_enc_layers=num_enc_layers, # pred_len=pred_len, # seq_len=seq_len, # num_channels=num_channels, # revin=True, # revin_affine=True, # revout=True, # head_type="linear", # patching=True, # last_state=False, # target_channels=None, # ).to(device) # # Create sample input data # x = torch.randn(batch_size, num_channels, seq_len).to(device) # (B, M, L) # # Pass the data through the model # output = model(x) # output = output.to(device) # print(f"Input shape: {x.shape}") # print(f"Output shape: {output.shape}")