标签转换器
在 Pytorch 中实现 ,表格数据的注意力网络。 这种简单的架构与 GBDT 的性能相差无几。
安装
$ pip install tab-transformer-pytorch
用法
import torch
from tab_transformer_pytorch import TabTransformer
cont_mean_std = torch . randn ( 10 , 2 )
model = TabTransformer (
categories = ( 10 , 5 , 6 , 5 , 8 ), # tuple containing the number of unique values within each category
num_continuous = 10 , # number of co
1