【问题标题】:Not able to access class methods from pyspark RDD's map method无法从 pyspark RDD 的 map 方法访问类方法
【发布时间】:2018-02-21 00:30:29
【问题描述】:

在我的应用程序代码库中集成 pyspark 时,我无法在 RDD 的 map 方法中引用类的方法。我用一个简单的例子重复了这个问题,如下

这是我定义的一个虚拟类,它只是将一个数字添加到从作为类属性的 RDD 派生的 RDD 的每个元素中:

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def add(self, a, b):
        return a + b

    def test_func(self, b):
        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: self.add(l[1], b))
        v_c = v.collect()
        return v_c

test_func() 在 RDD v 上调用 map() 方法,而后者又在 v 的每个元素上调用 add() 方法。调用test_func() 会抛出以下错误:

pickle.PicklingError: Could not serialize object: Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

现在,当我将 add() 方法移出类时,例如:

def add(self, a, b):
    return a + b

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def test_func(self, b):

        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: add(l[1], b))
        v_c = v.collect()

        return v_c

现在调用test_func() 可以正常工作。

[7, 9, 11]

为什么会这样?如何将类方法传递给 RDD 的 map() 方法?

【问题讨论】:

    标签: python pyspark rdd


    【解决方案1】:

    发生这种情况是因为当 pyspark 尝试序列化您的函数(将其发送给工作人员)时,它还需要序列化您的 Test 类的实例(因为您传递给 map 的函数引用了这个实例在self)。此实例具有对火花上下文的引用。您需要确保SparkContextRDDs 没有被序列化并发送给工作人员的任何对象引用。 SparkContext 只需要住在驱动程序中。

    这应该可行:

    在文件testspark.py:

    class Test(object):
        def add(self, a, b):
            return a + b
    
        def test_func(self, a_r, b):
            c_r = a_r.map(lambda l: (l[0], l[1] * 2))
            # now `self` has no reference to the SparkContext()
            v = c_r.map(lambda l: self.add(l[1], b)) 
            v_c = v.collect()
            return v_c
    

    在您的主脚本中:

    from pyspark import SparkContext
    from testspark import Test
    
    sc = SparkContext()
    a = [('a', 1), ('b', 2), ('c', 3)]
    a_r = sc.parallelize(a)
    
    test = Test()
    test.test_func(a_r, 5) # should give [7, 9, 11]
    

    【讨论】:

    • 这真的可以将方法传递给工人吗?如果你的类是在 python 内核中定义的,是的,但是如果你试图访问一个模块(a .py)你将不得不使用.addPyfiles 因为import内核将不起作用(so ref )。
    • @ohailolcat 这是一个很好的观点。在这个答案中,我假设工作人员中的 python 环境与驱动程序中的环境相同(即:testspark.py 已部署到工作人员,并且它位于 PYTHONPATH 中)否则 testspark.py 将需要通过.addPyfiles
    猜你喜欢
    • 1970-01-01
    • 2016-07-19
    • 1970-01-01
    • 1970-01-01
    • 2021-10-18
    • 2019-09-09
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多