火炬指标
PyTorch的模型评估指标
火炬指标作为自定义库,以提供Pytorch共同ML评价指标,类似于tf.keras.metrics 。
如,Pytorch没有用于模型评估指标的内置库torch.metrics 。 这类似于的指标库。
用法
pip install --upgrade torch-metrics
from torch_metrics import Accuracy
## define metric ##
metric = Accuracy ( from_logits = False )
y_pred = torch . tensor ([ 1 , 2 , 3 , 4 ])
y_true = torch . tensor ([ 0 , 2 , 3 , 4 ])
print ( metric ( y_pred , y_true ))
## define metri
1