【问题标题】:Tensorflow 2 variable not trainableTensorflow 2变量不可训练
【发布时间】:2020-02-03 21:18:18
【问题描述】:

我在 tf2 中创建了一个简单的模型,它将输入“a”乘以变量“b”(初始化为 1)并返回输出“c”。然后我尝试在简单的数据集 a=1, c=5 上对其进行训练。我希望它能够学习 b=5。

import tensorflow as tf
from tensorflow.keras.models import Model

a = Input(shape=(1,))
b = tf.Variable(1., trainable=True)
c = a*b
model = Model(a,c)

loss = tf.keras.losses.MeanAbsoluteError()
model.compile(optimizer='adam', loss=loss)

model.fit([1.],[5.],batch_size=1, epochs=1)

但是,tf2 并不认为变量“b”是可训练的。摘要显示没有可训练的参数。

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 1)]               0         
_________________________________________________________________
tf_op_layer_mul (TensorFlowO [(None, 1)]               0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

为什么变量“b”没有训练?

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0 tf.keras


    【解决方案1】:

    Keras 模型是 Layer 类的包装器。您必须将此变量包装为 keras 层,以便将其显示为模型中的可训练参数。

    您可以像这样为它创建一个微小的自定义层:

    class MyLayer(tf.keras.layers.Layer):
      def __init__(self):
        super(MyLayer, self).__init__()
    
        #your variable goes here
        self.variable = tf.Variable(1., trainable=True, dtype=tf.float64)
    
      def call(self, inputs, **kwargs):
    
        # your mul operation goes here
        x = inputs * self.variable
    
        return x
    

    这里call 方法将进行乘法运算。我们可以像使用输出模型中的任何其他层一样使用这一层。在这里,我正在创建一个 Sequential 模型,添加 aboce 乘法运算作为模型层。

    model = tf.keras.models.Sequential()
    mylayer_object = MyLayer()
    model.add(mylayer_object)
    
    loss = tf.keras.losses.MeanAbsoluteError()
    model.compile("adam", loss)
    
    model.fit([1.],[5.],batch_size=1, epochs=1)
    model.summary()
    '''
    Train on 1 samples
    1/1 [==============================] - 0s 426ms/sample - loss: 4.0000
    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    my_layer (MyLayer)           multiple                  1         
    =================================================================
    Total params: 1
    Trainable params: 1
    Non-trainable params: 0
    _________________________________________________________________
    '''
    

    在此之后,如果您可以列出模型的可训练参数。

    print(model.trainable_variables)
    # [<tf.Variable 'Variable:0' shape=() dtype=float64, numpy=1.0009999968852092>]
    

    【讨论】:

    • 这是文档中的任何地方吗?
    猜你喜欢
    • 1970-01-01
    • 2020-07-16
    • 2018-06-30
    • 2017-11-30
    • 1970-01-01
    • 2021-12-26
    • 1970-01-01
    • 2022-01-25
    • 2016-09-16
    相关资源
    最近更新 更多