Source code for tests.test_spark_transform

import pandas as pd
from pandas.testing import assert_frame_equal
from pyspark.sql import functions as F

from firelink.spark_transform import ConditionalMapping, Select, WithColumn


[docs]def test_withcolumn(test_spark_df): country = WithColumn("Country", "F.lit('Canada')") expected = pd.DataFrame( { "Country": ["Canada"] * 10, } ) output = country.fit_transform(test_spark_df).toPandas()[["Country"]] assert_frame_equal(output, expected)
[docs]def test_select(test_spark_df): select = Select(["a", "e"]) expected = pd.DataFrame( {"a": range(10), "e": [None, "d", "a", "d", "e", "e", "a", "a", "d", "d"]} ) output = select.fit_transform(test_spark_df).toPandas() assert_frame_equal(output, expected)
[docs]def test_conditionalmappilng(test_spark_df): cm = ConditionalMapping("e", "new_e", ["a", "d"], 0, 1) expected = pd.DataFrame({"new_e": [0, 1, 1, 1, 0, 0, 1, 1, 1, 1]}) output = cm.fit_transform(test_spark_df).toPandas()[["new_e"]].astype("int64") assert_frame_equal(output, expected)