【问题标题】:How to train a keras.concatenate model with tf.data.Dataset api?如何使用 tf.data.Dataset api 训练 keras.concatenate 模型?
【发布时间】:2020-06-24 02:40:08
【问题描述】:

这是我的 keras 模型”

input1 = keras.layers.Input(shape=(100,))
...
net1 = keras.layers.Dense(n1, activation='relu')()


input2 = keras.layers.Input(shape=(50,))
...
net2 = keras.layers.Dense(n2, activation='relu')()


merge = keras.layers.concatenate([net1,net2])
output = keras.layers.Dense(num_classes, activation='softmax')(merge)
model = keras.Model(inputs=[input1, input1], outputs=[output])

和tfrecord数据(大数据):

feature_description = {
    'f1': tf.io.FixedLenFeature([100], tf.int64),
    'f2': tf.io.FixedLenFeature([50], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
}

def parser(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)

ilename_queue = tf.data.Dataset.list_files(filename_queue, shuffle=True)
dataset = tf.data.TFRecordDataset(filename_queue).map(parser)

如何将数据拟合到模型中?

我知道你注定要这样做:

model.fit([f1,f2], y, epochs=epochs, batch_size=batch_size,validation_split=0.3)

而 f1,f2 是 df.dataframe。

【问题讨论】:

    标签: python keras deep-learning concatenation tensorflow2.x


    【解决方案1】:

    这里是我写的 sn-p,用于创建虚拟数据集并将其写入 tfrecord 文件并构建模型。

    import tensorflow as tf
    
    num_classes = 10
    n_samples = 10000
    
    f1 = tf.random.uniform(shape=[n_samples, 100], maxval=500, dtype=tf.int32).numpy()
    f2 = tf.random.uniform(shape=[n_samples, 50], maxval=500, dtype=tf.int32).numpy()
    labels = tf.random.uniform(shape=[n_samples], maxval=num_classes, dtype=tf.int32).numpy()
    
    
    
    def make_example(f1, f2, label):
        feature = {
            'f1': tf.train.Feature(int64_list=tf.train.Int64List(value=f1)),
            'f2': tf.train.Feature(int64_list=tf.train.Int64List(value=f2)),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
        }
        return tf.train.Example(features=tf.train.Features(feature=feature))
    
    def write_tfrecord(f1, f2, labels, tfrecord_path):
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            for i in range(len(f1)):
                example = make_example(f1[i], f2[i], labels[i])
                writer.write(example.SerializeToString())
    
    write_tfrecord(f1, f2, labels, 'test.tfrecord')
    
    n1 = 16
    n2 = 16
    input1 = tf.keras.layers.Input(shape=(100,))
    input2 = tf.keras.layers.Input(shape=(50,))
    
    net1 = tf.keras.layers.Dense(n1, activation='relu')(input1)
    net2 = tf.keras.layers.Dense(n2, activation='relu')(input2)
    
    merge = tf.keras.layers.concatenate([net1, net2])
    output = tf.keras.layers.Dense(num_classes, activation='softmax')(merge)
    model = tf.keras.Model(inputs=[input1, input2], outputs=[output])
    


    现在解析您的 tfrecord 文件并创建一个 tf.data.Dataset 对象应该很简单。 由于您的模型有两个输入和一个输出,因此您的 tf.data.Dataset 应该具有匹配的结构。所以我就是这样做的

    feature_description = {
        'f1': tf.io.FixedLenFeature([100], tf.int64),
        'f2': tf.io.FixedLenFeature([50], tf.int64),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    
    def parser(example_proto):
        parsed_example = tf.io.parse_single_example(example_proto, feature_description)
        f1 = parsed_example['f1']
        f2 = parsed_example['f2']
        label = parsed_example['label']
        return (f1, f2), label
    
    dataset = tf.data.TFRecordDataset('test.tfrecord')
    dataset = dataset.map(parser, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(4)
    dataset = dataset.shuffle(16)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    

    在打印dataset 的结构时,您应该看到以下输出

    <PrefetchDataset shapes: (((None, 100), (None, 50)), (None,)), types: ((tf.int64, tf.int64), tf.int64)>
    

    确保所有这些工作正常并且训练循环运行没有任何错误

    model.summary()
    model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
    model.fit(dataset, steps_per_epoch=10, epochs=2)
    

    这是最终输出

    Model: "model_2"
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_7 (InputLayer)            [(None, 100)]        0                                            
    __________________________________________________________________________________________________
    input_8 (InputLayer)            [(None, 50)]         0                                            
    __________________________________________________________________________________________________
    dense_6 (Dense)                 (None, 16)           1616        input_7[0][0]                    
    __________________________________________________________________________________________________
    dense_7 (Dense)                 (None, 16)           816         input_8[0][0]                    
    __________________________________________________________________________________________________
    concatenate_2 (Concatenate)     (None, 32)           0           dense_6[0][0]                    
                                                                     dense_7[0][0]                    
    __________________________________________________________________________________________________
    dense_8 (Dense)                 (None, 10)           330         concatenate_2[0][0]              
    ==================================================================================================
    Total params: 2,762
    Trainable params: 2,762
    Non-trainable params: 0
    __________________________________________________________________________________________________
    Train for 10 steps
    Epoch 1/2
    10/10 [==============================] - 0s 12ms/step - loss: 4735.8942
    Epoch 2/2
    10/10 [==============================] - 0s 1ms/step - loss: 2.7339
    

    【讨论】:

    • 谢谢,“return (f1, f2), label”是关键,我试过“return [f1,f2],label”,但是失败了。
    猜你喜欢
    • 2020-12-27
    • 2017-08-19
    • 2019-02-12
    • 1970-01-01
    • 1970-01-01
    • 2019-10-08
    • 1970-01-01
    • 1970-01-01
    • 2021-03-31
    相关资源
    最近更新 更多