Keras 模型可以完全动态实现,以支持您提到的高效路由。以下示例显示了一种可以完成此操作的方法。该示例是在以下前提下编写的:
- 假设有两个专家(
LayerA 和 LayerB)
- 假设混合专家模型 (
MixOfExpertsModel) 根据 Keras Dense 层的每个示例输出在两个专家层类之间动态切换
- 它满足了对模型进行批量训练的需要。
注意代码中的cmets,看看是怎么切换的。
import numpy as np
import tensorflow as tf
# This is your Expert A class.
class LayerA(tf.keras.layers.Layer):
def build(self, input_shape):
self.weight = self.add_weight("weight_a", shape=input_shape[1:])
@tf.function
def call(self, x):
return x + self.weight
# This is your Expert B class.
class LayerB(tf.keras.layers.Layer):
def build(self, input_shape):
self.weight = self.add_weight("weight_b", shape=input_shape[1:])
@tf.function
def call(self, x):
return x * self.weight
class MixOfExpertsModel(tf.keras.models.Model):
def __init__(self):
super(MixOfExpertsModel, self).__init__()
self._expert_a = LayerA()
self._expert_b = LayerB()
self._gating_layer = tf.keras.layers.Dense(1, activation="sigmoid")
@tf.function
def call(self, x):
z = self._gating_layer(x)
# The switching logic:
# - examples with gating output <= 0.5 are routed to expert A
# - examples with gating output > 0.5 are routed to expert B.
mask_a = tf.squeeze(tf.less_equal(z, 0.5), axis=-1)
mask_b = tf.squeeze(tf.greater(z, 0.5), axis=-1)
# `input_a` is a subset of slices of the original input (`x`).
# So is `input_b`. As such, no compute is wasted.
input_a = tf.boolean_mask(x, mask_a, axis=0)
input_b = tf.boolean_mask(x, mask_b, axis=0)
if tf.size(input_a) > 0:
output_a = self._expert_a(input_a)
else:
output_a = tf.zeros_like(input_a)
if tf.size(input_b) > 0:
output_b = self._expert_b(input_b)
else:
output_b = tf.zeros_like(input_b)
# Return `mask_a`, and `mask_b`, so that the caller can know
# which example is routed to which expert and whether its output
# appears in `output_a` or `output_b`. # This is necessary
# for writing a (custom) loss function for this class.
return output_a, output_b, mask_a, mask_b
# Create an intance of the mix-of-experts model.
mix_of_experts_model = MixOfExpertsModel()
# Generate some dummy data.
num_examples = 32
xs = np.random.random([num_examples, 8]).astype(np.float32)
# Call the model.
print(mix_of_experts_model(xs))
我没有编写自定义损失函数来支持该课程的训练。但这可以通过使用MixOfExpertsModel.call() 的返回值来实现,即输出和掩码。