【问题标题】:Tensorflow - re initializing weights and reshaping tensor of pretrained modelTensorflow - 重新初始化预训练模型的权重和重塑张量
【发布时间】:2017-03-25 12:52:22
【问题描述】:

我正在查看以下示例: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/4_Utils/save_restore_model.ipynb

我希望能够重新初始化隐藏层 2 的权重并将最终层(out_layer)重塑为 3 个类而不是 10 个。

我希望能够在第二次会话中做到这一点 - 也就是说,在我恢复训练好的模型之后。

我的主要目标是学习如何在 tensorflow 中进行迁移学习,我认为通过对这个示例执行此操作,我将能够学习到这一点。你能指出我应该做什么吗?我真的尝试过寻找但找不到任何类似的例子..

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    我解决了。

    可以通过执行以下操作重新初始化权重: 重要的部分是 set_value,它接收会话、张量流变量和新的权重值

    def _convert_string_dtype(dtype):
        if dtype == 'float16':
            return tf.float16
        if dtype == 'float32':
            return tf.float32
        elif dtype == 'float64':
            return tf.float64
        elif dtype == 'int16':
            return tf.int16
        elif dtype == 'int32':
            return tf.int32
        elif dtype == 'int64':
            return tf.int64
        elif dtype == 'uint8':
            return tf.int8
        elif dtype == 'uint16':
            return tf.uint16
        else:
            raise ValueError('Unsupported dtype:', dtype)
    
    def set_value(sess, x, value):
        """Sets the value of a variable, from a Numpy array.
        # Arguments
            x: Tensor to set to a new value.
            value: Value to set the tensor to, as a Numpy array
                (of the same shape).
        """
        value = np.asarray(value)
        tf_dtype = _convert_string_dtype(x.dtype.name.split('_')[0])
        if hasattr(x, '_assign_placeholder'):
            assign_placeholder = x._assign_placeholder
            assign_op = x._assign_op
        else:
            assign_placeholder = tf.placeholder(tf_dtype, shape=value.shape)
            assign_op = x.assign(assign_placeholder)
            x._assign_placeholder = assign_placeholder
            x._assign_op = assign_op
        return sess.run(assign_op, feed_dict={assign_placeholder: value})
    
    # Tensorflow variable name
    tf_var_name ="h2_weights"
    var = [var for var in tf.global_variables() if var.op.name==tf_var_name][0] 
    var_shape = var.get_shape().as_list()
    
    # Initialize to zero
    new_weights = np.zeros(var_shape)
    
    set_value(sess,var,new_weights)    
    

    【讨论】:

      猜你喜欢
      • 2021-06-08
      • 1970-01-01
      • 2020-10-27
      • 2016-06-29
      • 1970-01-01
      • 2019-04-14
      • 2021-10-02
      • 1970-01-01
      • 2020-02-09
      相关资源
      最近更新 更多