【发布时间】:2020-12-01 14:37:23
【问题描述】:
我正在将 TensorFlow 代码转换为 PyTorch 代码。
下面几行是我还不能解决的问题。
我对 PyTorch 不熟悉,所以在 PyTorch 库中找到匹配的方法并不容易。
谁能帮帮我?
p.s. alpha的形状是(batch, N)。
alpha_cumsum = tf.cumsum(alpha, axis = 1)
len_batch = tf.shape(alpha_cumsum)[0]
rand_prob = tf.random_uniform(shape = [len_batch, 1], minval = 0., maxval = 1.)
alpha_relu = tf.nn.relu(rand_prob - alpha_cumsum)
alpha_index = tf.count_nonzero(alpha_relu, 1)
alpha_hard = tf.one_hot(alpha_index, len(a))
【问题讨论】:
标签: tensorflow pytorch