使用TensorFlow io拆分训练/测试子集中的自定义二进制数据集

问题描述

我正在尝试使用本地二进制数据训练网络以执行regression inference。

每个本地二进制数据的布局如下:

并且整个数据由几个具有上述布局的*.bin文件组成。每个文件具有数量可变的403*4字节的序列。我能够使用以下代码读取其中一个文件:

import tensorflow as tf

RAW_N = 2 + 20*20 + 1

def convert_binary_to_float_array(register):
     return tf.io.decode_raw(register, out_type=tf.float32)

raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['mydata.bin'],record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(map_func=convert_binary_to_float_array)

现在,我需要创建4个数据集train_datatrain_labelstest_datatest_labels如下:

train_data, train_labels, test_data, test_labels = prepare_ds(raw_dataset, 0.8)

并使用它们进行培训和评估:

model = build_model()

history = model.fit(train_data, train_labels, ...)

loss, mse = model.evaluate(test_data, test_labels)

我的问题是:如何实现函数prepare_ds(dataset, frac)

def prepare_ds(dataset, frac):
    ...

我已尝试使用tf.shapetf.reshapetf.slice、订阅[:],但未成功。我意识到这些函数不能正常工作,因为在map()调用raw_dataset之后是一个MapDataset(由于急切的执行问题)。


解决方案

如果假设元数据是您输入的一部分(我假设是这样),您可以尝试如下操作:

import random
import struct
import tensorflow as tf
import numpy as np

RAW_N = 2 + 20*20 + 1

bytess = random.sample(range(1, 5000), RAW_N*4)
with open('mydata.bin', 'wb') as f:
  f.write(struct.pack('1612i', *bytess))

def decode_and_prepare(register):
  register = tf.io.decode_raw(register, out_type=tf.float32)
  inputs = register[:402]
  label = register[402:]
  return inputs, label

total_data_entries = 8
raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['/content/mydata.bin', '/content/mydata.bin'], record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(decode_and_prepare)
raw_dataset = raw_dataset.shuffle(buffer_size=total_data_entries)

train_ds_size = int(0.8 * total_data_entries)
test_ds_size = int(0.2 * total_data_entries)

train_ds = raw_dataset.take(train_ds_size)
remaining_data = raw_dataset.skip(train_ds_size)  
test_ds = remaining_data.take(test_ds_size)
请注意,出于演示目的,我使用了两次相同的bin文件。运行该代码片段后,您可以将数据集提供给您的模型,如下所示:

model = build_model()

history = model.fit(train_ds, ...)

loss, mse = model.evaluate(test_ds)

因为每个数据集都包含输入和相应的标签。

相关文章