Source code for optrade.models.utils.weight_init

import math
import torch
import random


[docs] def set_seed(seed): torch.manual_seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor
[docs] def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): return _no_grad_trunc_normal_(tensor, mean, std, a, b)
[docs] def xavier_init(m, seed=None): if seed is not None: set_seed(seed) if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: torch.nn.init.zeros_(m.bias) elif isinstance(m, torch.nn.LayerNorm): torch.nn.init.ones_(m.weight) torch.nn.init.zeros_(m.bias) elif isinstance(m, torch.nn.LSTM): for name, param in m.named_parameters(): if "weight_ih" in name: torch.nn.init.xavier_uniform_(param.data) elif "weight_hh" in name: torch.nn.init.orthogonal_(param.data) elif "bias" in name: param.data.fill_(0)