【问题标题】:How to get rows where at least two distinct values are in a column?如何获取列中至少有两个不同值的行?
【发布时间】:2022-01-04 22:07:56
【问题描述】:

我有一个日志文件,我想报告启动了一种以上(至少两种)protocol 连接的IP 地址,同时显示这些协议。我正在尝试使用 both DataFrames API 和 SparkSQL 来获得这些结果。

这是我的数据示例:

+----------------+--------+--------+---------------+--------------+---------+-------------+------+-----+
|       Timestamp|Duration|Protocol|BytesOriginator|ResponderBytes|LocalHost|   RemoteHost| State|Flags|
+----------------+--------+--------+---------------+--------------+---------+-------------+------+-----+
|748162802.427995| 1.24383|    smtp|              ?|             ?|        1| 128.97.154.3|   REJ|    L|
|748162802.803033| 3.96513|    smtp|           1173|           328|        3|  128.8.142.5|    SF| null|
|748162804.817224| 1.02839|    nntp|             58|           129|        2|   140.98.2.1|    SF|    L|
|748162812.254572| 138.168|    nntp|         363238|          1200|        4| 128.49.4.103|    SF|    L|
|748162817.478016| 10.0858|    nntp|            230|           100|        4| 128.32.133.1|    SF|    N|
|748162833.453963| 2.16477|    smtp|           2524|           306|        5|192.48.232.17|    SF| null|
|748162836.735788| 13.1779|    smtp|          16479|           174|       16| 128.233.1.12|RSTRS3|    L|
|748162839.930331| 6.69767|    smtp|           3104|           371|        8|   139.91.1.1|    SF|    L|
|748162841.854151| 2.07407|    smtp|           1172|           380|        6|  128.8.142.5|    SF| null|
|748162854.814153| 131.659|    nntp|         319292|          1220|        4| 128.110.4.25|    SF|    L|
|748162866.207165| 51.8406|    nntp|         135714|           280|        4| 128.110.4.25|    SF| null|
|748162866.600750|0.402045|    smtp|              ?|             ?|        1| 128.97.154.3|   REJ|    L|
|748162869.790751| 172.363|    smtp|              0|             0|       16|132.230.6.100|    SF|    L|
|748162873.491682|  102.88|    nntp|            346|           180|        4| 128.32.136.1|    SF|   LN|
|748162875.237378| 5.32943|    nntp|             90|            85|        4| 128.32.133.1|    SF|    N|
+----------------+--------+--------+---------------+--------------+---------+-------------+------+-----+

我试图过滤我的数据框,但我一直收到错误,我不知道我是否应该使用 Window 函数。通过使用 SparkSQL,到目前为止,我得到了 IPs 但没有 protocols

这就是我所做的:

custom_schema = StructType([
    StructField('Timestamp', StringType(), True),
    StructField('Duration', FloatType(), True),
    StructField('Protocol', StringType(), True),
    StructField('BytesOriginator', StringType(), True),
    StructField('ResponderBytes', StringType(), True),
    StructField('LocalHost', StringType(), True),
    StructField('RemoteHost', StringType(), True),
    StructField('State', StringType(), True),
    StructField('Flags', StringType(), True) 
])

logs = spark.read.csv('lbl-conn-7.csv', header=False, sep=' ', schema=custom_schema)

# I get an error
logs.select('RemoteHost', 'Protocol').distinct().filter(F.countDistinct('Protocol') > 1).show()

logs.createOrReplaceTempView("mytable")
sqlContext = SQLContext(sc)
df = sqlContext.sql("select remotehost, protocol FROM mytable GROUP BY  HAVING COUNT(distinct protocol) > 1")
# It doesn't show the protocols
df.show()

【问题讨论】:

    标签: apache-spark pyspark apache-spark-sql


    【解决方案1】:

    您可以按RemoteHost 分组并收集使用的不同Protocol 的列表。然后,使用协议数组的大小过滤生成的数据帧:

    import pyspark.sql.functions as F
    
    logs.groupBy("RemoteHost").agg(
        F.collect_set("Protocol").alias("Protocols")
    ).filter(
        F.size("Protocols") >= 2
    ).show()
    

    Spark SQL 等效查询:

    SELECT  RemoteHost, 
            collect_set(Protocol) AS Protocols
    FROM    mytable 
    GROUP BY  RemoteHost
    HAVING  size(Protocols) >= 2 -- or count(distinct Protocol)  >= 2
    

    如果要保留所有列,请使用带有collect_set 函数的Window:

    logs.withColumn(
        "Protocols",
        F.collect_set("Protocol").over((Window.partitionBy("RemoteHost")))
    ).filter(
        F.size("Protocols") >= 2
    ).drop("Protocols").show()
    

    【讨论】:

    • 非常感谢!它按预期工作。有什么办法可以使结果变平,以便获得 N 行而不是嵌套数组?我尝试使用 array_distinct(flatten()) 但它告诉我你不能给出一个字符串数组。
    • @cdaveau 是的,你可以使用explode,过滤后添加.select("RemoteHost", F.explode("Protocols").alias("Protocol")).show()。或者使用如上所示的 Window。
    猜你喜欢
    • 1970-01-01
    • 2017-06-09
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2013-08-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多