Close spark session

This commit is contained in:
kssenii 2023-04-11 17:23:05 +02:00
parent 00282483c9
commit e32c98e412
4 changed files with 99 additions and 80 deletions

View File

@ -447,6 +447,8 @@ class ClickHouseCluster:
self.minio_redirect_ip = None
self.minio_redirect_port = 8080
self.spark_session = None
self.with_azurite = False
# available when with_hdfs == True

View File

@ -29,10 +29,23 @@ from pyspark.sql.window import Window
from helpers.s3_tools import prepare_s3_bucket, upload_directory, get_file_contents
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("spark_test")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
.master("local")
)
return configure_spark_with_delta_pip(builder).master("local").getOrCreate()
@pytest.fixture(scope="module")
def started_cluster():
try:
@ -48,26 +61,17 @@ def started_cluster():
prepare_s3_bucket(cluster)
if cluster.spark_session is not None:
cluster.spark_session.stop()
cluster.spark_session._instantiatedContext = None
cluster.spark_session = get_spark()
yield cluster
finally:
cluster.shutdown()
def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("spark_test")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
.master("local")
)
return configure_spark_with_delta_pip(builder).master("local").getOrCreate()
def write_delta_from_file(spark, path, result_path, mode="overwrite"):
spark.read.load(path).write.mode(mode).option("compression", "none").format(
"delta"
@ -139,9 +143,9 @@ def create_initial_data_file(
def test_single_log_file(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_single_log_file"
inserted_data = "SELECT number, toString(number + 1) FROM numbers(100)"
@ -163,9 +167,9 @@ def test_single_log_file(started_cluster):
def test_partition_by(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_partition_by"
write_delta_from_df(
@ -185,9 +189,9 @@ def test_partition_by(started_cluster):
def test_checkpoint(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_checkpoint"
write_delta_from_df(
@ -258,9 +262,9 @@ def test_checkpoint(started_cluster):
def test_multiple_log_files(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_multiple_log_files"
write_delta_from_df(
@ -296,9 +300,9 @@ def test_multiple_log_files(started_cluster):
def test_metadata(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_metadata"
parquet_data_path = create_initial_data_file(
@ -328,8 +332,8 @@ def test_metadata(started_cluster):
def test_types(started_cluster):
spark = get_spark()
TABLE_NAME = "test_types"
spark = started_cluster.spark_session
result_file = f"{TABLE_NAME}_result_2"
delta_table = (

View File

@ -28,6 +28,26 @@ from pyspark.sql.window import Window
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("spark_test")
.config(
"spark.jars.packages",
"org.apache.hudi:hudi-spark3.3-bundle_2.12:0.13.0",
)
.config(
"org.apache.spark.sql.hudi.catalog.HoodieCatalog",
)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config(
"spark.sql.catalog.local", "org.apache.spark.sql.hudi.catalog.HoodieCatalog"
)
.config("spark.driver.memory", "20g")
.master("local")
)
return builder.master("local").getOrCreate()
@pytest.fixture(scope="module")
def started_cluster():
try:
@ -44,6 +64,12 @@ def started_cluster():
prepare_s3_bucket(cluster)
logging.info("S3 bucket created")
if cluster.spark_session is not None:
cluster.spark_session.stop()
cluster.spark_session._instantiatedContext = None
cluster.spark_session = get_spark()
yield cluster
finally:
@ -60,26 +86,6 @@ def run_query(instance, query, stdin=None, settings=None):
return result
def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("spark_test")
.config(
"spark.jars.packages",
"org.apache.hudi:hudi-spark3.3-bundle_2.12:0.12.0",
)
.config(
"org.apache.spark.sql.hudi.catalog.HoodieCatalog",
)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config(
"spark.sql.catalog.local", "org.apache.spark.sql.hudi.catalog.HoodieCatalog"
)
.config("spark.driver.memory", "20g")
.master("local")
)
return builder.master("local").getOrCreate()
def write_hudi_from_df(spark, table_name, df, result_path, mode="overwrite"):
if mode is "overwrite":
hudi_write_mode = "insert_overwrite"
@ -165,9 +171,9 @@ def create_initial_data_file(
def test_single_hudi_file(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_single_hudi_file"
inserted_data = "SELECT number as a, toString(number) as b FROM numbers(100)"
@ -188,9 +194,9 @@ def test_single_hudi_file(started_cluster):
def test_multiple_hudi_files(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_multiple_hudi_files"
write_hudi_from_df(
@ -258,9 +264,9 @@ def test_multiple_hudi_files(started_cluster):
def test_types(started_cluster):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_types"
data = [

View File

@ -31,38 +31,6 @@ from helpers.s3_tools import prepare_s3_bucket, upload_directory, get_file_conte
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster = ClickHouseCluster(__file__)
cluster.add_instance(
"node1",
main_configs=["configs/config.d/named_collections.xml"],
with_minio=True,
)
logging.info("Starting cluster...")
cluster.start()
prepare_s3_bucket(cluster)
logging.info("S3 bucket created")
yield cluster
finally:
cluster.shutdown()
def run_query(instance, query, stdin=None, settings=None):
# type: (ClickHouseInstance, str, object, dict) -> str
logging.info("Running query '{}'...".format(query))
result = instance.query(query, stdin=stdin, settings=settings)
logging.info("Query finished")
return result
def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("spark_test")
@ -82,6 +50,44 @@ def get_spark():
return builder.master("local").getOrCreate()
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster = ClickHouseCluster(__file__)
cluster.add_instance(
"node1",
main_configs=["configs/config.d/named_collections.xml"],
with_minio=True,
)
logging.info("Starting cluster...")
cluster.start()
prepare_s3_bucket(cluster)
logging.info("S3 bucket created")
if cluster.spark_session is not None:
cluster.spark_session.stop()
cluster.spark_session._instantiatedContext = None
cluster.spark_session = get_spark()
yield cluster
finally:
cluster.shutdown()
def run_query(instance, query, stdin=None, settings=None):
# type: (ClickHouseInstance, str, object, dict) -> str
logging.info("Running query '{}'...".format(query))
result = instance.query(query, stdin=stdin, settings=settings)
logging.info("Query finished")
return result
def write_iceberg_from_file(
spark, path, table_name, mode="overwrite", format_version="1", partition_by=None
):
@ -161,9 +167,9 @@ def create_initial_data_file(
@pytest.mark.parametrize("format_version", ["1", "2"])
def test_single_iceberg_file(started_cluster, format_version):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_single_iceberg_file_" + format_version
inserted_data = "SELECT number, toString(number) FROM numbers(100)"
@ -174,6 +180,7 @@ def test_single_iceberg_file(started_cluster, format_version):
write_iceberg_from_file(
spark, parquet_data_path, TABLE_NAME, format_version=format_version
)
time.sleep(500)
files = upload_directory(
minio_client, bucket, f"/iceberg_data/default/{TABLE_NAME}/", ""
@ -188,9 +195,9 @@ def test_single_iceberg_file(started_cluster, format_version):
@pytest.mark.parametrize("format_version", ["1", "2"])
def test_partition_by(started_cluster, format_version):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_partition_by_" + format_version
write_iceberg_from_df(
@ -214,9 +221,9 @@ def test_partition_by(started_cluster, format_version):
@pytest.mark.parametrize("format_version", ["1", "2"])
def test_multiple_iceberg_files(started_cluster, format_version):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_multiple_iceberg_files_" + format_version
write_iceberg_from_df(
@ -261,9 +268,9 @@ def test_multiple_iceberg_files(started_cluster, format_version):
@pytest.mark.parametrize("format_version", ["1", "2"])
def test_types(started_cluster, format_version):
instance = started_cluster.instances["node1"]
spark = started_cluster.spark_session
minio_client = started_cluster.minio_client
bucket = started_cluster.minio_bucket
spark = get_spark()
TABLE_NAME = "test_types_" + format_version
data = [