import torch
import torch.nn as nn
from typing import Optional
from optrade.models.utils.revin import RevIN
from optrade.models.utils.patcher import Patcher
from optrade.models.utils.pos_enc import PositionalEncoding
from optrade.models.utils.weight_init import xavier_init
from optrade.models.utils.utils import Reshape
# TODO: Reimplement patching
[docs]
class Model(nn.Module):
[docs]
def __init__(
self,
d_model,
num_enc_layers,
pred_len,
backbone_id,
bidirectional=False,
dropout=0.0,
seq_len=512,
patching=False,
patch_dim=16,
patch_stride=8,
num_channels=1,
head_type="linear",
norm_mode="layer",
revin=False,
revout=False,
revin_affine=False,
eps_revin=1e-5,
last_state=True,
avg_state=False,
return_head=True,
channel_independent=False,
target_channels: Optional[list] = None,
) -> None:
super(Model, self).__init__()
"""
A Recurrent Neural Network (RNN) class that host a variety of different recurrent architectures including LSTM, Mamba, GRU, and the classic RNN.
Args:
d_model (int): The number of expected features in the input (required).
num_enc_layers (int): Number of recurrent layers (required).
pred_len (int): The number of expected features in the output (required).
backbone_id (str): The type of recurrent architecture to use (required). Options: "LSTM", "Mamba",
bidirectional (bool): If True, becomes a bidirectional RNN. Default: False.
dropout (float): If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0.
seq_len (int): The length of the input sequence. Default: 512.
patching (bool): If True, the input sequence is patched. Default: False.
patch_dim (int): The dimension of the patch. Default: 16.
patch_stride (int): The stride of the patch. Default: 8.
num_channels (int): The number of channels in the input data. Default: 1.
head_type (str): The type of head to use Options: "linear", "mlp". Default: "linear".
norm_mode (str): The type of normalization to use. Default: "layer".
revin (bool): If True, applies RevIN to the input sequence. Default: False.
revout (bool): If True, applies RevIN to the output sequence. Default: False.
revin_affine (bool): If True, applies an affine transformation to the RevIN layer. Default: False.
eps_revin (float): The epsilon value for RevIN. Default: 1e-5.
last_state (bool): If True, returns the last state of the RNN. Default: True.
avg_state (bool): If True, returns the average state of the RNN. Default: False.
"""
# Parameters
self.backbone_id = backbone_id
self.num_patches = int((seq_len - patch_dim) / patch_stride) + 2
self.patch_dim = patch_dim
self.patch_stride = patch_stride
self.num_channels = num_channels
self.eps_revin = eps_revin
self.revin_affine = revin_affine
self.revout = revout
self.target_channels = target_channels
self.last_state = last_state
self.avg_state = avg_state
self.input_size = d_model if patching else num_channels
self.return_head = return_head
# RevIN
if revin:
self._init_revin()
else:
self._revin = None
self.revout = None
# Patching (only works for FIXED sequence length)
if patching:
self._patching = True
self.patcher = Patcher(patch_dim, patch_stride)
self.pos_enc = (
nn.Linear(patch_dim, d_model)
if avg_state
else PositionalEncoding(patch_dim, d_model, self.num_patches)
)
else:
self._patching = None
# Backbone
if self.backbone_id == "LSTM":
self.backbone = nn.LSTM(
input_size=self.input_size,
hidden_size=d_model,
num_layers=num_enc_layers,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional,
)
elif self.backbone_id == "RNN":
self.backbone = nn.RNN(
input_size=self.input_size,
hidden_size=d_model,
num_layers=num_enc_layers,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional,
)
elif self.backbone_id == "GRU":
self.backbone = nn.GRU(
input_size=self.input_size,
hidden_size=d_model,
num_layers=num_enc_layers,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional,
)
else:
raise ValueError("Invalid backbone_id. Options: 'LSTM', 'RNN', 'GRU'.")
# Head
self.dropout = nn.Dropout(dropout)
if patching and not avg_state:
head_dim = self.num_patches * d_model
elif last_state or avg_state:
head_dim = d_model
num_output_channels = (
len(target_channels) if target_channels is not None else num_channels
)
if head_type == "linear":
self.head = nn.Sequential(
nn.Linear(head_dim, num_output_channels * pred_len),
Reshape(-1, num_output_channels, pred_len),
)
elif head_type == "mlp":
self.head = nn.Sequential(
nn.Linear(head_dim, head_dim // 2),
nn.GELU(),
nn.Linear(head_dim // 2, num_output_channels * pred_len),
Reshape(-1, num_output_channels, pred_len),
)
if not (last_state or avg_state):
self.head = nn.Sequential(
Reshape(-1, seq_len * d_model),
self.head,
)
self.flatten = nn.Flatten(start_dim=-2)
# Final Normalization Layer
norm_dim = d_model
self.norm = nn.LayerNorm(norm_dim) if norm_mode == "layer" else nn.Identity()
# Weight initialization
self.apply(xavier_init)
def _init_revin(self):
self._revin = True
self.revin = RevIN(
num_channels=self.num_channels,
eps=self.eps_revin,
affine=self.revin_affine,
target_channels=self.target_channels,
)
[docs]
def compute_backbone(self, x):
if self.backbone_id in {"RNN", "GRU"}:
out, hn = self.backbone(x)
last_hn = hn[-1]
elif self.backbone_id == "LSTM":
out, (hn, _) = self.backbone(x)
last_hn = hn[-1]
else:
raise ValueError("Invalid backbone_id. Options: 'LSTM', 'RNN', 'GRU',.")
return out, last_hn
[docs]
def forward(self, x):
"""
Forward pass of the model.
Args:
x (torch.Tensor): The input data. Shape: (batch_size, num_channels, seq_len) = (B, M, L).
"""
# Ensure input is correct
if len(x.shape) == 2:
x = (
x.unsqueeze(-2) if self._patching else x.unsqueeze(-1)
) #: (B, L) -> (B, L, 1)
# RevIN
if self._revin:
x = self.revin(
x, mode="norm"
) # Patched version:(B, M, L). Non-patched version: (B, L, 1)
x = x.transpose(1, 2)
# # Patching
# if self._patching:
# x = self.patcher(x) # (B, M, N, P)
# x = self.pos_enc(x) # (B, M, N, D)
# B, M, N, D = x.shape
# x = x.view(B*M, N, D) # (B*M, N, D)
# Backbone forward pass
out, last_hn = self.compute_backbone(x)
# Normalization
if self.last_state:
x = self.norm(last_hn) # Select last hidden state: (B, D)
elif self.avg_state:
x = self.norm(
torch.mean(out, dim=1)
) # Average over sequence length. Patched version: (B*M, D). Non-patched version: (B, D).
else:
x = self.norm(
out
) # Patched version: (B*M, N, D). Non-patched version: (B, L, D)
# # Reshape for patching
# if self._patching:
# x = x.view(B, M, -1) # avg state: (B, M, D). Non-avg state: (B, M, N*D)
# Head
if self.return_head:
# x = x.transpose(0,1)
print(f"Shape before head: {x.shape}")
x = self.head(self.dropout(x)) # (B, pred_len)
print(f"x after head: {x.shape}")
# RevOUT
if self.revout:
x = self.revin(x, mode="denorm")
return x
if __name__ == "__main__":
# <---Non-patched version (classification)--->
# Define model parameters
batch_size = 32
num_channels = 7
seq_len = 512
pred_len = 96
d_model = 64
num_enc_layers = 5
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.randn(batch_size, num_channels, seq_len).to(device) # (B, M, L)
model = Model(
d_model=d_model,
num_enc_layers=num_enc_layers,
pred_len=pred_len,
backbone_id="GRU",
bidirectional=False,
dropout=0.1,
seq_len=seq_len,
patching=False,
# patch_dim=16,
# patch_stride=8,
num_channels=num_channels,
head_type="linear",
norm_mode="layer",
revin=True,
revout=True,
revin_affine=True,
last_state=False,
avg_state=True,
return_head=True,
target_channels=[0, 3, 5],
).to(device)
# Pass the data through the model
output = model(x)
output = output.to(device)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
# #<--Patched Version (forecasting)--->
# # Define model parameters
# patch_dim = 64
# patch_stride = 16
# batch_size = 1
# d_model = 128
# num_enc_layers = 5
# pred_len = 1
# seq_len = 16031
# num_channels = 1
# # Device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
# # Create an instance of the LSTM model
# model = RecurrentModel(
# d_model=d_model,
# backbone_id="Mamba",
# num_enc_layers=num_enc_layers,
# pred_len=pred_len,
# seq_len=seq_len,
# num_channels=num_channels,
# revin=True,
# revin_affine=True,
# revout=True,
# head_type="linear",
# patching=True,
# last_state=False,
# avg_state=True,
# ).to(device)
# # Create sample input data
# x = torch.randn(batch_size, num_channels, seq_len).to(device) # (B, M, L)
# print(f"Input shape: {x.shape}")
# # Pass the data through the model
# output = model(x)
# output = output.to(device)
# print(f"Output shape: {output.shape}")