【问题标题】:Issue with modifying a Keras class to include call function修改 Keras 类以包含调用函数的问题
【发布时间】:2020-09-17 18:53:20
【问题描述】:

我想训练一个拥有庞大数据集的 VAE,并决定使用为时尚 MNIST 制作的 VAE code 以及使用我在 github 上找到的文件名对 batch-loading 进行的流行修改。我的研究协作笔记本是heredataset 的示例部分。

但是 VAE 类的编写方式没有调用函数,根据 keras documentation 应该存在。我收到错误 NotImplementedError: When subclassing the Model class, you should implement a call method.

class VAE(tf.keras.Model):
"""a basic vae class for tensorflow
Extends:
    tf.keras.Model
"""

def __init__(self, **kwargs):
    super(VAE, self).__init__()
    self.__dict__.update(kwargs)

    self.enc = tf.keras.Sequential(self.enc)
    self.dec = tf.keras.Sequential(self.dec)

def encode(self, x):
    mu, sigma = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
    return ds.MultivariateNormalDiag(loc=mu, scale_diag=sigma)

def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * 0.5) + mean

def reconstruct(self, x):
    mu, _ = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
    return self.decode(mu)

def decode(self, z):
    return self.dec(z)

def compute_loss(self, x):

    q_z = self.encode(x)
    z = q_z.sample()
    x_recon = self.decode(z)
    p_z = ds.MultivariateNormalDiag(
      loc=[0.] * z.shape[-1], scale_diag=[1.] * z.shape[-1]
      )
    kl_div = ds.kl_divergence(q_z, p_z)
    latent_loss = tf.reduce_mean(tf.maximum(kl_div, 0))
    recon_loss = tf.reduce_mean(tf.reduce_sum(tf.math.square(x - x_recon), axis=0))

    return recon_loss, latent_loss

def compute_gradients(self, x):
    with tf.GradientTape() as tape:
        loss = self.compute_loss(x)
    return tape.gradient(loss, self.trainable_variables)

@tf.function
def train(self, train_x):
    gradients = self.compute_gradients(train_x)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

并且编码器和解码器分别定义,编译为

N_Z = 8
filt_base = 32
DIMS = (128,128,3)

encoder = [
tf.keras.layers.InputLayer(input_shape=DIMS),
tf.keras.layers.Conv2D(
    filters=filt_base, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*2, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*2, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*3, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*3, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*4, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
    filters=filt_base*4, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=N_Z*2),
]

decoder = [
tf.keras.layers.Dense(units=8 * 8 * 128, activation="relu"),
tf.keras.layers.Reshape(target_shape=(8, 8, 128)),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*4, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*4, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*3, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*3, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*2, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base*2, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=filt_base, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
    filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"
),
]

optimizer = tf.keras.optimizers.Adam(1e-3)

model = VAE(
  enc = encoder,
  dec = decoder,
  optimizer = optimizer,
)
model.compile(optimizer=optimizer)

并尝试使用 fit_generator 函数

训练模型
num_epochs = 50
model.fit_generator(generator=my_training_batch_generator,
                                      steps_per_epoch=(num_training_samples // batch_size),
                                      epochs=num_epochs,
                                      verbose=1,
                                      validation_data=my_validation_batch_generator,
                                      validation_steps=(num_validation_samples // batch_size),
                                      use_multiprocessing=True,
                                      workers=16,
                                      max_queue_size=32)

我是机器学习的新手,如果能帮助解决这个问题,我将不胜感激。我认为问题在于 VAE 类中的 def train line。

一个可选的请求是,如果可以完成训练,以便我可以看到每个 epoch 之后的重建,将不胜感激。为此,我在研究协作笔记本中已经有一个 plot_reconstruction 函数需要调用。

【问题讨论】:

    标签: python machine-learning keras


    【解决方案1】:

    APaul31,

    特别是在您的代码中,我建议将 call() 函数添加到 VAE 类中:

    def call(self, x):
        q_z = self.encode(x)
        z = q_z.sample()
        x_recon = self.decode(z)
    

    我还建议对您的任务使用更标准的方法,尤其是作为初学者:

    1. 使用 tf.keras.preprocessing.image_dataset_from_directory() 加载图像。教程here.

    2. 使用自定义 Model.train_step() 来计算 VAE 损失,而不是 VAE 类中的多个函数。示例here

    【讨论】:

    • 以上答案在 Github github.com/tensorflow/tensorflow/issues/43173987654323@ 讨论后更新
    • 嗨,Ruslan,非常感谢您详细调查问题。我按照您的建议修改了主要的 VAE 类,我想我解决了调用函数的问题。但是我在自定义文件加载类方面遇到了一些问题。我无法使用 image_dataset_from_directory(),因为它会加载整个数据集,并且协作笔记本的内存将耗尽。是否可以快速查看更新的笔记本?目前我遇到 TypeError 的错误:'NoneType' object is not callable
    • 1) 请接受并支持原始答案 2) 关于内存耗尽问题:尝试将自定义生成器编写为 keras.utils.Sequence 子类,而不是 image_dataset_from_directory()。在其__getitem__() 函数中,您可以控制数据加载过程以确保仅将单个批次加载到内存中。请参阅此处的示例:[medium.com/@mrgarg.rajat/…
    • 我的声望低于 15,不能投票给你 :-|
    • 这会过去的)
    【解决方案2】:

    目前我遇到 TypeError 的错误:'NoneType' object is not callable

    问题在于fit 方法,当您传递数据生成器时,使用fit_generator 而不是fit。在协作中它调用fit

    另外,请注意,您可以使用flow_from_directory 方法而不是image_dataset_from_directory 来懒惰地生成批次,它不会将整个数据加载到内存中 https://keras.io/api/preprocessing/image/#flowfromdirectory-method.

    【讨论】:

    • 嗨,Surj,我修改了我的代码以包含 flow_from_directory 功能,它能够加载数据并启动第一个 epoch,但现在遇到新错误ValueError: Layer sequential_1 expects 1 inputs, but it received 2 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, None, None, None) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, None, None, None) dtype=float32>] 任何帮助将不胜感激。数据集是here。也标记 Ruslan @Ruslan S.
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-08-16
    • 1970-01-01
    • 1970-01-01
    • 2012-03-23
    • 2022-07-08
    • 1970-01-01
    相关资源
    最近更新 更多