【问题标题】:numpy subclass will not accept arguments to __new__ from pythonically inheriting classnumpy 子类将不接受来自 python 继承类的 __new__ 参数
【发布时间】:2012-01-26 22:23:48
【问题描述】:

我创建了一个名为“Parray”的 ndarray 子类,它接受两个参数:p 和维度。它自己工作得很好。现在,我想创建一个名为 SirPlotsAlot 的类,它继承了 Parray,没有花哨的 newarray_finalize 等。

import numpy as np

class Parray(np.ndarray):
    def __new__(self, p = Parameters(), dimensionality = 2):

        print "Initializing Parray with initial dimensionality %s..." % dimensionality

        self.p = p # store the parameters

        if dimensionality == 2:
            shape = (p.nx, p.ny)
            self.pshape = shape
        elif dimensionality == 3:
            shape=(p.nx, p.ny, p.nx)
            self.pshape = shape
        else:
            raise NotImplementedError, "dimensionality must be 2 or 3"

        # ...Set other variables (ellided)

        subarr = np.ndarray.__new__(self, shape, dtype, buffer, offset, strides, order)
        subarr[::] = np.zeros(self.pshape) # initialize to zero
        return subarr
...

class SirPlotsAlot(Parray):
    def __init__(self, p = Parameters(), dimensions = 3):
        super(SirPlotsAlot, self).__new__(p, dimensions)     # (1)

我的程序中的对象通过来回传递对象 p = Parameters() 来共享参数集。

现在,当我输入(文件是辅助.py)时:

import auxiliary
from parameters import Parameters
p = Parameters()
s = auxiliary.SirPlotsAlot(p, 3)

期望得到一个不错的“初始化具有初始维数 3 的 Parray”,但我得到了“2”。但是如果我输入:

import auxiliary
s = auxiliary.SirPlotsAlot()

我明白了

---> 67             shape = (p.nx, p.ny)
"AttributeError: 'int' object has no attribute 'nx'"

它认为“p”是一个int,但它不是。如果我玩弄它,我会得到很多奇怪的看似无关的错误。它认为的 int 是“2”。我完全迷路了。

我尝试过使用和不使用 # (1) 注释(超级调用)。

其他错误包括“AttributeError: 'list' object has no attribute 'p'”、“TypeError: new() 只需要 2 个参数(给定 1 个)”、“ValueError:需要超过 0 个值才能解包”(我用 *args 替换了 new 的参数,我不太了解)。

