Source code for optrade.models.utils.revin

import torch
import torch.nn as nn

from typing import Optional, List


[docs] class RevIN(nn.Module):
[docs] def __init__( self, num_channels: int, eps: float = 1e-5, affine: bool = True, target_channels: Optional[List[int]] = None, ) -> None: """ Kim et al. (2022): Reversible instance normalization for accurate time-series forecasting against distribution shift. Provides a learnable instance normalization layer that is reversible. Code is from https://github.com/yuqinie98/PatchTST/blob/main/PatchTST_supervised/layers/RevIN.py with modifications, which was originally taken from https://github.com/ts-kim/RevIN. Args: num_channels: The number of features or channels. eps: A value added for numerical stability. affine: If True, RevIN has learnable affine parameters (e.g., like LayerNorm). target_channels: List of target channels for the head layer. It not None, it will denormalize only the target channels corresponding to an input shape of (batch_size, len(target_channels), pred_len) for the model output. """ super(RevIN, self).__init__() self.num_channels = num_channels self.eps = eps self.affine = affine if target_channels is not None: self.target_channels_idx = target_channels self.affine_weight = nn.Parameter(torch.ones(len(target_channels))) self.affine_bias = nn.Parameter(torch.zeros(len(target_channels))) else: self.target_channels_idx = None self.affine_weight = nn.Parameter(torch.ones(self.num_channels)) self.affine_bias = nn.Parameter(torch.zeros(self.num_channels))
[docs] def forward(self, x, mode: str): """ Forward pass for normalization or denormalizating the time series with learnable affine transformations. Args: x: Input tensor of shape (batch_size, num_channels, seq_len). mode: 'norm' for normalization and 'denorm' for denormalization. Returns: Normalized or denormalized tensor of same shape as input tensor. """ if x.dim() != 3: x = x.unsqueeze(1) if mode == "norm": self._get_statistics(x) x = self._normalize(x) elif mode == "denorm": x = self._denormalize(x) else: raise NotImplementedError return x
def _get_statistics(self, x): self.mean = torch.mean(x, dim=2, keepdim=True).detach() self.stdev = torch.sqrt( torch.var(x, dim=2, keepdim=True, unbiased=False) + self.eps ).detach() def _normalize(self, x): x = x - self.mean x = x / self.stdev if self.affine: if self.target_channels_idx is not None: x[:, self.target_channels_idx, :] = x[ :, self.target_channels_idx, : ] * self.affine_weight.unsqueeze(0).unsqueeze( -1 ) # Reshape: (1, num_channels, 1) x[:, self.target_channels_idx, :] = x[ :, self.target_channels_idx, : ] + self.affine_bias.unsqueeze(0).unsqueeze( -1 ) # Reshape: (1, num_channels, 1) else: x = x * self.affine_weight.unsqueeze(0).unsqueeze( -1 ) # Reshape: (1, num_channels, 1) x = x + self.affine_bias.unsqueeze(0).unsqueeze( -1 ) # Reshape: (1, num_channels, 1) return x def _denormalize(self, x): if self.affine: x = x - self.affine_bias.unsqueeze(0).unsqueeze(-1) x = x / ( self.affine_weight.unsqueeze(0).unsqueeze(-1) + self.eps * self.eps ) if self.target_channels_idx is not None: x = x * self.stdev[:, self.target_channels_idx, :] x = x + self.mean[:, self.target_channels_idx, :] else: x = x * self.stdev x = x + self.mean return x
if __name__ == "__main__": batch_size = 32 num_channels = 7 seq_len = 100 pred_len = 48 target_channels = [1, 3] x = torch.randn(batch_size, num_channels, seq_len) revin = RevIN( num_channels=num_channels, eps=1e-5, affine=True, target_channels=target_channels, ) x_norm = revin(x, mode="norm") print(f"x_norm shape: {x_norm.shape}") y_hat = torch.randn(batch_size, len(target_channels), pred_len) x_denorm = revin(y_hat, mode="denorm") print(f"x_denorm shape: {x_denorm.shape}")