【问题标题】:Defining python class method using arguments from __init__使用来自 __init__ 的参数定义 python 类方法
【发布时间】:2023-03-28 20:13:01
【问题描述】:

我想要达到的目标

我想定义一个有A和B两种模式的类,使该类的forward方法相应变化。

class MyClass():
    def __init__(self, constant):
        self.constant=constant
            
    def forward(self, x1,x2,function):
        if function=='A':
            return x1+self.constant
        
        elif function=='B':
            return x1*x2+self.constant
        
        else:
            print('please provide the correct function')
model1 = MyClass(2)
model1.forward(2, None, 'A')
output>>>4
model2 = MyClass(2)
model2.forward(2, 2, 'B')
output>>>6

它可以工作,但不是最优的,因为每次调用forward 方法时,它都会检查要使用的函数。但是,前向函数已经设置好了,一旦定义了类就永远不会更改,因此,在我的情况下,检查在forward 中使用哪个函数是超级多余的。 (对于那些注意到这一点的人,我正在使用 PyTorch 编写我的神经网络模型,两个模型共享 90% 的网络架构,仅有 10% 的差异是它们的前馈方式。

我想要的版本

我想在定义类的时候设置forward方法,这样就可以实现了

model1 = MyClass(2, 'A')
model1.forward(2)
output>>>4

model2 = MyClass(2, 'B')
model2.forward(2, 2)
output>>>6

所以我将我的课程改写为:

class MyClass():
    def __init__(self, constant, function):
        self.constant=constant # There would be a lot of shared parameters for the two methods
        self.function=function # This controls the feedforward method of this class
        
    if self.function=='A':
        def forward(self, x1):
            return x1+self.constant
        
    elif self.function=='B':
        def forward(self, x1, x2):
            return x1*x2+self.constant
        
    else:
        print('please provide the correct function')

但是,它给了我以下错误。

NameError: name 'self' 未定义

我该如何编写基于__init__ 的参数定义不同forward 方法的类?

【问题讨论】:

    标签: python oop pytorch


    【解决方案1】:

    您一直在尝试用您的代码重新定义 ,这样每个新对象都会在前后更改 所有 对象的 forward 定义。 幸运的是,你不知道该怎么做。

    相反,将所选函数设为对象的属性。编写所需的两个函数,然后在创建每个实例时分配所需的变体。

    class MyClass():
        def __init__(self, constant, function):
            self.constant=constant
            if function == 'A':
                self.forward = self.forwardA
            elif function=='B':
                self.forward = self.forwardB
            else:
                print('please provide the correct function')
                
        def forwardA(self, x1):
            return x1+self.constant
            
        def forwardB(self, x1, x2):
            return x1*x2+self.constant
    
    # Main
    model1 = MyClass(2, 'A')
    print(model1.forward(2))
    
    model2 = MyClass(2, 'B')
    print(model2.forward(2, 2))
    

    输出:

    4
    6
    

    【讨论】:

      【解决方案2】:

      您也可以尝试分解基类。它可能会与 mypy 一起玩得更好,并且更容易不混淆你正在使用的任何类。

      class MyClassBase():                                      
          def __init__(self, constant):                         
               self.constant=constant                           
                                                                
          def forward(self, *args, **kwargs):              
              raise NotImplementedError('use a derived class')  
                                                                
      class MyClassA(MyClassBase):                              
          def __init__(self, constant):                         
              super().__init__(constant)                        
                                                                
          def forward(self,x1):                                 
              return x1 + self.constant                         
                                                                
      class MyClassB(MyClassBase):                              
          def __init__(self, constant):                         
              super().__init__(constant)                        
                                                                
          def forward(self, x1, x2):                            
              return x1*x2 + self.constant                      
                                                                
      a = MyClassA(2)                                           
      b = MyClassB(2)                                           
                                                                
      print(a.forward(2))                                       
      print(b.forward(2,2))                                     
      

      【讨论】:

        【解决方案3】:

        我想说,如果您要在类方法的返回中使用常量变量名,我们可以通过这种方式定义它来调整逻辑。 (我们可以用不同的方式来做这件事,只和 Params 一起玩)希望这看起来不错。

        class MyClass():
            def __init__(self, constant, function):
                self.constant=constant
                self.function=function
            
            def forward(self, x1 = None, x2 = None):
                if self.function=='A':
                    return x1+self.constant
                elif self.function=='B':
                    return x1*x2+self.constant
                else:
                    print('please provide the correct function')
        
        model1 = MyClass(2, 'A')
        model1.forward(2)
        
        model2 = MyClass(2, 'B')
        model2.forward(2, 2)
        

        【讨论】:

          【解决方案4】:

          定义

          在这种情况下,我也会去继承;给定你想要的:

          model2 = MyClass(2, 'B')
          

          这样做会更容易(并且对其他人来说更具可读性):

          model2 = MyClassB(2)
          

          鉴于此,类似于 @Nathan Chappell provided in his answer 但更短(例如,无需重新定义 __init__):

          import torch
          
          
          class Base(torch.nn.Module):
              def __init__(self, constant):
                  super().__init__()
                  self.constant = constant
          
          
          class MyClassA(Base):
              def forward(self, x1):
                  return x1 + self.constant
          
          
          class MyClassB(Base):
              def forward(self, x1, x2):
                  return x1 * x2 + self.constant
          

          调用

          您应该使用torch.nn.Module__call__ 方法而不是forward,因为它可以与钩子一起正常工作(请参阅this answer),因此应该是:

          model1 = MyClassA(5)
          model1(torch.randn(10, 5))
          

          代替:

          model1 = MyClassA(5)
          model1.forward(torch.randn(10, 5))
          

          【讨论】:

            猜你喜欢
            • 2014-03-20
            • 1970-01-01
            • 2015-02-28
            • 1970-01-01
            • 2017-06-18
            • 2020-11-15
            • 1970-01-01
            • 2017-05-07
            • 2021-10-15
            相关资源
            最近更新 更多