【问题讨论】:

  • 如果 Python 认为 p 是一个 int,它可能是正确的。使用pdb 在该行放置一个断点,看看你有什么。做一个堆栈跟踪,看看你是如何到达那里的。如果跟踪中提到的文件是您认为应该存在的文件,请务必仔细注意。当图书馆路径稍有错误时,不止一个人搞砸了。
  • 一个暗示您的代码其他地方可能存在问题的问题是__init__ 中的“p=Parameters()”调用可能没有按照您的想法执行。每当调用Parray.__new__ 时,它都不会创建一个新的Parameters 实例:相反,它会在函数第一次声明时创建one。 IOW,每个 Parray 在未通过时共享一个 Parameter 实例,这似乎不太可能是您的意图。 [我不明白这怎么可能是这里的问题,但它可能会在其他地方引起问题。]
  • 您对__new__()__init__() 的组合使用是...非正统的,至少可以这么说。为什么你又在你的 Parray 课上使用__new__()?对我来说,这看起来像是一个普通的 __init__() 方法;没有理由把它写成__new__()。为什么要在类上存储属性? (__new__() 中的第一个参数是对类的引用,而不是对实例的引用,因为没有实例。)
  • @DSM:很好的收获;我什至没有看过这个,但这是真正会绊倒 Python 新手(有时也是老手)的微妙错误之一。
  • 为了清楚起见,__new__ 是一个类方法,所以__new__ 的第一个参数不是实例(即自我)它是一个类。最好使用def __new__(cls, ... 之类的东西来提醒自己以及其他阅读代码的人,您在类而不是实例上进行操作。看起来您可能会放弃 new 以支持 init,但了解未来是件好事。

标签: python inheritance numpy attributeerror default-arguments


【解决方案1】:

我要回应 kindall,并说“不要使用 __new__”。你的Parray.__new__ 方法看起来更像是一个初始化,应该使用__init__,就像它的子类一样。

【讨论】:

  • 你确定吗?我认为子类化np.ndarray,你应该使用__new__ 而不是__init__ docs.scipy.org/doc/numpy/user/basics.subclassing.html
  • 这正是我的想法。这就是为什么我希望我的新 ndarray 子类的每个子类都不必遵循 ndarray 的规则。我目前正在考虑的一种探索途径是尝试将 Parray 视为 ndarray。作为快速修复,我在类初始化中删除了变量默认语句,并创建了辅助函数:NewSirPlotsAlot(p = Parameters() dimensions=3): return SirPlotsAlot(p, dimensions)。不是我想要的,但它似乎有效。
  • @keflavich:好点。我还没有看到太多需要__new__ 的python 代码。您可能不想做的是从__init__ 调用__new__,因为它已经被调用了。
【解决方案2】:

已经十年了,我早就离开了这个项目,但我通过创建帮助函数来创建新类并设置它们来解决这个问题。在下面的代码示例中,请参阅文件底部的定义。我导入并使用了这些。

提示 Matthew Schinckel 指出 __new__ 应该在 __init__ 运行时已经被调用,并感谢其他所有人的想法。

# -*- coding: utf-8 -*-
"""
Era's Plotting Functionality. This module exports SirPlotsAlot and company:
    
class SirPlotsAlot: Array with 2D, 3D, animated plotting capability, and a pyrism Parameteres object.

def NewSirPlotsAlot(p, dimensionality): returns SirPlotsAlot, but doesn't need explicit parameters

def returns_SirPlotsAlot: decorator force ndarray-returning function to return SirPlotsAlot instead.

Created on Thu Jul 12 18:46:15 2012
@author: Era
"""

# SirPlotsAlot
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.pyplot as pyplot
import numpy as np
import scipy as scipy
import logging

lprint = logging.getLogger('pyrism')



class SirPlotsAlot(np.ndarray):
    """
    An array with 2D, 3D, animated plotting capability, and a pyrism Parameters object.

    Inherits: numpy's ndarray

    Input:    
        dimensionality: An int. The dimensions of the ndarray. Can be changed later.
        p:  A Parameters object.

    """

    # class variables
    currentSlice = 0    # for _updateSlice and animated plots

    
    def __new__(cls, shape):
        """
        Creates a new SirPlotsAlot for us to use.

        SirPlotsAlot inherits ndarray. ndarray is written in C, and needs an extra\
        method called __new__ to help it.
        
        Args:
            shape: A tuple of ints. The shape of the underlying ndarray.
            
        Returns:
            an ndarray
            
        Author / Date:
            Erasmus Alcarin   /   January 23rd, 2012
            Erasmus Alcarin   /      July 13th, 2012
            
        """

       
        ### Specify the exact parameters of the array this class implements
        dtype=float         # dtype: data type. Optional
                                # Any object can be interpreted as
                                # a numpy datatype
        buffer=None         # buffer: object exposing buffer interface. Optional
                                # Used to fill array with data
        offset=0            # offset: int. Optional.
                                # offset of array in data buffer
        strides=None        # strides : tuple of ints. Optional
                                # Strides of data in memory
        order=None          # order : {'C', 'F'}. Optional
                                # Row-major or column-major order.
        
        # Instantiate new ndarray (this class). Temporarily called sub_array.
        subarr = np.ndarray.__new__(cls,   # cls is crucial.
                                            # it creates a ndarray that
                                            # is of type THIS CLASS
                                            # instead of type ndarray
                shape, dtype, buffer, offset, strides, order)
        
        # Return the successfully created instance for this class to use!
        return subarr


    def __init__(self, shape):
        """Says hello!

        Args:
            shape: A tuple of ints. The shape of the underlying ndarray.

        Returns:
            None

        Author:
            Erasmus Alcarin   /   January 23rd, 2012
            Erasmus Alcarin   /      July 13th, 2012

        """
        lprint.debug("Ah, kind sir! Thy bidding be done!")


    def __array_finalize__(self, obj):
        """Allow inheritance of ndarray's unary(?) operations.

        Purpose:  ndarray has a lot of functions which let you interact
                  with it (all its awesome features, specifically views
                  and so-called "new-from-template": that is, slices).
                  This function tells python that our class also gets
                  to use all of those nifty "unary" features!

        Args:
            obj: Another object. For example, this function is called if we type
                myArr = myIntensityMap[1:]
                (myArr is obj, and myIntensityMap is self)

        Returns:
            None

        Author / Date:
            Erasmus Alcarin   /   January 23rd, 2012

        """
        if obj is None: return
        

    def __array_wrap__(self, out_arr, context=None):
        """Allow inheritance of ndarray's binary(?) operations.

        Purpose:  ndarray has a lot of functions which let you interact
                  with it (all its awesome features, specifically array
                  adding, multiplying, etc.). This function tells python
                  to use all of those nifty "binary" features!

        Args:
            out_arr: What is returned in the operation which is being
                performed.
            context: A parameter which _array_wrap__ is specified to take. (optional)
                If you know, update me!

        Returns:
            See ndarray.__array_wrap___()

        Author / Date:
            Erasmus Alcarin   /   January 23rd, 2012

        """
        # Call ndarray's __array_wrap__ method.
        return np.ndarray.__array_wrap__(self, out_arr, context).view(type(self))


    def _enforceXD(self, X):
        """
        a helper function returning true if this
        SirPlotsAlot has dimensionality X, otherwise
        raising a ValueError.

        Args:
            X: An int. The underlying ndarray dimensionality being tested for.

        Returns:
            True if this array has dimensionality X.
            ValueError is raised otherwise.

        """
        if self.shape.__len__() == X:
            return True
        else:
            raise ValueError, "A %sD array was required. A %sD array was supplied." % (X, self.shape.__len__())


    def _checkXD(self, X):
        """
        a helper function returning true if this
        SirPlotsAlot has dimensionality X, otherwise
        raising a ValueError.

        Args:
            X: An int. The underlying ndarray dimensionality being tested for.

        Returns:
            True if this array has dimensionality X.
            False otherwise.

        """
        if self.shape.__len__() == X:
            return True
        else:
            return False


    def _checkLabelInfo(self, label = None):
        """
        a helper utility function to decide which
        of the accepted formats for plot label the user
        has specified.

        Args:
            label: The user's input (Valid formats are String, Tuple)

        Returns:
            Nothing
        
        """
        if len(label) >= 1 and type(label[0]) == str:
            pyplot.title(label[0])
        if len(label) >= 2 and type(label[1]) == str:
            pyplot.xlabel(label[1])
        if len(label) >= 3 and type(label[2]) == str:
            pyplot.ylabel(label[2])
        if len(label) >= 4 and type(label[3]) == str:
            pyplot.zlabel(label[3])
    
    
    def _add_labels(self, label = None, caller_label = 'none'):
        """
        A utility function to quickly add labels to
        any of the graphing utilities embedded in
        SirPlotsAlot.

        Args:
            label: (Str, Tuple) The label being supplied by the user.
            caller_label: A string. Each plotting function has its own
                axes to label. This identifies the plotting function.

        Returns:
            Nothing

        """
        if label == None:
            raise ValueError, "_add_labels violated"
        else:
            lprint.debug("going on to labelling")

        if type(label) == str:
            pyplot.title(label)
        elif type(label) == tuple:
            self._checkLabelInfo(label)
        
        elif hasattr(self, 'caller_label'):
            label = getattr(self, caller_label)
            
            if type(label) == str:
                pyplot.title(label)
            elif type(label) == tuple:
                self._checkLabelInfo(label)
        
        else:
            print getattr(self, caller_label)
            raise ValueError, "_add_labels requires string or tuple of strings"
            

    def _updateSlice(self):
        """
        a helper function for animate2D(), this controls the
        progression (speed, sampling) of the animation by
        returning the next image to be presented in the animation.

        Args:
            None

        Returns:
            2D slice of this array.

        """
        if self._enforceXD(3):
            self.currentSlice += 1
        
            return self[self.currentSlice]


    def plot1D(self, label = None):
        """
        Plot 1-axis SirPlotsAlot in 2D, plotting array contents as y (up).

        Args:
            label: String or tuple labelling the plot.

        Returns:
            Nothing

        """
        if self._enforceXD(1):
            pyplot.figure()
            
            # self._add_labels(label, 'plot1D_label')
            
            pyplot.plot(self)
             
            pyplot.show()


    def plot2D(self, label = None):
        """
        Plot 2-axis SirPlotsAlot in 2D, plotting array contents as color.

        Args:
            label: String or tuple labelling the plot.

        Returns:
            Nothing

        """
        if self._enforceXD(2):
            # Do not produce huge output
            #lprint.debug("We're plotting this up:\n%s" % self)
            lprint.debug("We're plotting you some goodies!")

            fig = pyplot.figure()
            if type(label) == str:      # if label is supplied, apply it.
                pyplot.title(label)
            elif hasattr(self, 'plot2D_label'):
                pyplot.title(self.plot2D_label)

            plot = pyplot.imshow(self)
            
            fig.colorbar(plot)
            #colorbar.ax.set_yticklabels(["%.2f" % self.min(), '0', "%.2f" % self.max()])

            pyplot.gca().invert_yaxis()

            pyplot.xlabel('x')
            pyplot.ylabel('y')

            pyplot.show()


    def save_plot2D(self, file = None, label = None, cbar_ticks = None):
        """
        Saves plot of 2-axis SirPlotsAlot in 2D, plotting array contents as color,
        in .png format.

        Args:
            file: A string. The filename to save to. Default: ``output``
            label: String or tuple labelling the plot.
            cbar_ticks: Colorbar ticks for plot. Default: auto.

        Returns:
            Nothing

        """
        if self._enforceXD(2):
            # Do not produce huge output
            #lprint.debug("We're plotting this up:\n%s" % self)
            lprint.debug("We're plotting you some goodies!")

            fig = pyplot.figure()
            if label is not None:
                self._add_labels(label, 'plot2D_label')

            plot = pyplot.imshow(self)
            
            if cbar_ticks == None:
                fig.colorbar(plot)
            else:
                cbar = fig.colorbar(plot, ticks=cbar_ticks) # Numbers
                cbar.ax.set_yticklabels(map(str, cbar_ticks))   # Label

            pyplot.gca().invert_yaxis()

            if file == None:
                file = 'output'
            pyplot.savefig(file)


    # nice!
    def plot3D(self, label = None):
        """
        Plots 2-axis SirPlotsAlot in 3D, plotting array contents as 3rd dimension (up).

        Args:
            label: String or tuple labelling the plot.

        Returns:
            Nothing

        """
        if self._enforceXD(2):
            # Do not produce huge output
            #lprint.debug("We're plotting this up:\n%s" % self)
            lprint.debug("We're plotting you some goodies!")

            # make grid from min to max with interval nx
            x = scipy.linspace(0, self.shape[1], self.shape[1])
            y = scipy.linspace(0, self.shape[0], self.shape[0])
            [x, y] = scipy.meshgrid(x, y)       # this is the same as make_2d

            fig = pyplot.figure()
            if type(label) == str:      # if label is supplied, apply it.
                pyplot.title(label)
            elif hasattr(self, 'plot3D_label'):
                pyplot.title(self.plot3D_label)

            ax = Axes3D(fig)                    # make a 3D axis
            ax.plot_surface(x, y, self)

            pyplot.xlabel('x')
            pyplot.ylabel('y')

            pyplot.show()


    def plot3D_2(self, label = None):
        """
        Plots 2-axis SirPlotsAlot in 3D, plotting array contents as 3rd dimension (up),
        with contours projected onto each 2D cross-section of the 3D plot.

        Args:
            label: String or tuple labelling the plot.

        Returns:
            Nothing

        """
        if self._enforceXD(2):
            # Do not produce huge output
            #lprint.debug("We're plotting this up:\n%s" % self)
            lprint.debug("We're plotting you some goodies!")

            # make grid from min to max with interval nx
            x = scipy.linspace(0, self.shape[1], self.shape[1])
            y = scipy.linspace(0, self.shape[0], self.shape[0])
            [x, y] = scipy.meshgrid(x, y)

            fig = pyplot.figure()
            if type(label) == str:
                pyplot.title(label)
            elif hasattr(self, 'plot3D_2_label'):
                pyplot.title(self.plot3D_2_label)

            ax = fig.gca(projection='3d')

            ax.plot_surface(x, y, self, rstride=8, cstride=8, alpha=0.3)
            ax.contour(x, y, self, zdir='z', offset=self.min())
            ax.contour(x, y, self, zdir='x', offset=0)
            ax.contour(x, y, self, zdir='y', offset=self.shape[0])

            ax.set_xlabel('x')
            ax.set_xlim(0, self.shape[1])
            ax.set_ylabel('y')
            ax.set_ylim(0, self.shape[0])
            ax.set_zlabel('z')
            ax.set_zlim(self.min(), self.max())

            pyplot.show()

        # the following is probably deprecated code for the above.
        '''if self._enforceXD(2):
            print "We're plotting this up:\n%s" % self

            # make grid from min to max with interval nx
            x = y = scipy.linspace(self.min(), self.max(), self.shape[0])
            [x, y] = scipy.meshgrid(x, y)       # this is the same as make_2d

            fig = pyplot.figure()
            if type(label) == str:      # if label is supplied, apply it.
                pyplot.title(label)
            elif hasattr(self, 'plot3D_label'):
                pyplot.title(self.plot3D_label)

            ax = Axes3D(fig)                    # make a 3D axis
            ax.plot_surface(x, y, self)

            pyplot.show()
        '''


    def aniPlot2D(self):
        """
        Generate successive 2D color plots using color for the data. Then play these
        plots in series, creating an animation. Requires 3D SirPlotsAlot.

        Args:
            None

        Returns:
            Nothing

        """
        self.tplot = 0

        fig = pyplot.figure()

        #x = np.arange(0, self.shape[1])
        #y = np.arange(0, self.shape[0]).reshape(-1,1)

        ims = []
        imsappend = ims.append      # optimization
        for t in np.arange(self.shape[1]):
            imsappend((pyplot.imshow(self[t]),))

        animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=3000, blit=True)

        pyplot.show()



