我就废话不多说了,大家还是直接看代码吧~
import torch
input_tensor = torch.tensor([1,2,3,4,5])
print(input_tensor>3)
mask = (input_tensor>3).nonzero()
print(mask)
print(input_tensor.index_select(0,mask))
tensor([0, 0, 0, 1, 1], dtype=torch.uint8)
tensor([3, 4])
tensor([4, 5])
补充知识:pytorch tensor筛选满足条件的行或列(使用与或)
我就废话不
1