【发布时间】:2020-08-19 18:36:10
【问题描述】:
if sample_rate != sr:
waveform = torchaudio.transforms.Resample(sample_rate, sr)(waveform)
sample_rate = sr
我想知道这个 Resamle 是如何在那里工作的。因此,查看了 torchaudio 的文档。我以为会有 __call__ 功能。因为 Resample 被用作函数。我的意思是Resample()(waveform)。但在里面,只有 __init__ 和 forward 函数。我认为forward函数是工作函数,但我不知道为什么它被命名为'forward'而不是__call__。我错过了什么?
class Resample(torch.nn.Module):
r"""Resample a signal from one frequency to another. A resampling method can be given.
Args:
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
"""
def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation') -> None:
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.resampling_method = resampling_method
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
# unpack batch
waveform = waveform.view(shape[:-1] + waveform.shape[-1:])
return waveform
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
--编辑--
我环顾了 torch.nn.module。没有定义 __call__。但只有
__call__ : Callable[..., Any] = _call_impl会这样吗?
【问题讨论】:
-
嗯,它继承自
torch.nn.Module,所以那里可能有一个__call__,或者层次结构中的某个地方 -
@juanpa.arrivillaga 嘿,和它有关系吗?即使那个 call 也不是 torch.nn.module 中的函数。你知道吗?
-
看来肯定有电话,
print(torch.nn.Module.__call__) -
太棒了!你让我@juanpa.arrivillaga。谢谢你。我不知道 call 可以在没有 'def' 的情况下声明。
标签: python class machine-learning deep-learning pytorch