【发布时间】:2022-02-24 18:56:09
【问题描述】:
请对您的想法添加最低限度的评论,以便我改进查询。谢谢你。 -)
我正在尝试了解和实施一项关于 Triple Attention Learning 的研究工作,其中包括
- channel-wise attention (a)
- element-wise attention (b)
- scale-wise attention (c)
该机制在DenseNet 模型中进行了实验性集成。整个模型图的拱门是here。 channel-wise 注意力模块只不过是 squeeze 和激发 块。这为 element-wise 注意力模块提供了sigmoid 输出。下面是这些模块(a、b 和 c)的更精确的功能流程图。
理论
在大多数情况下,我能够理解并实现它,但在Element-Wise 注意部分(上图中的b 部分)有点迷失。这是我需要你帮助的地方。 -)
这里有一个关于这个主题的小理论,可以让您大致了解这一切是关于什么的。请注意,该论文现在无法公开访问,但在出版商页面上发布的早期阶段它是免费获取的,我当时保存了它。为了公平起见,我将与您分享,Link。无论如何,从论文(第 4.3 节)中可以看出:
所以首先,f(att)函数(在第一个inplace图中,左中部分或b)由三个具有512的卷积层组成 内核使用1 x 1,512 内核使用3 x 3 和C 内核使用1 x 1。这里C是分类器的编号。并通过Softmax 激活!
接下来,它适用于Channel-Wise 注意力模块,我们提到它只是一个SENet 模块并给出了sigmoid 概率分数,即X(CA)。因此,从f(att) 的函数中,我们得到C 乘以softmax 的概率分数,每个分数都乘以sigmoid 输出,最后生成特征图A(根据等式4 上图)。
第二,有一个C 线性分类器,它实现为1 x 1 - C 内核卷积层。该层还应用于SENet 模块的输出,即X(CA),以像素为单位应用于每个特征向量。最后,它给出了特征图S的输出(等式5如下图所示)。
并且第三,他们将每个置信度分数(S)与相应的注意力元素A 相乘。这种乘法是故意的。他们这样做是为了防止对特征图产生不必要的关注。为了使其有效,他们还使用weighted cross-entropy 损失函数在分类ground truth和score vector之间最小化它。
我的查询
大多数情况下,我没有正确获得网络中间的最小化策略。我希望有人能给我正确理解和实施上述文书工作(4.3节)中提出的这种“元素注意机制”的详细信息。
实施
这是开始使用的最低代码。我猜应该够了。这是一种肤浅的实现,但与原始的元素模块相距甚远。我不确定如何正确实施它。现在,我希望它作为一个可以即插即用任何模型的层。我正在尝试使用 MNIST 和一个简单的 Conv 网络。
总而言之,对于 MNIST,我们应该有一个包含 channel-wise 和 element-wise 注意力模型的网络,然后是最后一个 10 单元 softmax 层。比如:
Net: Conv2D - Attentions-Module - GAP - Softmax(10)
Attention-Module 由以下两部分组成:Channel-wise 和 Element-wise,Element-wise 也应该有 Softmax,以将加权 CE 损失函数最小化为 ground-truth 和 score vector来自这个模块(根据文件,上面也已经描述过)。该模块还将加权特征图传递给连续层。为了更清楚,这里是我们正在寻找的简单示意图
好的,对于 channel-wise 注意,它应该给我们一个单一的概率分数 (sigmoid),为了简单起见,我们现在使用一个假层:
class FakeSE(tf.keras.layers.Layer):
def __init__(self):
super(Block, self).__init__()
# conv layer
self.conv = tf.keras.layers.Conv2D(10, padding='same',
kernel_size=3)
def call(self, input_tensor, training=False):
x = self.conv(input_tensor)
return tf.math.sigmoid(x)
对于element-wise 注意部分,以下是迄今为止失败的尝试:
class ElementWiseAttention(tf.keras.layers.Layer):
def __init__(self):
# for simplicity the f(attn) function here has 2 convolution instead of 3
# self.conv1, and self.conv2
self.conv1 = tf.keras.layers.Conv2D(16,
kernel_size=1,
strides=1, padding='same',
use_bias=True, activation=tf.nn.silu)
self.conv2 = tf.keras.layers.Conv2D(10,
kernel_size=1,
strides=1, padding='same',
use_bias=False, activation=tf.keras.activations.softmax)
# fake SENet or channel-wise attention module
self.cam = FakeSE()
# a linear layer
self.linear = tf.keras.layers.Conv2D(10,
kernel_size=1,
strides=1, padding='same',
use_bias=True, activation=None)
super(ElementWiseAttention, self).__init__()
def call(self, inputs):
# 2 stacked conv layer (in paper, it's 3. we set 2 for simplicity)
# this is the f(att)
x = self.conv1(inputs)
x = self.conv2(x)
# this is the A = f(att)*X(CA)
camx = self.cam(x)*x
# this is S = X(CA)*Linear_Classifier
linx = self.cam(self.linear(inputs))
# element-wise multiply to prevent unnecessary attention
# suppose to minimize with weighted cross entorpy loss
out = tf.multiply(camx, linx)
return out
上面是兴趣层。如果我正确理解了论文的话,这一层不仅应该将加权损失函数最小化为gt 和score_vector,而且还应该生成一些加权特征图(2D)。
运行
这是玩具数据
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1)
x_train = x_train.astype('float32') / 255
x_train = tf.image.resize(x_train, [32,32]) # if we want to resize
y_train = tf.keras.utils.to_categorical(y_train , num_classes=10)
# Model
input = tf.keras.Input(shape=(32,32,1))
efnet = tf.keras.applications.DenseNet121(weights=None,
include_top = False,
input_tensor = input)
em = ElementWiseAttention()(efnet.output)
# Now that we apply global max pooling.
gap = tf.keras.layers.GlobalMaxPooling2D()(em)
# classification layer.
output = tf.keras.layers.Dense(10, activation='softmax')(gap)
# bind all
func_model = tf.keras.Model(efnet.input, output)
func_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
# fit
func_model.fit(x_train, y_train, batch_size=32, epochs=3, verbose = 1)
【问题讨论】:
标签: python tensorflow machine-learning keras deep-learning