使用说明
分对话系统和机器翻译两部分
data为数据集
model为训练的模型
translation文件夹下又分了Seq2Seq和transformer两个模型,大家按需查看使用
以transformer文件夹为例,attention.py主要实现了注意力机制,transformer.py实现了transformer的主体架构,data.py为数据的预处理以及生成了词典、dataset、dataloader,readdata.py运行可以查看数据形状,train.py为训练模型,predict.py为预测,config.py为一些参数的定义。
transformer机器翻译的模型是用cuda:1训练的,如果要使用可能需要修改代码
如:gpu->cpu,即在CPU上使用
torch.load('trans_encoder.mdl', map_location= lambda storage, loc: storage)
torch.load('trans_decoder.mdl', map_location= lambda storage, loc: storage)
1