Pytorch学習ウェイトの初期化
Pytorch Learning Weight Initialization
重みの初期化は、ニューラルネットワークのトレーニングにとって非常に重要です。適切な初期化の重みは、勾配の消失などの問題を効果的に回避できます。
参照用にpytorchを使用する場合、重みを初期化する方法はいくつかあります。
注:最初の方法はお勧めしません。後者の2つの方法を使用してみてください。
# not recommend def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0)
# recommend def initialize_weights(m): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.02) m.bias.data.zero_()
# recommend def weights_init(m): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight.data) nn.init.xavier_normal_(m.bias.data) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight,1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight,1) nn.init.constant_(m.bias, 0)
よく書かれているweights_init
関数の後、モデルを使用できますapply
メソッドはモデルの重みの初期化を実行します。
net = Residual() # generate an instance network from the Net class net.apply(weights_init) # apply weight init
参考資料
- https://discuss.pytorch.org/t/weight-initilzation/157
- https://discuss.pytorch.org/t/init-parameters-weight-init-not-defined/22935
転載:https://www.jianshu.com/p/adf427f4fcdf