Source code for optrade.models.pytorch.emforecaster

import torch
import torch.nn as nn
from typing import List, Optional
from pydantic import BaseModel

# Layers and models
from optrade.models.pytorch.dlinear import Model as DLinear
from optrade.models.pytorch.tsmixer import Model as TSMixer

# Weight initialization
from optrade.models.utils.weight_init import xavier_init

# Util Layers
from optrade.models.utils.revin import RevIN
from optrade.models.utils.patcher import Patcher


[docs] class Model(nn.Module):
[docs] def __init__( self, args: BaseModel, seed: int = 42, seq_len: int = 336, pred_len: int = 96, num_channels: int = 1, revin: bool = True, revout: bool = True, revin_affine: bool = True, eps_revin: float = 1e-5, patch_model_id: str = "TSMixer", patch_norm: str = "none", patch_act: str = "GeLU", patch_dim: int = 24, patch_stride: int = 12, patch_embed_dim: int = 128, pos_enc: str = "none", return_head: bool = True, target_channels: Optional[list] = None, ) -> None: super(Model, self).__init__() """ Patched-based forecasting model (univariate only). Args: seq_len (int): Length of input sequence. pred_len (int): Length of output forecasting. num_channels (int): Number of input channels (features). revin (bool): Whether to use reversible instance normalization (input). revout (bool): Whether to use reversible instance normalization (output). revin_affine (bool): Whether to use affine transformation for reversible instance normalization. eps_revin (float): Epsilon value for standard deviation numerical stability in reversible instance normalization. patch_model_id (str): Model used for patch embedding. patch_norm (str): Normalization layer used in patch embedding. patch_act (str): Activation function used in patch embedding. patch_dim (int): Patch dimension. patch_stride (int): Stride used in patching. patch_embed_dim (int): Patch embedding dimension. d_model (int): Dimension of the model. patching_on (bool): Only used for extensions of PatchedForecaster when patching is done externally. num_patches (int): Only used for extensions of PatchedForecaster when patching is done externally. """ # Parameters self.args = args self.seq_len = seq_len self.pred_len = pred_len self.num_channels = num_channels self.patch_embed_dim = patch_embed_dim self.patch_model_id = patch_model_id self.patch_dim = patch_dim self.return_head = return_head self.target_channels = target_channels self.revout = revout self.revin_affine = revin_affine self.eps_revin = eps_revin self.num_channels = num_channels if patch_stride == -1: self.patch_stride = self.patch_dim // 2 elif patch_stride == -2: self.patch_stride = self.patch_dim else: self.patch_stride = patch_stride self.num_patches = int((seq_len - patch_dim) / self.patch_stride) + 2 # RevIN if revin: self._init_revin() else: self._revin = None self.revout = None # Layers self.patcher = Patcher(self.patch_dim, self.patch_stride) self.patch_model = self.get_patch_model() # Positional encoding if pos_enc == "learnable": self.pos_enc = nn.Parameter( torch.randn(1, self.num_channels, self.num_patches * patch_embed_dim) ) elif pos_enc == "none": self.pos_enc = 0 # Activation function (patching) if patch_act == "relu": self.patch_act = nn.ReLU() elif patch_act == "gelu": self.patch_act = nn.GELU() else: self.patch_act = nn.Identity() # Fixed code: if target_channels is None: num_output_channels = self.num_channels else: num_output_channels = len(target_channels) self.head = nn.ModuleList( [ nn.Linear(self.num_patches * self.patch_embed_dim, self.pred_len) for i in range(num_output_channels) ] ) # Weight initialization self.apply(lambda m: xavier_init(m, seed=seed))
[docs] def forward_patch_model(self, x): # Process each channel through its dedicated model outputs = [] for i in range(len(self.patch_model)): channel_input = x[:, i, :, :] # (batch_size, num_patches, patch_dim) channel_output = self.patch_model[i]( channel_input ) # (batch_size, num_patches, patch_embed_dim) outputs.append( channel_output.view(-1, 1, self.num_patches, self.patch_embed_dim) ) # Stack the outputs along the channel dimension x = torch.stack( outputs, dim=1 ) # (batch_size, num_channels, num_patches, patch_embed_dim) return x
[docs] def forward(self, x): # RevIN if self._revin: x = self.revin(x, mode="norm") # Patching x = self.patcher( x ).squeeze() # (batch_size, num_channels, num_patches, patch_dim) # Activation function (optional) x = self.patch_act(x) # (batch_size, num_channels, num_patches, patch_dim) # Patch model x = self.forward_patch_model( x ) # (batch_size, num_channels, num_patches, patch_embed_dim) if self.return_head: # Flatten x = x.view( -1, self.num_channels, self.num_patches * self.patch_embed_dim ) # (B, 1, num_patches * patch_embed_dim) # Positional Encoding x = ( x + self.pos_enc ) # (B, self.num_channels, num_patches * patch_embed_dim) # Base Model + Linear Head outputs = [] for i in range(len(self.head)): channel_input = x[ :, i, : ] # (batch_size, num_patches * patch_embed_dim) channel_output = self.head[i](channel_input) # (batch_size, pred_len) outputs.append( channel_output.view(-1, 1, self.pred_len) ) # (batch_size, 1, pred_len) x = torch.stack(outputs, dim=1).squeeze( 2 ) # (batch_size, num_output_channels, pred_len) # RevOUT if self.revout: x = self.revin(x, mode="denorm") return x return map[self.patch_model_id]
def _init_revin(self): self._revin = True self.revin = RevIN( num_channels=self.num_channels, target_channels=self.target_channels, eps=self.eps_revin, affine=self.revin_affine, )
[docs] def get_patch_model(self) -> nn.Module: if self.patch_model_id == "DLinear": return nn.ModuleList( [ nn.Sequential( DLinear( task="forecasting", seq_len=self.patch_dim, pred_len=self.patch_embed_dim, num_channels=self.num_patches, moving_avg=self.args.emf.moving_avg, individual=False, return_head=False, ) ) for i in range(self.num_channels) ] ) elif self.patch_model_id == "TSMixer": return nn.ModuleList( [ nn.Sequential( nn.Linear(self.patch_dim, self.patch_embed_dim), TSMixer( seq_len=self.patch_embed_dim, pred_len=1, num_enc_layers=self.args.emf.num_enc_layers, d_model=self.args.emf.d_model, num_channels=self.num_patches, dropout=self.args.emf.dropout, revin=False, revin_affine=False, revout=False, return_head=False, ), ) for i in range(self.num_channels) ] ) elif self.patch_model_id == "Linear": return NotImplementedError else: raise NotImplementedError
if __name__ == "__main__": from optrade.config.config import Global args = Global() # Multivariate Test batch_size = 32 seq_len = 512 pred_len = 96 num_channels = 7 x = torch.randn(batch_size, num_channels, seq_len) model = Model( args=args, seq_len=seq_len, pred_len=pred_len, num_channels=num_channels, patch_model_id="TSMixer", revin=True, revout=True, revin_affine=True, target_channels=[6, 3, 1], ) y = model(x) print(f"Model output: {y.shape}")