您可以在keras 中创建自己的图层。这将帮助您自定义层内的权重,例如,它们是否可训练。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # suppress Tensorflow messages
import tensorflow as tf
from keras.layers import *
from keras.models import *
# Your custom layer
class Linear(Layer):
def __init__(self, units=32,**kwargs):
super(Linear, self).__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=False
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
在Linear 中,权重w 是可训练的,而偏差b 则不是。在这里,我正在为虚拟数据创建一个training loop,以可视化权重更新。
batch_size=10
input_shape=(batch_size,5,5)
## model
model = Sequential()
model.add(Input(shape=input_shape))
model.add(Linear(units=4,name='my_linear_layer'))
model.add(Dense(1))
## dummy dataset
x = tf.random.normal(input_shape) # dummy input
y = tf.ones((batch_size,1)) # dummy output
## loss functions and optimizer
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
### training loop
epochs = 3
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
tf.print(model.get_layer('my_linear_layer').get_weights())
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
logits = model(x, training=True) # Logits for this minibatch
# Compute the loss value for this minibatch.
loss_value = loss_fn(y, logits)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply_gradients(zip(grads, model.trainable_weights))
此循环返回以下结果,
Start of epoch 0
[array([[ 0.08920084, -0.04294993, 0.06111819, 0.08334437],
[-0.0369432 , -0.05014499, 0.0305218 , -0.07486793],
[-0.01227043, 0.09460627, -0.0560123 , 0.01324316],
[-0.00255878, 0.00214959, -0.02924518, 0.04721532],
[-0.05532415, -0.02014978, -0.06785563, -0.07330619]],
dtype=float32),
array([ 0.02154647, 0.05153348, -0.00128291, -0.06794706], dtype=float32)]
Start of epoch 1
[array([[ 0.08961578, -0.04327399, 0.06152926, 0.08325274],
[-0.03829437, -0.04908974, 0.02918325, -0.07456956],
[-0.01417133, 0.09609085, -0.05789544, 0.01366292],
[-0.00236284, 0.00199657, -0.02905108, 0.04717206],
[-0.05536905, -0.02011472, -0.06790011, -0.07329627]],
dtype=float32),
array([ 0.02154647, 0.05153348, -0.00128291, -0.06794706], dtype=float32)]
Start of epoch 2
[array([[ 0.09001605, -0.04358549, 0.06192534, 0.08316355],
[-0.03960795, -0.04806747, 0.02788337, -0.07427685],
[-0.01599812, 0.09751251, -0.05970317, 0.01406999],
[-0.00217021, 0.00184666, -0.02886046, 0.04712913],
[-0.05540781, -0.02008455, -0.06793848, -0.07328764]],
dtype=float32),
array([ 0.02154647, 0.05153348, -0.00128291, -0.06794706], dtype=float32)]
正如您所见,当权重 w 更新时,偏差 b 保持不变。