同步批处理标准PyTorch
PyTorch中的同步批处理规范化实现。
此模块与内置的PyTorch BatchNorm不同,因为在训练过程中所有设备的均值和标准差都减小了。
例如,当在训练期间使用nn.DataParallel封装网络时,PyTorch的实现仅使用该设备上的统计信息对每个设备上的张量进行归一化,这加快了计算速度,并且易于实现,但统计信息可能不准确。 相反,在此同步版本中,将对分布在多个设备上的所有训练样本进行统计。
请注意,对于单GPU或仅CPU的情况,此模块的行为与内置的PyTorch实现完全相同。
该模块目前仅是用于研究用途的原型版本。 如下所述,它有其局限性,甚至可能会遇到一些设计问题。 如果您有任何疑问或建议,请随时或。
为什么要同步BatchNorm?
尽管在多个设备(GPU)上运行BatchNorm的典型实现速度很快(没有通信开销),但不可避免地会
2022-01-27 18:44:23
17KB
Python
1