【发布时间】:2019-01-28 20:14:27
【问题描述】:
来自 TensorFlow,我觉得在 Keras 中实现除基本顺序模型之外的任何其他东西都可能非常棘手。自动发生的事情太多了。在 TensorFlow 中,您始终知道您的占位符(输入/输出)、形状、结构……因此很容易,例如,设置自定义损失。
定义多个输出和自定义损失函数的简洁方法是什么?
我们以一个简单的自动编码器为例,使用 MNIST:
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1)
短卷积编码器:
enc_in = Input(shape=(28, 28, 1), name="enc_in")
x = Conv2D(16, (3, 3))(enc_in)
x = LeakyReLU()(x)
x = MaxPool2D()(x)
x = Conv2D(32, (3, 3))(x)
x = LeakyReLU()(x)
x = Flatten()(x)
z = Dense(100, name="z")(x)
enc = Model(enc_in, z, name="encoder")
解码器的类似架构。我们不关心填充和卷积导致的维度减少,所以我们只是在最后应用双线性调整大小以再次匹配(batch, 28, 28, 1):
def resize_images(inputs, dims_xy):
x, y = dims_xy
return Lambda(lambda im: K.tf.image.resize_images(im, (y, x)))(inputs)
# decoder
dec_in = Input(shape=(100,), name="dec_in")
x = Dense(14 * 14 * 8)(dec_in)
x = LeakyReLU()(x)
x = Reshape((14, 14, 8))(x)
x = Conv2D(32, (3, 3))(x)
x = LeakyReLU()(x)
x = UpSampling2D()(x)
x = Conv2D(16, (3, 3))(x)
x = LeakyReLU()(x)
x = Conv2D(1, (3, 3), activation="linear")(x)
dec_out = resize_images(x, (28, 28))
dec = Model(dec_in, dec_out, name="decoder")
我们定义了自己的 MSE 以提供一个简单的示例...
def custom_loss(y_true, y_pred):
return K.mean(K.square(y_true - y_pred))
...最后构建我们的完整模型:
outputs = dec(enc(enc_in))
ae = Model(enc_in, outputs, name="ae")
ae.compile(optimizer=Adam(lr=1e-4), loss=custom_loss)
# training
ae.fit(x=X_train, y=X_train, batch_size=256, epochs=10)
如果我在解码器的最后一层定义 activation="sigmoid" 以获得漂亮的图像(输出间隔 [0.0, 1.0]),则训练损失会发散,因为 Keras 没有自动使用 logits,而是将 sigmoid 激活输入到失利。因此,在最后一层使用activation="linear" 进行训练会更好更快。在 TensorFlow 中,我只需定义两个张量 logits=x 和 output=sigmoid(x) 以便能够在任何自定义损失函数中使用 logits 并在绘图或其他应用程序中使用 output。
我怎么会在 Keras 中做这样的事情?
此外,如果我有多个输出,如何在自定义损失函数中使用它们?就像 VAE 的 KL 散度或 GAN 的损失项。
functional API guide 不是很有帮助(尤其是当您将其与 TensorFlow 的超广泛指南进行比较时),因为它仅涵盖基本的 LSTM 示例,您不必必须自己定义任何内容,但是只使用预定义的损失函数。
【问题讨论】:
-
你的意思是一个输出只是'悬空'而不用于训练,对吧?
-
@mrgloom 是的,完全正确。
-
我认为可以通过定义具有一个头和一些悬空输出的模型,然后您可以创建从
keras.callbacks.Callback派生的CustomCallback,其中在纪元结束时,您可以像这样从悬空输出层获得输出stackoverflow.com/questions/41711190/… 并将其传递给 tensorboard。
标签: python keras deep-learning