使用Pytorch框架进行深度学习任务,特别是分类任务时,经常会用到如下:
import torch.nn as nn
criterion = nn.CrossEntropyLoss().cuda()
loss = criterion(output, target)
即使用torch.nn.CrossEntropyLoss()作为损失函数。
那nn.CrossEntropyLoss()内部到底是啥??
nn.CrossEntropyLoss()是torch.nn中包装好的一个类,对应torch.nn.functional中的cross_entropy。
此外,nn.CrossEntropyL
2022-04-03 21:28:23
71KB
c
hot
op
1