这是一个手把手教你用 Tensorflow 构建卷机网络(CNN)进行图像分类的教程。教程并没有使用 MNIST 数据集,而是使用了真实的图片文件,并且教程代码包含了模型的保存、加载等功能,因此希望在日常项目中使用 Tensorflow 的朋友可以参考这篇教程。
概述
---
• 代码利用卷积网络完成一个图像分类的功能
• 训练完成后,模型保存在 model 文件中,可直接使用模型进行线上分类
• 同一个代码包括了训练和测试阶段,通过修改 train 参数为 True 和 False 控制训练和测试
数据准备
---
教程的图片从 Cifar 数据集中获取,download_cifar.py 从 Keras 自带的 Cifar 数据集中获取了部分 Cifar 数据集,并将其转换为 jpg 图片。
默认从 Cifar 数据集中选取了 3 类图片,每类 50 张图,分别是
• 0 => 飞机
• 1 => 汽车
• 2 => 鸟
图片都放在 data 文件夹中,按照 label_id.jpg 进行命名,例如 2_111.jpg 代表图片类别为 2(鸟),id 为 111。
1