构建人工智能模型基础:TFDS和Keras的完美搭配
ztj100 2024-12-28 16:50 23 浏览 0 评论
上一篇:《数据工程师,转型人工智能岗位的理想时空通道》
序言:本节将带您深入探索 TensorFlow 提供的关键工具和方法,涵盖数据集管理和神经网络模型的构建与训练。在现代人工智能框架中,TensorFlow 的数据集接口 (TensorFlow Datasets, 简称 TFDS) 与 Keras 模型库为深度学习任务提供了极大的便利。本章将具体展示如何使用 TFDS 和 Keras 配合构建神经网络架构,以实现高效的数据处理和模型训练。通过本节的实践操作,您将掌握从数据加载、预处理到模型搭建的核心流程,为进一步的人工智能模型研发奠定坚实的基础。
使用TFDS与Keras模型
在第2章中,你学到了如何使用TensorFlow和Keras创建一个简单的计算机视觉模型,使用Keras内置的数据集(包括Fashion MNIST),代码如下所示:
mnist = tf.keras.datasets.fashion_mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
使用TFDS时,代码非常相似,但有一些小的变化。Keras的数据集直接给我们返回了可以在model.fit中原生使用的ndarray类型,但使用TFDS时,我们需要进行一些转换:
(training_images, training_labels), (test_images, test_labels) = tfds.as_numpy(tfds.load('fashion_mnist', split=['train', 'test'], batch_size=-1, as_supervised=True))
在这里,我们使用了tfds.load,将fashion_mnist作为所需的数据集传递给它。我们知道这个数据集有训练集和测试集的划分,所以在数组中传入这些划分项将返回包含图像和标签的适配器数组。使用tfds.as_numpy在调用tfds.load时会将数据返回为Numpy数组。指定batch_size=-1会让我们获取所有数据,而as_supervised=True则确保返回的格式为(输入,标签)元组。
完成这些操作后,我们基本上获得了与Keras数据集相同的数据格式,但有一个区别——在TFDS中,数据的形状是(28, 28, 1),而在Keras数据集中是(28, 28)。
这意味着代码需要做一些小的改动来指定输入数据的形状为(28, 28, 1),而不是(28, 28):
import tensorflow as tf
import tensorflow_datasets as tfds
(training_images, training_labels), (test_images, test_labels) = tfds.as_numpy(tfds.load('fashion_mnist', split=['train', 'test'], batch_size=-1, as_supervised=True))
training_images = training_images / 255.0
test_images = test_images / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(training_images, training_labels, epochs=5)
对于更复杂的示例,可以参考第3章中使用的“马或人”数据集,这个数据集在TFDS中也可以使用。以下是使用它来训练模型的完整代码:
import tensorflow as tf
import tensorflow_datasets as tfds
data = tfds.load('horses_or_humans', split='train', as_supervised=True)
train_batches = data.shuffle(100).batch(10)
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(300, 300, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(train_batches, epochs=10)
正如你所见,这相当直接:只需调用 tfds.load,传入你想要的分割(在本例中是训练集),然后在模型中使用它。数据被批处理并打乱顺序,以便更有效地进行训练。
“Horses or Humans”数据集被划分为训练集和测试集,因此如果你想在训练时验证模型,可以通过TFDS加载单独的验证集,方法如下:
val_data = tfds.load('horses_or_humans', split='test', as_supervised=True)
你需要像处理训练集一样批处理它。例如:
validation_batches = val_data.batch(32)
然后在训练时,指定这些批次作为验证数据。你还需要显式设置每个周期要使用的验证步数,否则TensorFlow会报错。如果不确定,设置为1即可,如下:
history = model.fit(train_batches, epochs=10, validation_data=validation_batches, validation_steps=1)
加载特定版本
所有在TFDS中存储的数据集都使用MAJOR.MINOR.PATCH编号系统。其保证如下:如果仅PATCH更新,则调用返回的数据相同,但底层组织可能发生变化。这种变化对开发者应是无感知的。如果MINOR更新,则数据保持不变,但可能会在每条记录中增加新的特性(非破坏性更改)。此外,对特定切片(参见第74页的“使用自定义切片”)的数据不会重新排序。如果MAJOR更新,则记录格式及其位置可能会发生变化,因此特定切片可能会返回不同的值。
在检查数据集时,你会看到何时有不同版本可用——例如,cnn_dailymail数据集就是这样。如果你不想要默认版本(本文撰写时是3.0.0),而是希望使用早期版本,例如1.0.0,可以按以下方式加载:
data, info = tfds.load("cnn_dailymail:1.0.0", with_info=True)
请注意,如果你在Colab上使用TFDS,建议检查Colab使用的TFDS版本。本文撰写时,Colab预装的TFDS是2.0版,但其中存在一些加载数据集的bug(包括cnn_dailymail数据集),这些问题在TFDS 2.1及之后的版本中已修复,因此建议使用这些版本,或者至少在Colab中安装它们,而不是依赖内置的默认版本。
使用映射函数进行数据增强
在前面的章节中,你见到了使用ImageDataGenerator为模型提供训练数据时的一些有用增强工具。你可能想知道如何在使用TFDS时实现同样的功能,因为这时你不是从子目录流式读取图像。实现此功能的最佳方法(或任何其他形式的转换)是对数据适配器使用映射函数。让我们看看如何实现这一点。
之前,我们对Horses or Humans数据集的处理只是从TFDS中加载数据并为其创建批次,如下所示:
data = tfds.load('horses_or_humans', split='train', as_supervised=True)
train_batches = data.shuffle(100).batch(10)
要对数据进行变换并将其映射到数据集,你可以创建一个映射函数。这只是标准的Python代码。例如,假设你创建了一个名为augmentimages的函数,并让它进行一些图像增强,如下所示:
def augmentimages(image, label):
image = tf.cast(image, tf.float32)
image = (image/255)
image = tf.image.random_flip_left_right(image)
return image, label
然后你可以将其映射到数据上,创建一个名为train的新数据集:
train = data.map(augmentimages)
之后,在创建批次时,使用train而不是data,如下:
train_batches = train.shuffle(100).batch(32)
在augmentimages函数中,你可以看到使用tf.image.random_flip_left_right(image)对图像进行左右随机翻转。tf.image库中有很多可用于增强的函数;详细内容请参阅文档。
使用TensorFlow Addons
TensorFlow Addons库包含更多可用函数。ImageDataGenerator增强中的一些功能(如旋转)仅在此库中可用,因此建议查看它。
使用TensorFlow Addons非常简单——只需安装库即可:
pip install tensorflow-addons
安装完成后,可以将Addons混入到你的映射函数中。以下是将旋转Addons用于前面映射函数的示例:
import tensorflow_addons as tfa
def augmentimages(image, label):
image = tf.cast(image, tf.float32)
image = (image/255)
image = tf.image.random_flip_left_right(image)
image = tfa.image.rotate(image, 40, interpolation='NEAREST')
return image, label
使用自定义分割
到目前为止,你一直使用的是预先分割为训练集和测试集的数据集。例如,Fashion MNIST有60,000和10,000条记录,分别用于训练和测试。但如果你不想使用这些分割呢?如果你想根据自己的需求分割数据呢?TFDS的一个强大之处就在于——它提供了一个API,允许你精细地控制数据的分割方式。
实际上你已经见过这种方式了,例如像这样加载数据时:
data = tfds.load('cats_vs_dogs', split='train', as_supervised=True)
注意split参数是一个字符串,这里你请求了train分割,它恰好是整个数据集。如果你熟悉Python的切片符号,也可以使用它。这种符号可以总结为在方括号内定义你想要的切片,如下所示:[<start>: <stop>: <step>]。它是一种相当复杂的语法,赋予了很大的灵活性。
例如,如果你希望train的前10,000条记录作为训练数据,可以省略<start>,直接调用train[:10000](一个有用的记忆技巧是将前导冒号读作“前”,所以这将读作“train前10,000条记录”):
data = tfds.load('cats_vs_dogs', split='train[:10000]', as_supervised=True)
你还可以使用%来指定分割。例如,如果你希望前20%的记录用于训练,可以像这样使用:20%:
data = tfds.load('cats_vs_dogs', split='train[:20%]', as_supervised=True)
你甚至可以更进一步,组合多个分割。也就是说,如果你希望训练数据是前1000条记录和最后1000条记录的组合,可以这样做(-1000:表示“最后1000条记录”,“:1000”表示“前1000条记录”):
data = tfds.load('cats_vs_dogs', split='train[-1000:]+train[:1000]', as_supervised=True)
Dogs vs. Cats数据集没有固定的训练、测试和验证分割,但使用TFDS,创建自定义分割非常简单。假设你希望分割为80%、10%、10%。可以这样创建三个数据集:
train_data = tfds.load('cats_vs_dogs', split='train[:80%]', as_supervised=True)
validation_data = tfds.load('cats_vs_dogs', split='train[80%:90%]', as_supervised=True)
test_data = tfds.load('cats_vs_dogs', split='train[-10%:]', as_supervised=True)
一旦你有了它们,就可以像使用任何命名分割一样使用它们。
需要注意的是,由于返回的数据集无法被探测其长度,因此通常很难确认你是否正确地分割了原始数据集。要查看你在某个分割中的记录数量,你必须遍历整个数据集并一条条计数。以下是对你刚创建的训练集进行计数的代码:
train_length = [i for i, _ in enumerate(train_data)][-1] + 1
print(train_length)
这可能是一个较慢的过程,因此请确保仅在调试时使用它。
本节总结: 本章介绍了如何使用 TensorFlow Datasets(TFDS)和 Keras 搭建神经网络模型,并探索了数据增强、分割和预处理等关键操作。通过将 TFDS 数据集与 Keras 模型结合,您学会了如何高效加载和转换数据,为神经网络模型的构建和训练做好准备。此外,我们还介绍了如何灵活地自定义数据分割和应用数据增强,为模型提供更加多样化的训练数据。掌握了这些技能后,您将能够更自如地应用 TensorFlow 和 Keras 进行各种深度学习项目的开发,为复杂模型的搭建奠定坚实的基础。下一节我们将会为大家介绍从TFDS中下载出来的数据集是以什么形式保存下来的—TFRecord!
相关推荐
-
- SpringBoot如何实现优雅的参数校验
-
平常业务中肯定少不了校验,如果我们把大量的校验代码夹杂到业务中,肯定是不优雅的,对于一些简单的校验,我们可以使用java为我们提供的api进行处理,同时对于一些...
-
2025-05-11 19:46 ztj100
- Java中的空指针怎么处理?
-
#暑期创作大赛#Java程序员工作中遇到最多的错误就是空指针异常,无论你多么细心,一不留神就从代码的某个地方冒出NullPointerException,令人头疼。...
- 一坨一坨 if/else 参数校验,被 SpringBoot 参数校验组件整干净了
-
来源:https://mp.weixin.qq.com/s/ZVOiT-_C3f-g7aj3760Q-g...
- 用了这两款插件,同事再也不说我代码写的烂了
-
同事:你的代码写的不行啊,不够规范啊。我:我写的代码怎么可能不规范,不要胡说。于是同事打开我的IDEA,安装了一个插件,然后执行了一下,规范不规范,看报告吧。这可怎么是好,这玩意竟然给我挑出来这么...
- SpringBoot中6种拦截器使用场景
-
SpringBoot中6种拦截器使用场景,下面是思维导图详细总结一、拦截器基础...
- 用注解进行参数校验,spring validation介绍、使用、实现原理分析
-
springvalidation是什么在平时的需求开发中,经常会有参数校验的需求,比如一个接收用户注册请求的接口,要校验用户传入的用户名不能为空、用户名长度不超过20个字符、传入的手机号是合法的手机...
- 快速上手:SpringBoot自定义请求参数校验
-
作者:UncleChen来源:http://unclechen.github.io/最近在工作中遇到写一些API,这些API的请求参数非常多,嵌套也非常复杂,如果参数的校验代码全部都手动去实现,写起来...
- 分布式微服务架构组件
-
1、服务发现-Nacos服务发现、配置管理、服务治理及管理,同类产品还有ZooKeeper、Eureka、Consulhttps://nacos.io/zh-cn/docs/what-is-nacos...
- 优雅的参数校验,告别冗余if-else
-
一、参数校验简介...
- Spring Boot断言深度指南:用断言机制为代码构筑健壮防线
-
在SpringBoot开发中,断言(Assert)如同代码的"体检医生",能在上线前精准捕捉业务逻辑漏洞。本文将结合企业级实践,解析如何通过断言机制实现代码自检、异常预警与性能优化三...
- 如何在项目中优雅的校验参数
-
本文看点前言验证数据是贯穿所有应用程序层(从表示层到持久层)的常见任务。通常在每一层实现相同的验证逻辑,这既费时又容易出错。为了避免重复这些验证,开发人员经常将验证逻辑直接捆绑到域模型中,将域类与验证...
- SpingBoot项目使用@Validated和@Valid参数校验
-
一、什么是参数校验?我们在后端开发中,经常遇到的一个问题就是入参校验。简单来说就是对一个方法入参的参数进行校验,看是否符合我们的要求。比如入参要求是一个金额,你前端没做限制,用户随便过来一个负数,或者...
- 28个验证注解,通过业务案例让你精通Java数据校验(收藏篇)
-
在现代软件开发中,数据验证是确保应用程序健壮性和可靠性的关键环节。JavaBeanValidation(JSR380)作为一个功能强大的规范,为我们提供了一套全面的注解工具集,这些注解能够帮...
- Springboot @NotBlank参数校验失效汇总
-
有时候明明一个微服务里的@Validated和@NotBlank用的好好的,但就是另一个里不能用,这时候问题是最不好排查的,下面列举了各种失效情况的汇总,供各位参考:1、版本问题springbo...
- 这可能是最全面的Spring面试八股文了
-
Spring是什么?Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器框架。...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)