取决于用例以及如何使用驱动程序。
假设您想从 Spark 结构化流中收集一些 N 条记录(推文),将它们存储在 Postgresql 中,并在计数超过 N 条记录时停止流。
一种方法是使用累加器和 python 线程。
- 使用流查询对象和累加器创建一个 Python 线程,一旦超过计数就停止查询
- 启动流查询时,传递累加器变量并更新每批流的值。
分享代码 sn-p 用于理解/说明目的...
import threading
import time
def check_n_stop_streaming(query, acc, num_records=3500):
while (True):
if acc.value > num_records:
print_info(f"Number of records received so far {acc.value}")
query.stop()
break
else:
print_info(f"Number of records received so far {acc.value}")
time.sleep(1)
...
count_acc = spark.sparkContext.accumulator(0)
...
def postgresql_all_tweets_data_dump(df,
epoch_id,
raw_tweet_table_name,
count_acc):
print_info("Raw Tweets...")
df.select(["text"]).show(50, False)
count_acc += df.count()
mode = "append"
url = "jdbc:postgresql://{}:{}/{}".format(self._postgresql_host,
self._postgresql_port,
self._postgresql_database)
properties = {"user": self._postgresql_user,
"password": self._postgresql_password,
"driver": "org.postgresql.Driver"}
df.write.jdbc(url=url, table=raw_tweet_table_name, mode=mode, properties=properties)
...
query = tweet_stream.writeStream.outputMode("append"). \
foreachBatch(lambda df, id :
postgresql_all_tweets_data_dump(df=df,
epoch_id=id,
raw_tweet_table_name=raw_tweet_table_name,
count_acc=count_acc)).start()
stop_thread = threading.Thread(target=self.check_n_stop_streaming, args=(query, num_records, raw_tweet_table_name, ))
stop_thread.setDaemon(True)
stop_thread.start()
query.awaitTermination()
stop_thread.join()