【问题标题】:Keras/TensorFlow equivalent of PyTorch Conv1d相当于 PyTorch Conv1d 的 Keras/TensorFlow
【发布时间】:2020-09-24 19:05:38
【问题描述】:

我目前正在将 PyTorch 代码转换为 TensorFlow (Keras)。使用的层之一是 Conv1d,如何在 PyTorch 中使用它的描述如下:

torch.nn.Conv1d(in_channels: int, out_channels: int, kernel_size: Union[T, Tuple[T]], stride: Union[T, Tuple[T]] = 1, padding: Union[T, Tuple[T]] = 0, dilation: Union[T, Tuple[T]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros')

在 Keras (TF1.15) 中,描述为

tf.keras.layers.Conv1D(filters, kernel_size, strides=1, padding='valid', data_format='channels_last', dilation_rate=1, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs)

我无法在 TensorFlow 中重现我在 PyTorch 中获得的相同输出。即 PyTorch 中的示例代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

B, K, L, N = 4, 10, 128, 64
mixture = torch.randint(3, (B, K, L), dtype=torch.float32)
# L2 Norm along L axis
EPS = 1e-8
norm_coef = torch.norm(mixture, p=2, dim=2, keepdim=True)  # B x K x 1
norm_mixture = mixture / (norm_coef + EPS)  # B x K x L
# 1-D gated conv
norm_mixture = torch.unsqueeze(norm_mixture.view(-1, L), 2)  # B*K x L x 1
conv1d_U = nn.Conv1d(L, N, kernel_size=1, stride=1, bias=False)
conv_out = conv1d_U(norm_mixture)
conv_out = F.relu(conv_out)  # B*K x N x 1
mixture_w = conv_out.view(B, K, N)  # B x K x N

weights = conv1d_U.weight.data

在 TensorFlow 中,为了获得相似的输出尺寸,可以在下面找到代码

import tensorflow as tf
import numpy as np

def build_net():
    # Encoder
    mixture = tf.keras.layers.Input(shape=(10, 128), name='mixture', batch_size=4)  # [B,K,L]
    norm_coef = tf.keras.backend.sqrt(tf.keras.backend.sum(mixture ** 2, axis=2, keepdims=True) + 1e-8)  # [B,K,1]
    norm_mixture = mixture / norm_coef  # [B, K, L]
    norm_mixture = tf.keras.backend.expand_dims(tf.keras.backend.reshape(norm_mixture, [-1, 128]), axis=2)  # [B*K,L,1]
    conv = tf.keras.layers.Conv1D(filters=64, kernel_size=1, activation='relu', use_bias=False, name='conv')(norm_mixture)  # [B*K,N,1]
    mixture_w = tf.keras.backend.reshape(conv, [4, -1, 64])  # [B, K, N]
    return tf.keras.models.Model(inputs=mixture, outputs=mixture_w)

model = build_net()
weights = model.get_weights()
inp = np.random.randn(4, 10, 128)
out = model.predict(inp)

比较两种情况下权重的维度,Conv1d 操作明显不同于 TensorFlow (Keras),应该如何更改 TF 代码以反映相同的操作?

【问题讨论】:

  • 为什么权重应该一样?在任何情况下,权重初始化取决于任一框架使用的伪随机生成器。
  • 我不希望权重相同,我希望复制操作。权重的维度是检查 Conv1d 内核是否会导致类似操作的代理。我改变了问题的形式以反映这一点
  • keras的Conv1D中没有提供in通道的数量(它是从上一层,本例中的输入层衍生而来的)。 filters 是输出通道数:conv = tf.keras.layers.Conv1D(filters=64, kernel_size=1, activation='relu', use_bias=False, name='conv')(norm_mixture)
  • @Max 这是我认为我应该从文档中执行的操作,但是 Conv1D 操作的输出维度是 (40,128,64) 但在 PyTorch 中我们得到维度 (40,64,1 )。所以再次对数据执行的操作不匹配。
  • 我明白了。问题是 keras 使用“通道最后”格式,即最后一个维度始终是卷积层中通道/过滤器/特征图的数量。之前的维度是一维卷积层中的序列维度。请在下面查看我的建议,实际上只需要修改 expand_dims 轴。

标签: python tensorflow keras pytorch


【解决方案1】:
import tensorflow as tf
import numpy as np

def build_net():
    # Encoder
    mixture = tf.keras.layers.Input(shape=(10, 128), name='mixture', batch_size=4)  # [B,K,L]
    norm_coef = tf.keras.backend.sqrt(tf.keras.backend.sum(mixture ** 2, axis=2, keepdims=True) + 1e-8)  # [B,K,1]
    norm_mixture = mixture / norm_coef  # [B, K, L]
    norm_mixture = tf.keras.backend.expand_dims(tf.keras.backend.reshape(norm_mixture, [-1, 128]), axis=1)  # [B*K,1,L]
    conv = tf.keras.layers.Conv1D(filters=64, kernel_size=1, activation='relu', use_bias=False, name='conv')(norm_mixture)  # [B*K,1,N]
    mixture_w = tf.keras.backend.reshape(conv, [4, -1, 64])  # [B, K, N]
    return tf.keras.models.Model(inputs=mixture, outputs=mixture_w)

model = build_net()
print(model.summary())
weights = model.get_weights()
inp = np.random.randn(4, 10, 128)
out = model.predict(inp)

【讨论】:

    猜你喜欢
    • 2021-12-04
    • 2021-03-19
    • 1970-01-01
    • 2022-07-11
    • 2023-02-22
    • 2019-11-11
    • 2019-09-21
    • 2020-11-06
    • 2018-06-21
    相关资源
    最近更新 更多