import torch
import torch.nn as nn
[docs]
def get_num_examples(data_loader):
total_samples = len(data_loader.dataset)
if not data_loader.drop_last:
return total_samples
else:
batch_size = data_loader.batch_size
return (total_samples // batch_size) * batch_size
[docs]
class Transpose(nn.Module):
[docs]
def __init__(self, dim1, dim2):
super().__init__()
self.dim1, self.dim2 = dim1, dim2
[docs]
def forward(self, x):
return x.transpose(self.dim1, self.dim2)
[docs]
class Unsqueeze(nn.Module):
[docs]
def __init__(self, dim):
super().__init__()
self.dim = dim
[docs]
def forward(self, x):
return x.unsqueeze(self.dim)
[docs]
class Reshape(nn.Module):
[docs]
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
[docs]
def forward(self, x):
return x.reshape(*self.shape)
[docs]
class Norm(nn.Module):
[docs]
def __init__(self, norm_mode, num_channels, seq_len, d_model):
super().__init__()
self.norm_mode = norm_mode
self.num_channels = num_channels
self.seq_len = seq_len
self.d_model = d_model
if norm_mode == "batch1d":
self.norm = nn.Sequential(
Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
)
elif norm_mode == "batch2d":
self.norm = nn.BatchNorm2d(num_channels)
elif norm_mode == "layer":
self.norm = nn.LayerNorm(d_model)
else:
raise ValueError(
"Please select a valid normalization mode: 'batch1d', 'batch2d', or 'layer'."
)
[docs]
def forward(self, x):
if self.norm_mode == "batch2d":
batch_size = x.shape[0]
x = x.view(batch_size, self.num_channels, self.seq_len, self.d_model)
x = self.norm(x)
x = x.view(batch_size * self.num_channels, self.seq_len, self.d_model)
return x
else:
return self.norm(x)