【发布时间】:2020-09-17 18:53:20
【问题描述】:
我想训练一个拥有庞大数据集的 VAE,并决定使用为时尚 MNIST 制作的 VAE code 以及使用我在 github 上找到的文件名对 batch-loading 进行的流行修改。我的研究协作笔记本是here 和dataset 的示例部分。
但是 VAE 类的编写方式没有调用函数,根据 keras documentation 应该存在。我收到错误 NotImplementedError: When subclassing the Model class, you should implement a call method.
class VAE(tf.keras.Model):
"""a basic vae class for tensorflow
Extends:
tf.keras.Model
"""
def __init__(self, **kwargs):
super(VAE, self).__init__()
self.__dict__.update(kwargs)
self.enc = tf.keras.Sequential(self.enc)
self.dec = tf.keras.Sequential(self.dec)
def encode(self, x):
mu, sigma = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
return ds.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return eps * tf.exp(logvar * 0.5) + mean
def reconstruct(self, x):
mu, _ = tf.split(self.enc(x), num_or_size_splits=2, axis=1)
return self.decode(mu)
def decode(self, z):
return self.dec(z)
def compute_loss(self, x):
q_z = self.encode(x)
z = q_z.sample()
x_recon = self.decode(z)
p_z = ds.MultivariateNormalDiag(
loc=[0.] * z.shape[-1], scale_diag=[1.] * z.shape[-1]
)
kl_div = ds.kl_divergence(q_z, p_z)
latent_loss = tf.reduce_mean(tf.maximum(kl_div, 0))
recon_loss = tf.reduce_mean(tf.reduce_sum(tf.math.square(x - x_recon), axis=0))
return recon_loss, latent_loss
def compute_gradients(self, x):
with tf.GradientTape() as tape:
loss = self.compute_loss(x)
return tape.gradient(loss, self.trainable_variables)
@tf.function
def train(self, train_x):
gradients = self.compute_gradients(train_x)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
并且编码器和解码器分别定义,编译为
N_Z = 8
filt_base = 32
DIMS = (128,128,3)
encoder = [
tf.keras.layers.InputLayer(input_shape=DIMS),
tf.keras.layers.Conv2D(
filters=filt_base, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*2, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*2, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*3, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*3, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*4, kernel_size=3, strides=(1, 1), activation="relu", padding="same"
),
tf.keras.layers.Conv2D(
filters=filt_base*4, kernel_size=3, strides=(2, 2), activation="relu", padding="same"
),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=N_Z*2),
]
decoder = [
tf.keras.layers.Dense(units=8 * 8 * 128, activation="relu"),
tf.keras.layers.Reshape(target_shape=(8, 8, 128)),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*4, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*4, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*3, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*3, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*2, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base*2, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=filt_base, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"
),
]
optimizer = tf.keras.optimizers.Adam(1e-3)
model = VAE(
enc = encoder,
dec = decoder,
optimizer = optimizer,
)
model.compile(optimizer=optimizer)
并尝试使用 fit_generator 函数
训练模型num_epochs = 50
model.fit_generator(generator=my_training_batch_generator,
steps_per_epoch=(num_training_samples // batch_size),
epochs=num_epochs,
verbose=1,
validation_data=my_validation_batch_generator,
validation_steps=(num_validation_samples // batch_size),
use_multiprocessing=True,
workers=16,
max_queue_size=32)
我是机器学习的新手,如果能帮助解决这个问题,我将不胜感激。我认为问题在于 VAE 类中的 def train line。
一个可选的请求是,如果可以完成训练,以便我可以看到每个 epoch 之后的重建,将不胜感激。为此,我在研究协作笔记本中已经有一个 plot_reconstruction 函数需要调用。
【问题讨论】:
标签: python machine-learning keras