Source code for optrade.models.pytorch.dlinear

import torch
import torch.nn as nn

import torch
import torch.nn as nn
import torch.nn.functional as F
from optrade.models.utils.revin import RevIN

from typing import Optional


[docs] class Model(nn.Module): """ Paper link: https://arxiv.org/pdf/2205.13504.pdf Taken from: https://github.com/thuml/Time-Series-Library/blob/main/models/DLinear.py """
[docs] def __init__( self, task: str = "forecasting", seq_len: int = 512, pred_len: int = 96, num_channels: int = 7, num_classes: int = 2, moving_avg: int = 25, individual: bool = False, return_head: bool = True, revin: bool = True, revout: bool = False, revin_affine: bool = False, eps_revin: float = 1e-5, target_channels: Optional[list] = None, ) -> None: """ Args: task (str): Task name among 'classification', 'anomaly_detection', 'imputation', or 'forecasting'. seq_len (int): Length of input sequence. pred_len (int): Length of output forecasting. num_channels (int): Number of input channels (features). num_classes (int): Number of classes for classification task. moving_avg (int): Window size of moving average. individual (bool): Whether shared model among different variates. """ super(Model, self).__init__() self.task = task self.seq_len = seq_len self.return_head = return_head self.target_channels = target_channels self.revout = revout self.revin_affine = revin_affine self.eps_revin = eps_revin self.num_channels = num_channels # RevIN if revin: self._init_revin() else: self._revin = None self.revout = None if self.task == "classification": self.pred_len = seq_len elif self.task == "forecasting": self.pred_len = pred_len else: raise ValueError(f"Task name '{self.task}' not supported.") # Series decomposition block from Autoformer self.decomposition = series_decomp(moving_avg) self.individual = individual self.channels = num_channels if self.individual: self.Linear_Seasonal = nn.ModuleList() self.Linear_Trend = nn.ModuleList() for i in range(self.channels): self.Linear_Seasonal.append(nn.Linear(self.seq_len, self.pred_len)) self.Linear_Trend.append(nn.Linear(self.seq_len, self.pred_len)) self.Linear_Seasonal[i].weight = nn.Parameter( (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]) ) self.Linear_Trend[i].weight = nn.Parameter( (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]) ) else: self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) self.Linear_Seasonal.weight = nn.Parameter( (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]) ) self.Linear_Trend.weight = nn.Parameter( (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]) ) if self.task == "classification": self.head = nn.Linear(num_channels * seq_len, num_classes) elif self.task == "forecasting": self.target_channels = target_channels in_dim = ( num_channels * pred_len if target_channels is None else pred_len * len(target_channels) ) out_dim = ( num_channels * pred_len if target_channels is None else pred_len * len(target_channels) ) self.head = nn.Linear(in_dim, out_dim)
def _init_revin(self): self._revin = True self.revin = RevIN( num_channels=self.num_channels, eps=self.eps_revin, affine=self.revin_affine, target_channels=self.target_channels, )
[docs] def encoder(self, x): seasonal_init, trend_init = self.decomposition(x) seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute( 0, 2, 1 ) if self.individual: seasonal_output = torch.zeros( [seasonal_init.size(0), seasonal_init.size(1), self.pred_len], dtype=seasonal_init.dtype, ).to(seasonal_init.device) trend_output = torch.zeros( [trend_init.size(0), trend_init.size(1), self.pred_len], dtype=trend_init.dtype, ).to(trend_init.device) for i in range(self.channels): seasonal_output[:, i, :] = self.Linear_Seasonal[i]( seasonal_init[:, i, :] ) trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :]) else: seasonal_output = self.Linear_Seasonal(seasonal_init) trend_output = self.Linear_Trend(trend_init) x = seasonal_output + trend_output return x.permute(0, 2, 1)
[docs] def forecast(self, x_enc): # Ensure correct size (3D tensor) if len(x_enc.size()) == 2: x_enc = x_enc.unsqueeze( 1 ) # (batch_size, seq_len) -> (batch_size, 1, seq_len) batch_size, _, seq_len = x_enc.size() assert ( seq_len == self.seq_len ), f"Input sequence length {seq_len} is not equal to the model sequence length {self.seq_len}." if self._revin: x_enc = self.revin(x_enc, mode="norm") output = self.encoder( x_enc.permute(0, 2, 1) ) # (batch_size, seq_len, num_channels) if self.target_channels is not None: output = output[:, :, self.target_channels] if self.return_head: output = output.reshape(output.shape[0], -1) output = self.head(output) # (batch_size, num_channels*pred_len) output = output.reshape( output.shape[0], self.pred_len, -1 ) # (batch_size, pred_len, num_channels) # RevOUT if self.revout: output = self.revin(output.permute(0, 2, 1), mode="denorm").permute(0, 2, 1) return output.permute(0, 2, 1)
[docs] def classification(self, x_enc): # Encoder output = self.encoder(x_enc) # (batch_size, seq_len, num_channels) if self.target_channels is not None: output = output[:, :, self.target_channels] if self.return_head: output = output.reshape(output.shape[0], -1) output = self.head(output) # (batch_size, num_classes) return output
[docs] def forward(self, x_enc): if self.task == "forecasting": output = self.forecast(x_enc) # (batch_size, pred_len, num_channels) elif self.task == "classification": output = self.classification( x_enc ) # (batch_size, num_classes) or (batch_size,) for binary classification else: raise ValueError(f"Task name '{self.task}' not supported.") return output
[docs] class moving_avg(nn.Module): """ Moving average block to highlight the trend of time series """
[docs] def __init__(self, kernel_size, stride): super(moving_avg, self).__init__() self.kernel_size = kernel_size self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
[docs] def forward(self, x): # padding on the both ends of time series front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) x = torch.cat([front, x, end], dim=1) x = self.avg(x.permute(0, 2, 1)) x = x.permute(0, 2, 1) return x
[docs] class series_decomp(nn.Module): """ Series decomposition block """
[docs] def __init__(self, kernel_size): super(series_decomp, self).__init__() self.moving_avg = moving_avg(kernel_size, stride=1)
[docs] def forward(self, x): moving_mean = self.moving_avg(x) res = x - moving_mean return res, moving_mean
# Test if __name__ == "__main__": # Forecasting batch_size = 32 seq_len = 512 num_channels = 7 task = "forecasting" pred_len = 96 x = torch.randn(batch_size, num_channels, seq_len) forecasting_model = Model( task=task, seq_len=seq_len, pred_len=pred_len, num_channels=num_channels, target_channels=[1, 3], return_head=True, revin=True, revin_affine=True, revout=True, ) y = forecasting_model(x) print(f"x: {x.shape}") print(f"y: {y.shape}")