【问题标题】:Spatial Join between pyspark dataframe and polygons (geopandas)pyspark 数据框和多边形(geopandas)之间的空间连接
【发布时间】:2020-03-27 08:37:32
【问题描述】:

问题:

我想在以下之间进行空间连接:

  • 带有(例如道路上的点)的大型 Spark Dataframe(500M 行)
  • 带有多边形(例如区域边界)的小型geojson(20000 个形状)。

这是我目前所拥有的,我发现它很慢(很多调度程序延迟,可能是由于 communes 没有广播):

@pandas_udf(schema_out, PandasUDFType.GROUPED_MAP)
def join_communes(traces):   
    geometry = gpd.points_from_xy(traces['longitude'], traces['latitude'])
    gdf_traces = gpd.GeoDataFrame(traces, geometry=geometry, crs = communes.crs)
    joined_df = gpd.sjoin(gdf_traces, communes, how='left', op='within')
    return joined_df[columns]

pandas_udf 将一点 points 数据帧(trace)作为 pandas 数据帧,将其转换为带有 geopandas 的 GeoDataFrame,并使用 polygons进行空间连接> GeoDataFrame(因此受益于 Geopandas 的 Rtree 连接)

问题:

有没有办法让它更快?我知道我的 communes 地理数据帧在 Spark Driver 的内存中,并且每个工作人员都必须为每次调用 udf 下载它,这是正确的吗?

但是我不知道如何使这个 GeoDataFrame 直接对工作人员可用(如在广播连接中)

有什么想法吗?

【问题讨论】:

  • 你广播过公社吗?您应该广播公社,然后使用 communes.value 访问 json
  • 这就是我最终做的事

标签: python pandas pyspark pyspark-sql geopandas


【解决方案1】:

一年后,这就是我最终按照@ndricca 的建议做的事情,诀窍是广播公社,但您不能直接广播GeoDataFrame,因此您必须将其加载为 Spark DataFrame,然后在广播之前将其转换为 JSON。然后使用shapely.wkt(众所周知的文本:一种将几何对象编码为文本的方法)在UDF中重建GeoDataFrame

另一个技巧是在 groupby 中使用盐来确保数据在集群中的平均重新分区

import geopandas as gpd
from shapely import wkt
from pyspark.sql.functions import broadcast
communes = gpd.load_file('...communes.geojson')
# Use a previously created spark session
traces= spark_session.read_csv('trajectoires.csv')
communes_spark = spark.createDataFrame(communes[['insee_comm', 'wkt']])
communes_json = provinces_spark.toJSON().collect()
communes_bc = spark.sparkContext.broadcast(communes_json)

@pandas_udf(schema_out, PandasUDFType.GROUPED_MAP)
def join_communes_bc(traces):
    communes = pd.DataFrame.from_records([json.loads(c) for c in communes_bc.value])
    polygons = [wkt.loads(w) for w in communes['wkt']]
    gdf_communes = gpd.GeoDataFrame(communes, geometry=polygons, crs=crs )
    geometry = gpd.points_from_xy(traces['longitude'], traces['latitude'])
    gdf_traces = gpd.GeoDataFrame(traces , geometry=geometry, crs=crs)
    joined_df = gpd.sjoin(gdf_traces, gdf_communes, how='left', op='within')
    return joined_df[columns]
    

traces = traces.groupby(salt).apply(join_communes_bc)

【讨论】:

  • 您好,我正在尝试实现相同的功能,但 pyarrow 存在一些问题。我遵循了在线建议的解决方案(将 pyarrow 降级到 0.14.1 并添加 env 变量:ARROW_PRE_0_15_IPC_FORMAT = 1)。但我总是遇到与 pyarrow 相关的错误。你有没有遇到类似的错误? pyarrow.lib.ArrowInvalid:输入对象不是 NumPy 数组
  • 没关系,我发现解决方案是永远不要返回 wkt 列。在某种程度上,它会导致此错误。
猜你喜欢
  • 1970-01-01
  • 2018-08-01
  • 2019-02-22
  • 1970-01-01
  • 1970-01-01
  • 2018-12-13
  • 1970-01-01
  • 1970-01-01
  • 2019-02-16
相关资源
最近更新 更多