I need to “checkpoint” certain information during my batch processing with pyspark that are needed in the next batches.
For this use case, DataFrame.checkpoint seems to fit. While I found many places that explain how to create the one, I did not find any how to restore or read a checkpoint.
For this to be tested, I created a simple test class with two (2) tests. The first reads a CSV and creates a sum. The 2nd one should just get some a continue to sum up:
import pytest
from pyspark.sql import functions as f
class TestCheckpoint:
@pytest.fixture(autouse=True)
def init_test(self, spark_unit_test_fixture, data_dir, tmp_path):
self.spark = spark_unit_test_fixture
self.dir = data_dir("")
self.checkpoint_dir = tmp_path
def test_first(self):
df = (self.spark.read.format("csv")
.option("pathGlobFilter", "numbers.csv")
.load(self.dir))
sum = df.agg(f.sum("_c1").alias("sum"))
sum.checkpoint()
assert 1 == 1
def test_second(self):
df = (self.spark.read.format("csv")
.option("pathGlobFilter", "numbers2.csv")
.load(self.dir))
sum = # how to get back the sum?
Creating the checkpoint in first test works fine (set tmp_path as checkpoint dir) and i see a folder created with a file.
But how do I read it?
And how do you handle multiple checkpoints? For example, one checkpoint on the sum and another for the average?
Are there better approaches to storing state across batches?
For sake of completeness, the CSV looks like this:
1719228973,1
1719228974,2
And this is only a minimal example to get it running – my real scenario is more complex.