【问题标题】:Calculate Jacobian Matrix in TensorFlow v.2 with GradientTape()在 TensorFlow v.2 中使用 GradientTape() 计算雅可比矩阵
【发布时间】:2023-04-02 22:33:01
【问题描述】:

我目前正在尝试使用 TensorFlow 2 中的 GradientTape()batch_jacobian 在我的训练循环中计算雅可比矩阵。遗憾的是,我只获得了 None 值...

我目前的尝试是这样的:

for step, (batch_x, batch_y) in enumerate(train_data):

            with tf.GradientTape(persistent=True) as g:
                g.watch(batch_x)
                g.watch(batch_y)
                logits = self.retrained(batch_x, is_training=True)
                loss = lstm.cross_entropy_loss(logits, batch_y)
                acc = lstm.accuracy(logits, batch_y)
            avg_loss += loss
            avg_acc += acc

            gradients = g.gradient(loss, self.retrained.trainable_variables)
            J = g.batch_jacobian(logits, batch_x, experimental_use_pfor=False)
            print(J.numpy())
            self.optimizer.apply_gradients(zip(gradients, self.retrained.trainable_variables))

【问题讨论】:

  • self.retrained 是内置函数还是您定义的?你解决了吗?我有类似的问题,但对于tape.gradient(),原因是我定义了自己的损失函数。
  • 是的 - 我使用了我自己的损失函数 :) 使用 tf 中实现的一个可以解决它。非常感谢!

标签: tensorflow machine-learning deep-learning lstm


【解决方案1】:

以下代码使用 tensorflow 2:

import tensorflow as tf

在这里,我创建了一个简单的神经网络,然后使用它的偏导数 w.r.t。输入:

model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(2,1)),
tf.keras.layers.Dense(3),
tf.keras.layers.Dense(2)])

现在我使用 GradientTape 计算雅可比矩阵(对于输入:x=2.0,y=3.0):

x = tf.Variable([[2.0]])
y = tf.Variable([[3.0]])

with tf.GradientTape(persistent=True) as t:
    t.watch([x,y])
    z = tf.concat([x,y],1)
    f1 = model(z)[0][0]
    f2 = model(z)[0][1]


df1_dx = t.gradient(f1, x).numpy()
df1_dy = t.gradient(f1, y).numpy()
df2_dx = t.gradient(f2, x).numpy()
df2_dy = t.gradient(f2, y).numpy()

del t
print(df1_dx,df1_dy)
print(df2_dx,df2_dy)

考虑到神经网络的权重是随机初始化的,雅可比矩阵或打印输出如下:

[[-0.832729]] [[-0.19699946]]
[[-0.5562407]] [[0.53551793]]

我试图更详细地解释如何计算函数(明确编写)和神经网络的雅可比矩阵here

【讨论】:

  • 谢谢 :) 这使我可以保留自定义损失函数,并且不需要将其替换为 tf 中定义的损失函数。
  • 你好@Ehsan你能检查这个问题并帮助我吗? link
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2022-07-24
  • 1970-01-01
  • 2020-07-14
  • 1970-01-01
  • 2013-12-03
  • 2018-02-07
相关资源
最近更新 更多