【问题标题】:How to pass the input tensor of a model to a loss function?如何将模型的输入张量传递给损失函数?
【发布时间】:2024-01-19 13:12:01
【问题描述】:

我的目标是创建一个自定义损失函数,根据y_truey_pred 和模型输入层的张量计算损失:

import numpy as np
from tensorflow import keras as K

input_shape = (16, 16, 1)

input = K.layers.Input(input_shape)
dense = K.layers.Dense(16)(input)
output = K.layers.Dense(1)(dense)

model = K.Model(inputs=input, outputs=output)


def CustomLoss(y_true, y_pred):
  return K.backend.sum(K.backend.abs(y_true - model.input * y_pred))


model.compile(loss=CustomLoss)
model.fit(np.ones(input_shape), np.zeros(input_shape))

但是,此代码失败并显示以下错误消息:

TypeError: Cannot convert a symbolic Keras input/output to a numpy array. This error may indicate that you're trying to pass a symbolic value to a NumPy call, which is not supported. Or, you may be trying to pass Keras symbolic inputs/outputs to a TF API that does not register dispatching, preventing Keras from automatically converting the API call to a lambda layer in the Functional Model.

如何将模型的输入张量传递给损失函数?

Tensorflow 版本:2.4.1
Python 版本:3.8.8

【问题讨论】:

    标签: python tensorflow keras deep-learning loss-function


    【解决方案1】:

    您可以使用add_loss 将外部层传递给您的loss。举个例子:

    import numpy as np
    from tensorflow import keras as K
    
    def CustomLoss(y_true, y_pred, input_l):
        return K.backend.sum(K.backend.abs(y_true - input_l * y_pred))
    
    input_shape = (16, 16, 1)
    n_sample = 10
    
    X = np.random.uniform(0,1, (n_sample,) + input_shape)
    y = np.random.uniform(0,1, (n_sample,) + input_shape)
    
    inp = K.layers.Input(input_shape)
    dense = K.layers.Dense(16)(inp)
    out = K.layers.Dense(1)(dense)
    
    target = K.layers.Input(input_shape)
    model = K.Model(inputs=[inp,target], outputs=out)
    
    model.add_loss( CustomLoss( target, out, inp ) )
    model.compile(loss=None, optimizer='adam')
    model.fit(x=[X,y], y=None, epochs=3)
    

    在推理模式下使用模型(从输入中删除目标)

    final_model = K.Model(model.input[0], model.output)
    final_model.predict(X)
    

    【讨论】:

      最近更新 更多