TyXe:面向Pytorch用户的基于Pyro的BNN
TyXe旨在通过利用的模型定义和推理功能来简化将神经网络转变为贝叶斯神经网络的过程。 我们的核心设计原则是将神经体系结构的构造,先验,推理分布和可能性完全分开,从而实现灵活的工作流,其中每个组件都可以独立交换。 在TyXe中定义BNN只需5行代码:
net = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
likelihood = tyxe.observation_models.HomoskedasticGaussian(scale=0.1)
inference = pyro.infer.autoguides.AutoDiagonalNormal
bnn = t
1