pytorch-mdn
此回购包含的代码。
用法:
import torch . nn as nn
import torch . optim as optim
import mdn
# initialize the model
model = nn . Sequential (
nn . Linear ( 5 , 6 ),
nn . Tanh (),
mdn . MDN ( 6 , 7 , 20 )
)
optimizer = optim . Adam ( model . parameters ())
# train the model
for minibatch , labels in train_set :
model . zero_grad ()
pi , sigma , mu = model ( minibatch )
loss
2021-12-05 21:55:15
305KB
Python
1