【问题标题】:What is a buffer in Pytorch?Pytorch 中的缓冲区是什么?
【发布时间】:2020-04-24 11:43:26
【问题描述】:

我了解register_buffer 的作用以及register_buffer and register_parameters 之间的区别。

但是 PyTorch 中缓冲区的准确定义是什么?

【问题讨论】:

  • 第 4 段可能是相关的:stackoverflow.com/a/57546078/365102
  • 听起来它们只是张量(又名数据),在训练期间没有被修改。
  • 没错,但我正在寻找更具体的定义。 IE。缓冲区是一个张量,requires_grad 等于 False?
  • @Berriel - 我喜欢你之前的回答...
  • @luminicentauri 是的

标签: python pytorch


【解决方案1】:

这可以通过implementation来回答:

def register_buffer(self, name, tensor):
    if '_buffers' not in self.__dict__:
        raise AttributeError(
            "cannot assign buffer before Module.__init__() call")
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("buffer name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("buffer name can't contain \".\"")
    elif name == '':
        raise KeyError("buffer name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._buffers:
        raise KeyError("attribute '{}' already exists".format(name))
    elif tensor is not None and not isinstance(tensor, torch.Tensor):
        raise TypeError("cannot assign '{}' object to buffer '{}' "
                        "(torch Tensor or None required)"
                        .format(torch.typename(tensor), name))
    else:
        self._buffers[name] = tensor

即缓冲区的名称:

  • 必须是字符串:not isinstance(name, torch._six.string_classes)
  • 不能包含.(点):'.' in name
  • 不能为空字符串:name == ''
  • 不能是模块的属性:hasattr(self, name)
  • 应该是唯一的:name not in self._buffers

还有tensor(你猜怎么着?):

  • 应该是张量:isinstance(tensor, torch.Tensor)

因此,缓冲区只是具有这些属性的张量,注册在 Module_buffers 属性中;

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2012-06-21
    • 2010-10-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多