如何使用tf.unction从TensorFlow中的函数集中随机选择

问题描述

我的问题是:在预处理过程中,我希望使用tf.data.Datasettf.functionAPI将从一组函数中随机选择的函数应用于数据集示例。

具体地说,我的数据是3D体积,我希望从一组24个预定义的旋转函数中应用旋转。我想在tf.function中编写这段代码,这样就限制了numpy和列表索引之类的包的使用。

例如,我想做这样的事情:

import tensorflow as tf

@tf.function
def func1(tensor):
    # Apply some rotation here
    ...

@tf.function
def func2(tensor):
    ...

...

@tf.function
def func24(tensor):
    ...


@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, ..., func24]

    # Randomly sample from 0-23
    a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
    
    return list_of_funcs[a](tensor)
但是,我无法将list_of_funcs索引为TypeError: list indices must be integers or slices, not Tensor。此外,我无法将这些函数(AFAIK)收集到tf.Tensor中并使用tf.gather

所以我的问题是:我如何在tf.function中合理而灵活地从这些函数中进行采样?


解决方案

可以使用 tf.switch_case

def func1(tensor):
    return tensor * 1

def func2(tensor):
    return tensor * 2

def func24(tensor):
    return tensor * 24

class Lambda:
    def __init__(self, func, arg):
        self._func = func
        self._arg = arg
        
    def __call__(self):
        return self._func(self._arg)

@tf.function
def apply(tensor):
    list_of_funcs = [func1, func2, func24]

    branch_index = tf.random.uniform(shape=[], minval=0, maxval=len(list_of_funcs), dtype=tf.int32)
    output = tf.switch_case(
        branch_index=branch_index, 
        branch_fns=[Lambda(func, tensor) for func in list_of_funcs], 
    )
    
    return output

修饰符@tf.function仅用于您希望优化的整个函数,在本例中为apply。如果使用applyInsidetf.data.Dataset.map,则根本不需要装饰符。

参见 this discussion 了解为什么我们必须在此处定义类Lambda

相关文章