import math
import torch
import random
[docs]
def set_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
[docs]
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
[docs]
def xavier_init(m, seed=None):
if seed is not None:
set_seed(seed)
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.ones_(m.weight)
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.LSTM):
for name, param in m.named_parameters():
if "weight_ih" in name:
torch.nn.init.xavier_uniform_(param.data)
elif "weight_hh" in name:
torch.nn.init.orthogonal_(param.data)
elif "bias" in name:
param.data.fill_(0)