【问题标题】:Tensorflow equivalent to torch.Tensor.index_copyTensorflow 等效于 torch.Tensor.index_copy
【发布时间】:2021-05-31 16:26:15
【问题描述】:

我正在实现 等效于模型 here 最初使用 实现的。一切都很顺利,直到我遇到了这行代码。

batch_current = Variable(torch.zeros(size, self.embedding_dim))

# self.embedding and self.W_c are pytorch network layers I have created
batch_current = self.W_c(batch_current.index_copy(0, Variable(torch.LongTensor(index)),
                                                         self.embedding(Variable(self.th.LongTensor(current_node)))))

如果搜索index_copy 的文档,似乎它所做的只是在某个索引和公共轴上复制一组元素并将其分配给另一个张量。但我真的不想写一些有问题的代码,所以在尝试任何自我实现之前,我想知道你们是否知道我可以如何去实现它。

模型来自这个paper,是的,我搜索了其他 实现,但它们对我来说似乎没有多大意义。

【问题讨论】:

    标签: tensorflow pytorch tensorflow python tensorflow pytorch


    【解决方案1】:

    您需要的是 中的tf.tensor_scatter_nd_update 来获得类似Tensor.index_copy_ 的等效操作。下面是一个演示。

    ,你有

    import torch 
    
    tensor = torch.zeros(5, 3)
    indices = torch.tensor([0, 4, 2])
    updates= torch.tensor([[1, 2, 3], 
                           [4, 5, 6], 
                           [7, 8, 9]], dtype=torch.float)
    tensor.index_copy_(0, indices, updates)
    
    tensor([[1., 2., 3.],
            [0., 0., 0.],
            [7., 8., 9.],
            [0., 0., 0.],
            [4., 5., 6.]])
    

    ,你可以这样做

    import tensorflow as tf
    
    tensor = tf.zeros([5,3])
    indices = tf.constant([[0], [4], [2]])
    updates  = tf.constant([[1, 2, 3], 
                            [4, 5, 6], 
                            [7, 8, 9]], dtype=tf.float32)
    tensor = tf.tensor_scatter_nd_update(tensor, indices, updates)
    tensor.numpy()
    array([[1., 2., 3.],
           [0., 0., 0.],
           [7., 8., 9.],
           [0., 0., 0.],
           [4., 5., 6.]], dtype=float32)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-08-01
      • 2020-09-26
      • 2020-03-28
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-06-21
      • 2021-10-05
      相关资源
      最近更新 更多