I want to ingest and create a set of tables into a schema in Databricks. I already created the entire schema of several hundred tables in Databricks and now I just need to import the initial data load and periodically rerun for incremental loads.
Based on this tutorial, it should be a simple set of commands to import the DMS files from an S3 bucket into a Databricks schema but that tutorial is setting up a lot of extra stuff I don’t need. My DMS S3 bucket is already present with csv files and now I just need to import them.
Here is the original script from that tutorial:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
input_file_name, regexp_extract, lit, current_timestamp,
when, to_timestamp, col, row_number
)
from pyspark.sql.window import Window
from pyspark.sql.types import (
StructType, StructField, StringType, TimestampType,
IntegerType, BooleanType, DoubleType
)
import boto3
from urllib.parse import urlparse
from datetime import datetime
# AWS credentials
aws_access_key = dbutils.secrets.get(scope="aws-credentials-ANONYMOUS", key="aws-access-key")
aws_secret_key = dbutils.secrets.get(scope="aws-credentials-ANONYMOUS", key="aws-secret-key")
# Set Spark configurations
spark.conf.set("fs.s3a.access.key", aws_access_key)
spark.conf.set("fs.s3a.secret.key", aws_secret_key)
spark.conf.set("fs.s3a.endpoint", "s3.us-west-1.amazonaws.com")
# Get all tables
s3_client = boto3.client(
's3',
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
region_name='us-west-1'
)
def get_s3_files(bucket, prefix):
"""Get all files in S3 bucket with given prefix."""
paginator = s3_client.get_paginator('list_objects_v2')
files = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' in page:
files.extend(page['Contents'])
return sorted(files, key=lambda x: x['Key'])
def convert_schema_to_struct(schema_rows):
"""Convert schema rows from DESCRIBE TABLE to StructType."""
fields = []
for row in schema_rows:
if row.col_name != 'creation_dt': # Skip creation_dt as we handle it separately
fields.append(StructField(row.col_name, StringType(), True))
return StructType(fields)
def get_table_schema(table_name):
"""Get schema for a table from Databricks catalog."""
schema_rows = spark.sql(f"DESCRIBE bronze.myschema.{table_name}").collect()
return convert_schema_to_struct(schema_rows)
def add_creation_dt_column(table_name):
"""Add creation_dt column if it doesn't exist."""
columns = spark.sql(f"DESCRIBE bronze.myschema.{table_name}").collect()
if not any(row.col_name == 'creation_dt' for row in columns):
spark.sql(f"""
ALTER TABLE bronze.myschema.{table_name}
ADD COLUMN creation_dt TIMESTAMP
""")
def get_last_processed_timestamp(table_name):
"""Get the most recent creation_dt for a table."""
result = spark.sql(f"SELECT MAX(creation_dt) as max_dt FROM bronze.myschema.{table_name}").collect()
max_dt = result[0]['max_dt']
return max_dt if max_dt else None
def process_file(table_name, file_path, is_initial_load, schema):
"""Process a single DMS file."""
print(f"Processing file: {file_path}")
print(f"Schema: {schema}")
try:
if is_initial_load:
# Read CSV with proper quote handling
df = spark.read.format("csv")
.option("header", "false")
.option("quote", """)
.option("escape", """)
.option("multiLine", "true")
.load(file_path)
# Map _c0, _c1, etc. to schema columns
column_mappings = [
col(f"_c{i}").alias(field.name)
for i, field in enumerate(schema.fields)
]
mapped_df = df.select(column_mappings)
mapped_df.createOrReplaceTempView("source_data")
# Insert the data with current timestamp
spark.sql(f"""
INSERT INTO bronze.myschema.{table_name}
SELECT *, current_timestamp() as creation_dt
FROM source_data
""")
else:
# For incremental loads, read all columns
raw_df = spark.read.format("csv")
.option("header", "false")
.option("quote", """)
.option("escape", """)
.option("multiLine", "true")
.load(file_path)
# First column is operation, rest map to schema
# Map _c1, _c2, etc. to schema columns (offset by 1 due to operation column)
column_mappings = [
col(f"_c{i+1}").alias(field.name)
for i, field in enumerate(schema.fields)
]
# Handle deletes
delete_df = raw_df.filter(col("_c0") == "D").select(
col(f"_c1").alias(schema.fields[0].name) # Primary key is first schema column
)
if delete_df.count() > 0:
delete_df.createOrReplaceTempView("delete_records")
spark.sql(f"""
DELETE FROM bronze.myschema.{table_name}
WHERE {schema.fields[0].name} IN (SELECT {schema.fields[0].name} FROM delete_records)
""")
# Handle inserts and updates
upsert_df = raw_df.filter(col("_c0").isin(["I", "U"]))
.select(column_mappings)
if upsert_df.count() > 0:
# Deduplicate records by taking the last one for each primary key
window_spec = Window.partitionBy(schema.fields[0].name).orderBy(lit(1))
deduped_records = upsert_df
.withColumn("rn", row_number().over(window_spec))
.filter(col("rn") == 1)
.drop("rn")
deduped_records.createOrReplaceTempView("upsert_records")
# Generate column list for UPDATE SET clause
update_columns = [f for f in schema.fields if f.name != schema.fields[0].name]
update_sets = ", ".join([f"target.{f.name} = source.{f.name}" for f in update_columns])
# Perform merge for all records
spark.sql(f"""
MERGE INTO bronze.myschema.{table_name} target
USING (SELECT *, current_timestamp() as creation_dt FROM upsert_records) source
ON target.{schema.fields[0].name} = source.{schema.fields[0].name}
WHEN MATCHED THEN
UPDATE SET {update_sets}, target.creation_dt = source.creation_dt
WHEN NOT MATCHED THEN
INSERT *
""")
return True
except Exception as e:
print(f"Error processing file {file_path}: {str(e)}")
raise
return False
def format_timestamp(timestamp_str):
"""Convert filename timestamp to proper datetime format."""
try:
dt = datetime.strptime(timestamp_str, '%Y%m%d-%H%M%S%f')
return dt.strftime('%Y-%m-%d %H:%M:%S.%f')
except ValueError:
return None
def main():
# Get all tables in bronze.myschema schema
tables = spark.sql("SHOW TABLES IN bronze.myschema").select("tableName").collect()
bucket = "wp-staging-data-lake-west"
base_prefix = "inbound/production_db/my_app/"
for table in tables:
table_name = table.tableName
print(f"Processing table: {table_name}")
try:
# Ensure creation_dt column exists
add_creation_dt_column(table_name)
# Get table schema
schema = get_table_schema(table_name)
# Get last processed timestamp
last_timestamp = get_last_processed_timestamp(table_name)
# Get all files for this table
table_prefix = f"{base_prefix}{table_name}/"
s3_files = get_s3_files(bucket, table_prefix)
# Process initial load file if table is empty
initial_load_file = next((f for f in s3_files if 'LOAD00000001.csv' in f['Key']), None)
if last_timestamp is None and initial_load_file:
file_path = f"s3a://{bucket}/{initial_load_file['Key']}"
process_file(table_name, file_path, True, schema)
# Process incremental files
incremental_files = [f for f in s3_files if f['Key'].endswith('.csv') and 'LOAD00000001' not in f['Key']]
for file in incremental_files:
filename = file['Key'].split('/')[-1].split('.')[0]
if 'LOAD' not in filename:
formatted_timestamp = format_timestamp(filename)
if formatted_timestamp:
file_timestamp = datetime.strptime(formatted_timestamp, '%Y-%m-%d %H:%M:%S.%f')
if last_timestamp is None or file_timestamp > last_timestamp:
file_path = f"s3a://{bucket}/{file['Key']}"
process_file(table_name, file_path, False, schema)
except Exception as e:
print(f"Error processing table {table_name}: {str(e)}")
continue
if __name__ == "__main__":
main()
Here is my modified script trying to use the existing database schema already present in databricks to provide the schema information:
# Databricks notebook source
import dlt
from pyspark.sql.functions import *
from pyspark.sql.types import *
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('dms_pipeline')
# S3 path configuration
data = "s3://wp-staging-data-lake-west/inbound/production_db/my_app"
logger.info(f"Initializing DMS pipeline with data path: {data}")
def get_table_schema(table_name):
"""Get schema from existing bronze.mydb table"""
logger.info(f"Getting schema for table: {table_name}")
schema_df = spark.sql(f"DESCRIBE TABLE bronze.mydb.{table_name}")
return schema_df.collect()
def get_tables_from_catalog():
"""Get list of tables from bronze.mydb catalog"""
logger.info("Getting tables from bronze.mydb catalog")
tables_df = spark.sql("SHOW TABLES IN bronze.mydb")
tables = {}
for table_row in tables_df.collect():
table_name = table_row.tableName
# Get schema to determine primary key (first column)
schema = get_table_schema(table_name)
primary_key = schema[0].col_name
tables[table_name] = {'id': primary_key}
logger.info(f"Found table: {table_name} with primary key: {primary_key}")
return tables
def generate_tables(table, info):
logger.info(f"Generating CDC and target tables for: {table}")
@dlt.table(
name=f"{table}_cdc_raw",
table_properties={
"quality": "bronze",
"delta.autoOptimize.optimizeWrite": "true",
"delta.autoOptimize.autoCompact": "true"
},
comment=f"Raw DMS data for table: {table}",
temporary=True
)
def create_cdc_table():
logger.info(f"Creating CDC raw table for: {table}")
# Get schema from existing table
schema = get_table_schema(table)
column_names = [row.col_name for row in schema]
# Read stream from S3
stream = spark.readStream.format("cloudFiles")
.option("cloudFiles.format", "csv")
.load(f"{data}/{table}")
# For initial load (LOAD00000001.csv)
initial_condition = input_file_name().contains("LOAD00000001.csv")
# Add operation column and timestamp
stream = stream
.withColumn("Op",
when(initial_condition, lit("I"))
.otherwise(col("_c0")))
.withColumn("dmsTimestamp",
when(initial_condition, current_timestamp())
.otherwise(
to_timestamp(
regexp_extract(input_file_name(), r"(d{8}-d{9})", 1),
"yyyyMMdd-HHmmssSSS"
)
))
# Drop the operation column for initial load records and map columns
base_cols = []
for i, col_name in enumerate(column_names):
if initial_condition:
base_cols.append(col(f"_c{i}").alias(col_name))
else:
# For CDC files, skip the Op column by offsetting index
base_cols.append(col(f"_c{i+1}").alias(col_name))
stream = stream.select("Op", "dmsTimestamp", *base_cols)
logger.info(f"Mapped columns for {table}: {column_names}")
return stream.withColumn("_ingest_file_name", input_file_name())
@dlt.table(
name=f"{table}",
table_properties={
"quality": "silver",
"delta.autoOptimize.optimizeWrite": "true",
"delta.autoOptimize.autoCompact": "true"
},
comment=f"Merged DMS data for table: {table}"
)
@dlt.expect_all_or_drop({"valid_operation": "Op IN ('I', 'U', 'D') OR Op IS NULL"})
def create_merged_table():
logger.info(f"Creating/updating merged table for: {table}")
return dlt.apply_changes(
target=f"{table}",
source=f"{table}_cdc_raw",
keys=[info['id']],
sequence_by=col("dmsTimestamp"),
apply_as_deletes=expr("Op = 'D'"),
except_column_list=["Op", "dmsTimestamp", "_ingest_file_name"],
stored_as_scd_type=1
)
# Main execution
logger.info("Starting DMS pipeline execution")
# Get tables from bronze.mydb catalog
tables = get_tables_from_catalog()
logger.info(f"Processing tables: {', '.join(tables.keys())}")
# Generate tables
for table, info in tables.items():
try:
logger.info(f"Starting processing for table: {table}")
generate_tables(table, info)
logger.info(f"Successfully processed table: {table}")
except Exception as e:
logger.error(f"Error processing table {table}: {str(e)}")
raise
logger.info("DMS pipeline setup completed successfully")
Produces this error:
ERROR:SQLQueryContextLogger:[TABLE_OR_VIEW_NOT_FOUND] The table or view `bronze`.`mydb`.`delete_records` cannot be found. Verify the spelling and correctness of the schema and catalog.
If you did not qualify the name with a schema, verify the current_schema() output, or qualify the name with the correct schema and catalog.
To tolerate the error on drop use DROP VIEW IF EXISTS or DROP TABLE IF EXISTS. SQLSTATE: 42P01
Traceback (most recent call last):
File "/databricks/spark/python/pyspark/errors/exceptions/captured.py", line 263, in deco
return f(*a, **kw)
^^^^^^^^^^^
File "/databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/protocol.py", line 326, in get_return_value
raise Py4JJavaError(
py4j.protocol.Py4JJavaError: An error occurred while calling z:com.databricks.pipelines.SQLPipelineHelper.sqlWithAnalysisApi.
: org.apache.spark.sql.catalyst.ExtendedAnalysisException: [TABLE_OR_VIEW_NOT_FOUND] The table or view `bronze`.`mydb`.`delete_records` cannot be found. Verify the spelling and correctness of the schema and catalog.
If you did not qualify the name with a schema, verify the current_schema() output, or qualify the name with the correct schema and catalog.
To tolerate the error on drop use DROP VIEW IF EXISTS or DROP TABLE IF EXISTS. SQLSTATE: 42P01; line 1 pos 15;
'DescribeRelation false, [col_name#184549, data_type#184550, comment#184551]
+- 'UnresolvedTableOrView [bronze, mydb, delete_records], DESCRIBE TABLE, true
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.tableNotFound(package.scala:90)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2(CheckAnalysis.scala:248)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2$adapted(CheckAnalysis.scala:231)
at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:287)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1(TreeNode.scala:286)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1$adapted(TreeNode.scala:286)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
at scala.collection.IterableLike.foreach(IterableLike.scala:74)
at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:286)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0(CheckAnalysis.scala:231)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0$(CheckAnalysis.scala:213)
at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis0(Analyzer.scala:388)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis$1(CheckAnalysis.scala:198)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:185)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:185)
at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:388)
at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$2(Analyzer.scala:443)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:193)
at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:443)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:443)
at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:440)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:264)
at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:472)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$5(QueryExecution.scala:562)
at org.apache.spark.sql.execution.SQLExecution$.withExecutionPhase(SQLExecution.scala:144)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$4(QueryExecution.scala:562)
at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:1125)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:561)
at com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63)
at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:557)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1273)
at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:557)
at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:258)
at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:257)
at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:239)
at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:106)
at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1273)
at org.apache.spark.sql.SparkSession.$anonfun$withActiveAndFrameProfiler$1(SparkSession.scala:1280)
at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94)
at org.apache.spark.sql.SparkSession.withActiveAndFrameProfiler(SparkSession.scala:1280)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:104)
at com.databricks.pipelines.SQLPipeline$.analyze(SQLPipeline.scala:639)
at com.databricks.pipelines.SQLPipeline$.sqlWithAnalysisApi(SQLPipeline.scala:494)
at com.databricks.pipelines.SQLPipelineHelper$.sqlWithAnalysisApi(SQLPipeline.scala:442)
at com.databricks.pipelines.SQLPipelineHelper.sqlWithAnalysisApi(SQLPipeline.scala)
at jdk.internal.reflect.GeneratedMethodAccessor1554.invoke(Unknown Source)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:568)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:397)
at py4j.Gateway.invoke(Gateway.java:306)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:199)
at py4j.ClientServerConnection.run(ClientServerConnection.java:119)
at java.base/java.lang.Thread.run(Thread.java:840)
I’m relatively new to databricks but feels like this should be simpler to achieve.
4