【发布时间】:2021-08-20 06:33:51
【问题描述】:
我正在尝试在自定义损失函数中进行一些自定义计算。但是当我从自定义损失函数中记录语句时,似乎自定义损失函数只被调用一次(在 .fit() 方法的开头)。
损失函数示例:
def loss(y_true, y_pred):
print("--- Starting of the loss function ---")
print(y_true)
loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
print("--- Ending of the loss function ---")
return loss
使用回调检查批处理的开始和结束时间:
class monitor(Callback):
def on_batch_begin(self, batch, logs=None):
print("\n >> Starting a new batch (batch index) :: ", batch)
def on_batch_end(self, batch, logs=None):
print(">> Ending a batch (batch index) :: ", batch)
.fit() 方法用作:
history = model.fit(
x=[inputs],
y=[outputs],
shuffle=False,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCH,
verbose=1,
callbacks=[monitor()]
)
使用的参数:
BATCH_SIZE = 128
NUM_EPOCH = 3
inputs.shape = (512, 8)
outputs.shape = (512, 2)
还有输出:
Epoch 1/3
>> Starting a new batch (batch index) :: 0
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
1/4 [======>.......................] - ETA: 0s - loss: 0.5551
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 5ms/step - loss: 0.5307
Epoch 2/3
>> Starting a new batch (batch index) :: 0
1/4 [======>.......................] - ETA: 0s - loss: 0.5443
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 5ms/step - loss: 0.5246
Epoch 3/3
>> Starting a new batch (batch index) :: 0
1/4 [======>.......................] - ETA: 0s - loss: 0.5433
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 4ms/step - loss: 0.5219
为什么自定义损失函数只在开始时调用,而不是每次批量计算都调用?我还想知道何时调用/触发损失函数?
【问题讨论】:
-
仔细观察输出,每批都会调用损失函数。
-
@ShubhamPanchal 那为什么
print("--- Starting of the loss function ---")和print("--- Ending of the loss function ---")语句不是每次调用都打印出来的? -
我明白你的观点@Rahul。我不知道为什么您没有看到所有时代都印有这些声明。然而,在每个时期都会打印损失值这一事实让我认为损失是正确计算的。打印可能无法正常工作。您可以尝试改用日志记录
标签: tensorflow keras tensorflow2.0 tf.keras loss-function