【发布时间】:2021-07-19 18:37:10
【问题描述】:
所以看起来 Python 的 tensorflow 和当前发布的 Jython 版本不兼容,所以我正在用 Java 编写我的 AI 模型(GAN)。我正在为我的 GAN 模型遵循 Keras 和 Tensorflow 中的 python 指南。我已经想出了如何使用 DeepLearning4j 在 Java 中设置我的 python 神经网络代码。问题是,我无法在 DeepLearning4j 中设置训练功能。
这是我所遵循的 Python 培训代码:
generator_optimizer = tf.keras.optimizers.Adam(generator_lr) # learning rate for generator
discriminator_optimizer = tf.keras.optimizers.Adam(discriminator_lr) # learning rate for discriminator
seed = tf.random.normal([num_examples_to_generate, noise_dim])
# ignore this; I already have dataset code set up
train_dataset = tf.data.Dataset.from_tensor_slices(train_images_scaled).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True) #generator model is already set up under the name "generator"
real_output = discriminator(images, training=True) #discriminator model is already set up as "discriminator"
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output) #predefined function that calculates the loss of the generator as a decimal value using the loss of the discriminator; uses crossentropy
disc_loss = discriminator_loss(real_output, fake_output) #predefined function that calculates the loss of the discriminator as a decimal value
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time() # time module of course
for structure_batch in dataset:
train_step(structure_batch)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
关于 GAN 的目的是什么,它是通过对不同块使用 one-hot 编码算法在 Minecraft 中生成结构,但我认为这不是必要的信息。
我只是希望使用 DeepLearning4j 库将上述 Python 代码“翻译”成 Java。 (我也有 Java 的 tensorflow 库,但我认为它与 DL4J 不直接兼容)
【问题讨论】:
标签: tensorflow artificial-intelligence minecraft generative-adversarial-network deeplearning4j