不同格式的数据的导入
Numpy 数据的导入
这种导入非常直白,就是使用 Numpy 把外部的数据进行导入,然后转换成 tf.Tensor
,之后使用 Dataset.from_tensor_slices()
。就可以成功导入了。简单的案例如下:
1 |
|
上面的简单的实例有一个很大的问题,就是 features
和 labels
会作为 tf.constant()
指令嵌入在 Tensorflow 的图中,会浪费很多内存。所以我们可以根据 tf.palceholder()
来定义 Dataset
,同时在对数据集初始化的时候送入 Numpy 数组。
1 |
|
TFRecord 数据的导入
TFRecord 是一种面向记录的简单二进制格式,很多 Tensorflow 应用采用这种方式来训练数据。这个也是推荐的做法。将它做成 Dataset 的方式也非常简单,就是单纯的通过 tf.data.TFRecordDataset
类就可以实现。
1 |
|
同样我们也能设定成,在初始化迭代器的时候导入数据。其中需要注意的是 filenames
需要设置成 tf.String
类。
1 |
|
Dataset 的预处理
Dataset.map()
Dataset.map(f)
转换通过将指定函数 f
应用于输入数据集的每个元素来生成新数据集。
简单的实例(解码图片数据并调整大小)如下:
1 |
|
至此为止,我们对图片的处理还是使用的是 TensorFlow 中的 API,那么我们想用 Python 自带的奇奇怪怪的包应该怎么做呢。TensorFlow 给了我们 tf.py_func()
这个选项来使用任意 Python 逻辑。我们只用在 Dataset.map()
中调用 tf.py_func()
指令就可以了。简单的例子如下:
1 |
|
批处理数据
- 简单的批处理
简单的批处理我们直接调用
Dataset.batch()
这种 API 即可,但是它有一个限制就是对于每个组件 i,所有元素的张量形状都必须完全相同。1
2
3
4
5
6
7
8
9
10
11inc_dataset = tf.data.Dataset.range(100) dec_dataset = tf.data.Dataset.range(0, -100, -1) dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset)) batched_dataset = dataset.batch(4) iterator = batched_dataset.make_one_shot_iterator() next_element = iterator.get_next() print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3]) print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7]) print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
- 填充批处理张量
和简单批处理相比,这种方式可以对具有不同大小的张量进行批处理。这种方法的 API 为
Dataset.padded_batch()
。简单的实例展示如下:1
2
3
4
5
6
7
8
9
10
11
12dataset = tf.data.Dataset.range(100) dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x)) dataset = dataset.padded_batch(4, padded_shapes=[None]) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]] print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0], # [5, 5, 5, 5, 5, 0, 0], # [6, 6, 6, 6, 6, 6, 0], # [7, 7, 7, 7, 7, 7, 7]]
可以通过
Dataset.padded_batch()
转换为每个组件的每个维度设置不同的填充,并且可以采用可变长度(在上面的示例中用None
表示)或恒定长度。也可以替换填充值,默认设置为 0。
训练工作流程
处理多个周期
有时候我们希望我们的数据集能训练很多个周期,简单的方法是使用Dataset.repeat()
API。
1 |
|
上述例子中,我们将 dataset 重复了 10 个周期,值得注意的是如果 repeat 中没有参数代表中无限次地重复使用,即不会在一个周期结束和下一个周期开始时发出信号。
如果我们想在每个周期结束时收到信号,则可以编写在数据集结束时捕获 tf.errors.OutOfRangeError
的训练循环。此时,就可以收集关于该周期的一些统计信息。
1 |
|
随机重排数据
有时候我们希望能随机的选取 Dataset 中的元素,则可以使用 Dataset.shuffle()
。
1 |
|