Keras混合物密度网络层
使用TensorFlow的发行模块的Keras的混合密度网络(MDN)层。 这使得使用神经网络进行实验变得更加简单,该神经网络预测了多个可能包含多个可能值的实值变量。
该层可以帮助构建类似于 , ,甚至所使用的MDN- 。 您可以使用MDN做很多很酷的事情!
此实现的一个好处是您可以预测任意数量的实值。 TensorFlow的Mixture , Categorical和MultivariateNormalDiag分布函数用于生成损失函数(多元正态分布与对角协方差矩阵混合的概率密度函数)。 在以前的工作中,通常会手动指定损失函数,这对于1D或2D预测是合适的,但此后会变得更加烦人。
提供了两个重要的功能用于训练和预测:
get_mixture_loss_func(output_dim, num_mixtures) :此函数生成具有正确输出尺寸a和混合
1