I am using PySpark and implemented a some pipelines using batch processing. This pipelines need to save some state between batches so i created my own state manager (is there a better way in general?):
class StateManager:
SCHEMA_KEY = 'schema'
def __init__(self, spark: SparkSession, config: UMSConfig):
self.spark = spark
if not self.spark:
raise ValueError("Spark is not set.")
self.state_folder = config.UMS_STATE_FOLDER
if not self.state_folder:
raise ValueError("UMS_STATE_FOLDER is not set in config.")
self.meta_file = os.path.join(self.state_folder, "ums_meta.json")
self.states = self._load_metadate()
def _load_metadate(self):
try:
with open(self.meta_file, 'r') as f:
return json.load(f)
except FileNotFoundError:
logging.info("UMS state metadata file not found. Will create on demand.")
return {}
def _save_metadata(self):
with open(self.meta_file, 'w') as f:
json.dump(self.states, f)
def update_state_metadata(self, state_id: str, schema, description: str):
self.states[state_id] = {
('%s' % StateManager.SCHEMA_KEY): schema.jsonValue(),
'description': description
}
self._save_metadata()
def load_state(self, state_id: str, msn: str):
if not msn:
raise ValueError("MSN must be provided.")
if state_id not in self.states:
raise ValueError(f"State {state_id} not found. Use update_state_metadata first to register it.")
path = os.path.join(self.state_folder, state_id, msn)
schema = StructType.fromJson(self.states[state_id][StateManager.SCHEMA_KEY])
try:
return self.spark.read.schema(schema).parquet(path)
except AnalysisException:
logging.info(f"State {state_id} not found yet. Will create new empty state.")
return self.spark.createDataFrame([], schema)
def save_state(self, state_id: str, msn: str, new_state: DataFrame):
if state_id not in self.states:
raise ValueError(f"State {state_id} not found. Use update_state_metadata first to register it.")
path = os.path.join(self.state_folder, state_id, msn)
new_state.write.mode("overwrite").parquet(path)
I created a test that tries to create a full cycle:
- add a new state schema
- save state
- read it
- change the data frame by union with new data
- save again <- ERROR
It is possible the underlying files have been updated. You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved.
Adding
self.spark.catalog.refreshByPath(path)
As last line to save_state did not helped.
For sake of completeness the test:
class TestStateManagement:
@pytest.fixture(autouse=True)
def init_test(self, spark_unit_test_fixture, tmp_path):
self.spark = spark_unit_test_fixture
self.path = tmp_path
def test_state_manager(self):
ums_config = mock(UMSConfig)
ums_config.UMS_STATE_FOLDER = self.path
state_manager = StateManager(self.spark, ums_config)
schema = StructType([StructField("time", IntegerType()), StructField("value", DoubleType())])
state_id = "last_closing_times"
state_manager.update_state_metadata(state_id,
schema,
"stores the last 3 closing times")
assert state_manager.states[state_id][StateManager.SCHEMA_KEY] == schema.jsonValue()
new_state = self.spark.createDataFrame([(1, 1.0), (2, 2.0), (3, 3.0)], schema)
state_manager.save_state(state_id, "pc24_666", new_state)
loaded_state = state_manager.load_state(state_id, "pc24_666")
assert loaded_state.count() == 3
new_state = self.spark.createDataFrame([(4, 4.0), (5, 5.0), (6, 6.0)], schema)
updated_state = loaded_state.union(new_state)
state_manager.save_state(state_id, "pc24_666", updated_state) # Error occurs here
updated_state = state_manager.load_state(state_id, "pc24_666")
assert updated_state.count() == 6
How to solve this problem?