输出补丁而不是完整图像的ImageDataGenerator

问题描述

我有一个很大的数据集,我想用它来训练带有Kera的CNN(太大了,无法将其加载到内存中)。我总是使用ImageDataGenerator.flow_from_dataframe进行培训,因为我将图像放在不同的目录中,如下所示。

datagen = ImageDataGenerator(
    rescale=1./255.
)
train_gen=datagen.flow_from_dataframe(
    dataframe=train_df),
    x_col="filepath",
    class_mode="input",
    shuffle=True,
    seed=1)
但是,这一次我不想使用完整的映像,而是使用映像的随机补丁,即,我希望选择一个随机映像,并每次随机获取该映像的32x32的补丁。我如何才能做到这一点?

我想过使用tf.extract_image_patchessklearn.feature_extraction.image.extract_patches_2d,但我不知道是否可以将它们集成到flow_from_dataframe中。

如有任何帮助,我们将不胜感激。


解决方案

您可以尝试使用ImageDataGeneratortf.image.extract_patches结合使用的预处理函数:

import tensorflow as tf
import matplotlib.pyplot as plt

BATCH_SIZE = 32

def get_patches():
    def _get_patches(image):
            image = tf.expand_dims(image,0)
            patches = tf.image.extract_patches(images=image,
                                    sizes=[1, 32, 32, 1],
                                    strides=[1, 32, 32, 1],
                                    rates=[1, 1, 1, 1],
                                    padding='VALID')

            patches = tf.reshape(patches, (1, 256, 256, 3))
            return patches
    return _get_patches

def reshape_data(images, labels):
      ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
      for b in tf.range(BATCH_SIZE):
        i = tf.random.uniform((), maxval=int(256/32), dtype=tf.int32)
        j = tf.random.uniform((), maxval=int(256/32), dtype=tf.int32)
        patched_image = tf.reshape(images[b], (8, 8, 3072))
        ta = ta.write(ta.size(), tf.reshape(patched_image[i, j], shape=(32, 32 ,3)))
      return ta.stack(), labels

preprocessing = get_patches()
flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20, preprocessing_function = preprocessing)


ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers, batch_size=BATCH_SIZE, shuffle=True),
    output_types=(tf.float32, tf.float32))

ds = ds.map(reshape_data)
images, _ = next(iter(ds.take(1)))

image = images[0] # (32, 32, 3)

plt.imshow(image.numpy())
问题是ImageDataGeneratorpreprocessing_function需要与输入形状相同的输出形状。因此,我首先创建面片,并基于面片构建与原始图像相同的输出形状。稍后,在reshape_data方法中,我将图像从(256,256,3)重塑到(8,8,3072),提取一个随机面片,然后将其与形状(32,32,3)一起返回。

相关文章