【问题标题】:Importing pre-trained embeddings into Tensorflow's Embedding Feature Column将预训练的嵌入导入 TensorFlow 的嵌入特征列
【发布时间】:2020-02-05 06:06:14
【问题描述】:

我有一个 TF 估计器,它在其输入层使用特征列。其中之一是和EmbeddingColumn,我一直在随机初始化(默认行为)。

现在我想在 gensim 中预训练我的嵌入并将学习到的嵌入转移到我的 TF 模型中。 embedding_column 接受一个初始化参数,该参数期望一个可以调用的可调用对象,使用 tf.contrib.framework.load_embedding_initializer 可以是 created

但是,该函数需要一个保存的 TF 检查点,而我没有,因为我在 gensim 中训练了我的嵌入。

问题是:如何将 gensim 词向量(它们是 numpy 数组)保存为 TF 检查点格式的张量,以便我可以使用它来初始化我的嵌入列?

【问题讨论】:

    标签: tensorflow gensim word2vec transfer-learning


    【解决方案1】:

    想通了!这在 TensorFlow 1.14.0 中有效。

    您首先需要将嵌入向量转换为tf.Variable。然后使用tf.train.Saver 将其保存在检查点中。

    import tensorflow as tf
    import numpy as np
    
    
    ckpt_name = 'gensim_embeddings'
    vocab_file = 'vocab.txt'
    tensor_name = 'embeddings_tensor'
    vocab = ['A', 'B', 'C']
    embedding_vectors = np.array([
        [1,2,3],
        [4,5,6],
        [7,8,9]
    ], dtype=np.float32)
    
    embeddings = tf.Variable(initial_value=embedding_vectors)
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver({tensor_name: embeddings})
    with tf.Session() as sess:
        sess.run(init_op)
        saver.save(sess, ckpt_name)
    
    # writing vocab file
    with open(vocab_file, 'w') as f:
        f.write('\n'.join(vocab))
    

    要使用此检查点来初始化嵌入特征列:

    cat = tf.feature_column.categorical_column_with_vocabulary_file(
        key='cat', vocabulary_file=vocab_file)
    
    embedding_initializer = tf.contrib.framework.load_embedding_initializer(
        ckpt_path=ckpt_name,
        embedding_tensor_name='embeddings_tensor',
        new_vocab_size=3,
        embedding_dim=3,
        old_vocab_file=vocab_file,
        new_vocab_file=vocab_file
    )
    
    emb = tf.feature_column.embedding_column(cat, dimension=3, initializer=embedding_initializer, trainable=False)
    

    我们可以测试以确保它已正确初始化:

    def test_embedding(feature_column, sample):
        feature_layer = tf.keras.layers.DenseFeatures(feature_column)
        print(feature_layer(sample).numpy())
    
    tf.enable_eager_execution()
    
    sample = {'cat': tf.constant(['B', 'A'], dtype=tf.string)}
    
    test_embedding(item_emb, sample)
    

    正如预期的那样,输出是:

    [[4. 5. 6.]
     [1. 2. 3.]]
    

    分别是“B”和“A”的嵌入。

    【讨论】:

      猜你喜欢
      • 2018-12-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-06-20
      • 1970-01-01
      • 2021-01-18
      • 2019-12-30
      • 1970-01-01
      相关资源
      最近更新 更多