def NewSirPlotsAlot(shape = (512, 512)):
    """
    Returns instance of SirPlotsAlot explicitly initiallized to all zeros;
    arguments may be left unspecified.

    Args:
        shape: A tuple of ints. The shape of the underlying ndarray.

    Returns:
        SirPlotsAlot

    Author / Date:
        Erasmus Alcarin   /   July 13, 2012

    """
    s = SirPlotsAlot(shape)
    s[:] = np.zeros(s.shape)

    lprint.info("SirPlotsAlot has been populated with zeros.")

    return s


# Aliases for NewSirPlotsAlot
splot = NewSirPlotsAlot


def NewPsirPlotsAlot(dimensionality = 3, p = None):
    """
    Returns instance of SirPlotsAlot explicitly initiallized to all zeros;
    arguments may be left unspecified.

    Args:
        dimensionality: An int. Number of dimensions for array. (optional)
        p: A Parameters object. Simulation parameters. (optional)

    Returns:
        SirPlotsAlot

    Author / Date:
        Erasmus Alcarin   /   July 13, 2012

    """
    lprint.debug("Initializing SirPlotsAlot with initial dimensionality %s..." % dimensionality)
    # NewSirPlotsAlot()
    try:
        import pyrism.parameters as par
    except:
        import sys
        lprint.error("Use of pyrism as non-package detected. You must remain in the pyrism directory.")
        import parameters as par


    if p == None:
        p = par.Parameters.Instance()

    # extract size from parameters file, assuming size nx, ny
    if dimensionality == 2:
        shape = (p.ny, p.nx)
    elif dimensionality == 3:
        shape = (p.nx, p.ny, p.nx)
    else:
        raise NotImplementedError, "dimensionality must be 2 or 3"
        
    # Make and Get object
    s = SirPlotsAlot(shape)
    s[:] = np.zeros(s.shape)

    lprint.info("SirPlotsAlot has been populated with zeros.")

    return s
        
        
# Aliases for NewPsirPlotsAlot
psplot = NewPsirPlotsAlot


def returns_SirPlotsAlot(fn):
    """
    A decorator that changes an ndarray to a SirPlotsAlot by
    means of the ndarray view function. (Returns SirPlotsAlot)

    """
    def wrapped(*args, **kwargs):
        return fn(*args, **kwargs).view(SirPlotsAlot)
    return wrapped


# Aliases for returns_SirPlotsAlot
returns_splot = returns_SirPlotsAlot

【讨论】:

    猜你喜欢
    • 2014-05-22
    • 1970-01-01
    • 2011-10-06
    • 1970-01-01
    • 2018-07-06
    • 1970-01-01
    • 2021-10-12
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多