【问题标题】:Print shape of tensor even when it is `None`打印张量的形状,即使它是“无”
【发布时间】:2025-12-13 08:45:01
【问题描述】:

我正在使用 Keras 中的自定义数据加载器创建自定义算法。我知道,当您尝试在模型的内部方法中访问张量时,您通常会在打印张量的形状时得到None,通常在批处理轴上,因为批处理大小可以是可变的。我创建了一个更新渐变的自定义方法,只是为了进行完整性检查,我试图在程序执行时打印该轴形状的实际值。我不知道该怎么做。

这里有一些代码,看看我在哪里写的THIS LINE。此代码将打印出以下输出,其中将批处理轴显示为NONE。仅出于调试目的,我实际上想看看代码运行时这个值是多少,我该怎么做?

(无, 4, 100) (无, 100) (无, 100, 100) (无, 100) (无, 100, 100) (无, 100) (无, 100, 1) (无, 1)

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data
        
        tao = 1

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            
        gradients = tape.jacobian(loss, self.trainable_variables)
        
        new_gradients = []
        for grad in gradients:

            print(grad.shape) # <--- THIS LINE

            q1 = K.mean( grad[:env_siz], axis=0 )
            q2 = K.mean( grad[env_siz:], axis=0 )

            Q = K.mean( K.stack((K.sign(q1), K.sign(q2))), axis=0 ) # 1 means all gradients in same direction on that axis
            P = tf.where( tf.abs(Q) >= tao, K.mean( K.stack((q1, q2)), axis=0 ), 0)
#             print(P)
            new_gradients.append( P )

        # Compute gradients
        trainable_vars = self.trainable_variables
#         gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    您可以使用tf.print 而不是print 来查看“图形化”函数中的张量值。不要访问始终是静态已知形状的.shape 属性,而是使用tf.shape 来读取实际的张量形状。

    tf.print(tf.shape(grad))
    

    【讨论】:

    • 就是这样,我尝试打印 tf.shape 但我不知道 tf.print,谢谢!