【问题标题】:How do you send arguments to a generator function using tf.data.Dataset.from_generator()?如何使用 tf.data.Dataset.from_generator() 向生成器函数发送参数?
【发布时间】:2019-02-25 20:05:00
【问题描述】:

我想使用from_generator() 函数创建多个tf.data.Dataset。我想向生成器函数 (raw_data_gen) 发送一个参数。这个想法是生成器函数将根据发送的参数产生不同的数据。通过这种方式,我希望raw_data_gen 能够提供训练、验证或测试数据。

training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))

validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))

test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))

当我尝试以这种方式调用from_generator() 时收到的错误消息是:

TypeError: from_generator() got an unexpected keyword argument 'args'

这里是raw_data_gen 函数,虽然我不确定你是否需要这个函数,因为我的直觉是问题出在from_generator() 的调用上:

def raw_data_gen(train_val_or_test):

    if train_val_or_test == 1:        
        #For every filename collected in the list
        for filename, lab in training_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 2:
        #For every filename collected in the list
        for filename, lab in validation_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 3:
        #For every filename collected in the list
        for filename, lab in test_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    else:
        print("generator function called with an argument not in [1, 2, 3]")
        raise ValueError()

【问题讨论】:

    标签: python python-3.x tensorflow tensorflow-datasets


    【解决方案1】:

    您需要基于raw_data_gen 定义一个不带任何参数的新函数。您可以使用lambda 关键字来执行此操作。

    training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
    ...
    

    现在,我们将一个不带任何参数的函数传递给from_generator,但它只是充当raw_data_gen,参数设置为1。您可以对验证集和测试集使用相同的方案,分别通过 2 和 3。

    【讨论】:

    • 太完美了。非常感谢。这是 from_generator() 不能接受参数的当前错误吗?还是我只是误解了文档?
    • 哦,对了,我忘了他们添加了args 的东西。这是一个最近的更新,显然是在 1.9 中引入的。也许您使用的是过时版本的 Tensorflow?
    • 该死,你是对的。我正在使用一个版本并阅读另一个版本的文档...再次感谢!
    【解决方案2】:

    对于 TensorFlow 2.4:

    training_dataset = tf.data.Dataset.from_generator(
         raw_data_gen, 
         args=(1), 
         output_types=(tf.float32, tf.uint8), 
         output_shapes=([None, 1], [None]))
    

    【讨论】:

    • 我遇到了这个解决方案的问题 - args 被转换为类似字节的对象
    猜你喜欢
    • 2021-05-06
    • 1970-01-01
    • 2022-01-05
    • 2012-08-22
    • 2017-05-01
    • 1970-01-01
    • 1970-01-01
    • 2010-12-05
    • 1970-01-01
    相关资源
    最近更新 更多