【问题标题】:How to plot correlation heatmap when using pyspark+databricks使用 pyspark+databricks 时如何绘制相关热图
【发布时间】:2019-08-28 00:16:30
【问题描述】:

我正在研究数据块中的 pyspark。我想生成一个相关热图。假设这是我的数据:

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              ['col1','col2','col3'])

这是我的代码:

import pyspark
from pyspark.sql import SparkSession
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from ggplot import *
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
from pyspark.mllib.stat import Statistics

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              ['col1','col2','col3'])
vector_col = "corr_features"
assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                            outputCol=vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector, vector_col)
matrix.collect()[0]["pearson({})".format(vector_col)].values

直到这里,我才能得到相关矩阵。结果如下:

现在我的问题是:

  1. 如何将矩阵传输到数据框?我试过How to convert DenseMatrix to spark DataFrame in pyspark?How to get correlation matrix values pyspark的方法。但这对我不起作用。
  2. 如何生成如下所示的相关热图:

因为我刚刚学习了 pyspark 和 databricks。 ggplot 或 matplotlib 都可以解决我的问题。

【问题讨论】:

    标签: ggplot2 pyspark heatmap correlation databricks


    【解决方案1】:

    我认为你感到困惑的地方是:

    matrix.collect()[0]["pearson({})".format(vector_col)].values
    

    调用密集矩阵的 .values 会为您提供所有值的列表,但您实际上要查找的是表示相关矩阵的列表列表。

    import matplotlib.pyplot as plt
    from pyspark.ml.feature import VectorAssembler
    from pyspark.ml.stat import Correlation
    
    columns = ['col1','col2','col3']
    
    myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                                   (2.5,4.6,3.1),
                                   (6.5,7.2,10.0)],
                                  columns)
    vector_col = "corr_features"
    assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                                outputCol=vector_col)
    myGraph_vector = assembler.transform(myGraph).select(vector_col)
    matrix = Correlation.corr(myGraph_vector, vector_col)
    

    到目前为止,它基本上是您的代码。您应该使用 .toArray().tolist() 而不是调用 .values 来获取表示相关矩阵的列表:

    matrix = Correlation.corr(myGraph_vector, vector_col).collect()[0][0]
    corrmatrix = matrix.toArray().tolist()
    print(corrmatrix)
    

    输出:

    [[1.0, 0.9582184104641529, 0.9780872729407004], [0.9582184104641529, 1.0, 0.8776695567739841], [0.9780872729407004, 0.8776695567739841, 1.0]]
    

    这种方法的优点是您可以轻松地将列表列表转换为数据框:

    df = spark.createDataFrame(corrmatrix,columns)
    df.show()
    

    输出:

    +------------------+------------------+------------------+ 
    |              col1|              col2|              col3| 
    +------------------+------------------+------------------+ 
    |               1.0|0.9582184104641529|0.9780872729407004|
    |0.9582184104641529|               1.0|0.8776695567739841| 
    |0.9780872729407004|0.8776695567739841|               1.0|  
    +------------------+------------------+------------------+
    

    回答你的第二个问题。只是绘制热图的众多解决方案之一(如thisthis 使用seaborn 更好)。

    def plot_corr_matrix(correlations,attr,fig_no):
        fig=plt.figure(fig_no)
        ax=fig.add_subplot(111)
        ax.set_title("Correlation Matrix for Specified Attributes")
        ax.set_xticklabels(['']+attr)
        ax.set_yticklabels(['']+attr)
        cax=ax.matshow(correlations,vmax=1,vmin=-1)
        fig.colorbar(cax)
        plt.show()
    
    plot_corr_matrix(corrmatrix, columns, 234)
    

    【讨论】:

    • Crnoik - 值是否必须为 INT 格式?我正在尝试 FLOAT 值之间的相关性并在结果矩阵中获取 NaN。
    • 不,没有必要。在上面的示例中,我也使用了浮点格式。你能打开你自己的问题并向我们展示你的代码吗?我去看看。
    猜你喜欢
    • 2015-06-17
    • 1970-01-01
    • 2020-03-29
    • 2021-05-24
    • 2020-04-14
    • 2019-10-21
    • 2023-03-22
    • 2021-09-01
    • 2017-07-12
    相关资源
    最近更新 更多