【问题标题】:Custom cache with iterator does not work as intended带有迭代器的自定义缓存无法按预期工作
【发布时间】:2022-01-20 21:27:03
【问题描述】:

我有以下课程,其中:

iterable 是传递的参数,例如range(20)n_max 是一个可选值,它限制缓存应具有的元素数量,iterator 是一个使用可迭代对象启动的字段, cache 是我要填充的列表,finishedbool,它表示迭代器是否为“空”。这是一个示例输入:

>>> iterable = range(20)
>>> cachedtuple = CachedTuple(iterable)
>>> print(cachedtuple[0])
0
>>> print(len(cachedtuple.cache))
1
>>> print(cachedtuple[10])
10
>>> print(len(cachedtuple.cache))
11
>>> print(len(cachedtuple))
20
>>> print(len(cachedtuple.cache))
20
>>> print(cachedtuple[25])


@dataclass
class CachedTuple:
    iterable: Iterable = field(init=True)
    n_max: Optional[int] = None
    iterator: Iterator = field(init=False)
    cache: list = field(default_factory=list)
    finished: bool = False

    def __post_init__(self):
        self.iterator = iter(self.iterable)

    def cache_next(self):
        
        if self.n_max and self.n_max <= len(self.cache):
            self.finished = True
        else:
            try:
                nxt = next(self.iterator)
                self.cache.append(nxt)

            except StopIteration:
                self.finished = True


    def __getitem__(self, item: int):

        match item:
            case item if type(item) != int:
                raise IndexError

            case item if item < 0:
                raise IndexError

            case item if self.finished or self.n_max and item > self.n_max:
                raise IndexError(f"Index {item} out of range")

            case item if item >= len(self.cache):
                while item - len(self.cache) >= 0:
                    self.cache_next()

                return self.__getitem__(item)

            case item if item < len(self.cache):
                return self.cache[item]


    def __len__(self):

        while not self.finished:
            self.cache_next()
        return len(self.cache)

虽然这段代码肯定不好,但至少它适用于几乎所有场景,但以 Python 的 range 函数为例。如果我使用例如

cachedtuple = CachedTuple(range(20))
for element in cachedtuple:
    print(element)

我得到元素直到19,然后程序无限循环。我认为一个问题可能是我的代码中没有raise StopIteration。所以我有点迷失如何解决这个烂摊子。

【问题讨论】:

  • 我一定要问,match语句的意义何在?

标签: python caching python-dataclasses


【解决方案1】:

您的错误是由于以下几行造成的:

case item if item >= len(self.cache):
    while item - len(self.cache) >= 0:
        self.cache_next()

基本上,CachedTuple((1,2,3))[50] 会无限循环,因为50 大于缓存的长度,self.cache_next() 不会生成任何新值。

添加self.finished 检查的简单更改将起作用:

case item if item >= len(self.cache):
    while item - len(self.cache) >= 0 and not self.finished:
        self.cache_next()

不过,我相信您的代码还有许多其他问题,我认为您可以极大地改进它:

  1. 删除匹配语句。它什么都不做。
  2. 使用__iter__ 实现迭代,而不是依赖__getitem__ 的旧迭代机制。
  3. collections.abc.Sequence 继承并遵守Sequence 协议。
  4. 删除数据类。这不是数据类。您似乎很喜欢这些令人愉悦的新语言功能,但不幸的是,它们都不相关,这会导致您的代码更长、更不清晰,并且无法按预期工作。

请记住,简单易读的代码比使用新的语言功能更重要。


我冒昧地花了几个小时创建了一个符合collections.abc.Sequence 的示例代码。尽情享受吧!

from collections.abc import Sequence
import itertools
from typing import Iterable, Iterator, Optional, TypeVar, overload

_T_co =TypeVar("_T_co", covariant=True)

class CachedIterable(Sequence[_T_co]):
    def __init__(self, iterable: Iterable[_T_co], *, max_length: int = None) -> None:
        self._cache: list[_T_co] = []
        
        if max_length is not None:
            if max_length <= 0:
                raise ValueError('max_length must be > 0')
            iterable = itertools.islice(iterable, max_length)
        else:
            try:
                # Attempt to optimize and get a length.
                max_length = len(iterable)  # type: ignore
            except TypeError:
                max_length = None

        self._max_length = max_length
        self._iterator: Optional[Iterator] = iter(iterable)
    
    def __repr__(self) -> str:
        return (f'<{self.__class__.__name__} {self._cache!r}'
                f'{"+" if self._iterator else ""}>')
    
    def _exhaust_iterator(self) -> None:
        """Fully exhaust the iterator."""
        assert self._iterator
        try:
            self._cache.extend(self._iterator)
        finally:
            self._iterator = None

    def _advance_iterator(self, n: int) -> None:
        """Attempt to advance the iterator by n steps.

        May advance by less than n steps if the iterator is exhausted.
        """
        assert self._iterator
        
        pre_advance_length = len(self._cache)

        try:
            self._cache.extend(itertools.islice(self._iterator, n))
        except Exception:
            # Iterator threw an exception.
            self._iterator = None
            raise

        # If iterator exhausted, clear it.
        if pre_advance_length + n > len(self._cache):
            self._iterator = None
        
    def _grow_cache(self, size: int) -> None:
        """Atttempt grow the cache to be at least size.
        
        May grow to less than size if the iterator is exhausted.
        """
        if size <= len(self._cache):
            return

        if self._max_length and size >= self._max_length:
            self._exhaust_iterator()
            return
        
        self._advance_iterator(size - len(self._cache))
    
    @overload
    def __getitem__(self, i: int) -> _T_co: ...

    @overload
    def __getitem__(self, s: slice) -> Sequence[_T_co]: ...
        
    def __getitem__(self, index):
        if not isinstance(index, (slice, int)):
            raise TypeError(f'index must be int or slice, not {index!r}')

        if not self._iterator:
            return self._cache[index]

        if isinstance(index, slice):
            # Stop might be less than start if step is negative.
            max_index = max(index.stop or 0, index.start or 0)
            
            # If we're counting from the end, exaust the iterator.
            if (index.stop is not None and index.stop < 0 or
                    index.start is not None and index.start < 0):
                self._exhaust_iterator()
            
            else:
                self._grow_cache(max_index + 1)

            return self._cache[index]

        # Asking for a number beyond the limit.
        if self._max_length and index > self._max_length:
            raise IndexError(f'index {index} out of range')

        # If we're counting from the end, exaust the iterator.
        if index < 0:
            self._exhaust_iterator()
        else:
            self._grow_cache(index + 1)

        return self._cache[index]
    
    def __iter__(self) -> Iterator[_T_co]:
        if not self._iterator:
            yield from self._cache
            return
        
        yield from self._cache
        while True:
            try:
                item = next(self._iterator)
                # Iterator threw an exception.
            except StopIteration:
                self._iterator = None
                return
            except BaseException:
                self._iterator = None
                raise
            
            self._cache.append(item)
            # Prevent capturing GeneratorExit and other gen.throw() exceptions.
            yield item


    def __len__(self) -> int:
        # TODO: Can optimize for known lengths.
        if not self._iterator:
            return len(self._cache)

        self._exhaust_iterator()
        return len(self._cache)

【讨论】:

  • 我同意,我只是在练习这些功能以了解它们,但非常感谢您的建议!那么这就是标准列表函数的 getitem 的工作原理吗?它会增加索引,直到索引“太大”?
  • @kklaw 是正确的。
猜你喜欢
  • 2015-07-10
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2013-04-05
  • 1970-01-01
  • 2018-10-02
  • 2017-11-15
  • 1970-01-01
相关资源
最近更新 更多