【发布时间】:2020-03-24 06:28:02
【问题描述】:
我有一个关于 python 中迭代行为的具体问题。我的可迭代是 pytorch 中自定义构建的 Dataset 类:
import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
def __init__(self, X):
self.X = X
def __len__(self):
return len(self.X)
def __getitem__(self, x):
print('***********')
print('getitem x = ', x)
print('###########')
y = self.X[x]
print('getitem y = ', y)
return y
现在,当我初始化该 datasetTest 类的特定实例时,就会出现奇怪的行为。根据我作为参数 X 传递的数据结构,当我调用 list(datasetTestInstance) 时,它的行为会有所不同。特别是,当传递一个 torch.tensor 作为参数时没有问题,但是当传递一个 dict 作为参数时,它会抛出一个 KeyError。原因是 list(iterable) 不仅调用 i=0, ..., len(iterable)-1,而且调用 i=0, ..., len(iterable)。也就是说,它将迭代直到(包括)索引等于可迭代的长度。显然,这个索引在任何 python 数据结构中都没有定义,因为最后一个元素总是有索引 len(datastructure)-1 而不是 len(datastructure)。如果 X 是 torch.tensor 或列表,则不会出现错误,即使我认为应该是错误。即使对于索引为 len(datasetTestinstance) 的(不存在的)元素,它仍然会调用 getitem,但它不会计算 y=self.X[len(datasetTestInstance]。有谁知道 pytorch 是否在内部以某种方式优雅地处理这个问题?
当将 dict 作为数据传递时,它将在最后一次迭代中抛出错误,此时 x=len(datasetTestInstance)。这实际上是我猜想的预期行为。但为什么这只发生在 dict 而不是 list 或 torch.tensor?
if __name__ == "__main__":
a = datasetTest(torch.randn(5,2))
print(len(a))
print('++++++++++++')
for i in range(len(a)):
print(i)
print(a[i])
print('++++++++++++')
print(list(a))
print('++++++++++++')
b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
print(len(b))
print('++++++++++++')
for i in range(len(b)):
print(i)
print(b[i])
print('++++++++++++')
print(list(b))
如果您想更好地理解我所观察到的内容,可以尝试使用该 sn-p 代码。
我的问题是:
1.) 为什么 list(iterable) 会迭代直到(包括)len(iterable)? for 循环不会这样做。
2.) 在作为数据 X 传递的 torch.tensor 或列表的情况下:为什么即使调用索引 len(datasetTestInstance) 的 getitem 方法也不会引发错误,因为它实际上应该超出范围未定义为张量/列表中的索引?或者,换句话说,当到达索引 len(datasetTestInstance) 然后进入 getitem 方法时,究竟发生了什么?它显然不再调用'y = self.X[x]'(否则会有一个IndexError),但它确实进入了getitem方法,我可以看到它从getitem方法中打印索引x。那么这种方法会发生什么?为什么它的行为会根据是 torch.tensor/list 还是 dict 而有所不同?
【问题讨论】:
-
关于第1点。通常我们使用
for item in b:通过for循环迭代可迭代类型。在这种情况下,python 期望b引发IndexError以指示已到达列表末尾。 (有关特定文档链接,请参阅我的答案)。 -
你为什么不使用迭代器协议?以这种方式使对象可迭代是非常过时的,并且只保留不破坏向后兼容性,AFAIK
标签: python list dictionary pytorch iterable