【发布时间】:2023-03-28 20:13:01
【问题描述】:
我想要达到的目标
我想定义一个有A和B两种模式的类,使该类的forward方法相应变化。
class MyClass():
def __init__(self, constant):
self.constant=constant
def forward(self, x1,x2,function):
if function=='A':
return x1+self.constant
elif function=='B':
return x1*x2+self.constant
else:
print('please provide the correct function')
model1 = MyClass(2)
model1.forward(2, None, 'A')
output>>>4
model2 = MyClass(2)
model2.forward(2, 2, 'B')
output>>>6
它可以工作,但不是最优的,因为每次调用forward 方法时,它都会检查要使用的函数。但是,前向函数已经设置好了,一旦定义了类就永远不会更改,因此,在我的情况下,检查在forward 中使用哪个函数是超级多余的。 (对于那些注意到这一点的人,我正在使用 PyTorch 编写我的神经网络模型,两个模型共享 90% 的网络架构,仅有 10% 的差异是它们的前馈方式。
我想要的版本
我想在定义类的时候设置forward方法,这样就可以实现了
model1 = MyClass(2, 'A')
model1.forward(2)
output>>>4
model2 = MyClass(2, 'B')
model2.forward(2, 2)
output>>>6
所以我将我的课程改写为:
class MyClass():
def __init__(self, constant, function):
self.constant=constant # There would be a lot of shared parameters for the two methods
self.function=function # This controls the feedforward method of this class
if self.function=='A':
def forward(self, x1):
return x1+self.constant
elif self.function=='B':
def forward(self, x1, x2):
return x1*x2+self.constant
else:
print('please provide the correct function')
但是,它给了我以下错误。
NameError: name 'self' 未定义
我该如何编写基于__init__ 的参数定义不同forward 方法的类?
【问题讨论】: