1、项目要求:
基于 pytorch 搭建神经网络分类模型识别花的种类,输入一张花的照片,输出显示最有可能的前八种花的名称和该种花的照片。
2、分三大步骤操作:
数据集预处理操作:
读取数据集数据
构建神经网络的数据集
1)数据增强:torchvision中transforms模块自带功能,将数据集中照片进行旋转、翻折、放大…得到更多的数据
2)数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
3)处理好的数据集保存在DataLoader模块中,可直接读取batch数据
网络模型训练操作:
迁移pytorch官网中models提供的resnet模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
选择GPU计算、选择训练哪些层、优化器设置、损失函数设置…
训练全连接层......
详细介绍见:https://blog.csdn.net/zhaohaobingniu/article/details/119922606?spm=1001.2014.3001.5501
1