【发布时间】:2020-10-08 01:13:38
【问题描述】:
您好,我是 Pytorch 和 Torch 张量的新手。我正在阅读 yolo_v3 代码并遇到这个问题。我认为它与... 的张量索引有关,但很难通过谷歌搜索...,所以我想在这里问它。代码是:
prediction = (
x.view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size)
.permute(0, 1, 3, 4, 2)
.contiguous()
)
print (prediction.shape)
# Get outputs
x = torch.sigmoid(prediction[..., 0]) # Center x
y = torch.sigmoid(prediction[..., 1]) # Center y
w = prediction[..., 2] # Width
h = prediction[..., 3] # Height
pred_conf = torch.sigmoid(prediction[..., 4]) # Conf
pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.
我的理解是预测将是一个形状为 [batch, anchor, x_grid, y_grid, class] 的张量。但是 prediction[..., x] 做了什么(x=0,1,2,3,4,5)?它是否类似于 [:, x] 的 numpy 索引?如果是这样,x、y、w、h、pred_conf 和 pred_cls 的计算就没有意义了。
【问题讨论】:
标签: python numpy pytorch tensor