【问题标题】:How can torchaudio.transform.Resample be called without __call__ function inside?如何在没有 __call__ 函数的情况下调用 torchaudio.transform.Resample?
【发布时间】:2020-08-19 18:36:10
【问题描述】:
if sample_rate != sr:
        waveform = torchaudio.transforms.Resample(sample_rate, sr)(waveform)
        sample_rate = sr

我想知道这个 Resamle 是如何在那里工作的。因此,查看了 torchaudio 的文档。我以为会有 __call__ 功能。因为 Resample 被用作函数。我的意思是Resample()(waveform)。但在里面,只有 __init__ 和 forward 函数。我认为forward函数是工作函数,但我不知道为什么它被命名为'forward'而不是__call__。我错过了什么?

class Resample(torch.nn.Module):
    r"""Resample a signal from one frequency to another. A resampling method can be given.

    Args:
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
    """

    def __init__(self,
                 orig_freq: int = 16000,
                 new_freq: int = 16000,
                 resampling_method: str = 'sinc_interpolation') -> None:
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Output signal of dimension (..., time).
        """
        if self.resampling_method == 'sinc_interpolation':

            # pack batch
            shape = waveform.size()
            waveform = waveform.view(-1, shape[-1])

            waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

            # unpack batch
            waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

            return waveform

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))

--编辑--

我环顾了 torch.nn.module。没有定义 __call__。但只有 __call__ : Callable[..., Any] = _call_impl会这样吗?

【问题讨论】:

  • 嗯,它继承自torch.nn.Module,所以那里可能有一个__call__,或者层次结构中的某个地方
  • @juanpa.arrivillaga 嘿,和它有关系吗?即使那个 call 也不是 torch.nn.module 中的函数。你知道吗?
  • 看来肯定有电话,print(torch.nn.Module.__call__)
  • 太棒了!你让我@juanpa.arrivillaga。谢谢你。我不知道 call 可以在没有 'def' 的情况下声明。

标签: python class machine-learning deep-learning pytorch


【解决方案1】:

下面是forward 函数在PyTorch 中如何工作的简单类似演示。

检查一下:

from typing import Callable, Any

class parent:
    def _unimplemented_forward(self, *input):
        raise NotImplementedError

    def _call_impl(self, *args):
        # original nn.Module _call_impl function contains lot more code
        # to handle exceptions, to handle hooks and for other purposes
        self.forward(*args)
    
    forward : Callable[..., Any]  = _unimplemented_forward
    __call__ : Callable[..., Any] = _call_impl

class child(parent):
    def forward(self, *args):
        print('forward function')


class child_2(parent):
    pass

运行时:

>>> c1 = child_1()
>>> c1()
forward function
>>> c2 = child_2()
>>> c2()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".\callable.py", line 8, in _call_impl
    self.forward(*args)
  File ".\callable.py", line 5, in _unimplemented_forward
    raise NotImplementedError
NotImplementedError

【讨论】:

  • 对我真的很有帮助。谢谢,但你知道他们为什么这样做而不是使用 def call??
  • 我不太清楚。但也许他们正在使用这种复杂的实现来获得可以处理许多情况的通用代码。此外,他们需要通过构建dynamic 图表来处理AutoGrad。所以它可能会变得复杂。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多