PySpark Mocking: Exception Test Succeeds but Exception is not handled

Question:

I am using python 2.7 (don’t ask me why, I am a contractor, I just work with what they give me).

I am trying to implement a pyspark function that leverages the spark-bigquery connector to submit a simple query leveraging the Spark SQL Data Source API.

I am experiencing the weirdest thing; I wrote the function and confirmed it really works against the server when I actually run it. I wanted to make sure that if the user provides a table name that does not exist an exception will be thrown as per handling the one returned by the server, and I did (I know this is not TDD but just go with this). I then proceeded to write the test for it and I obviously had to generate a mock exception which I did as follows:

module/query_bq

from py4j.protocol import Py4JJavaError
from pyspark.sql import SparkSession

def submit_bq_query(spark, table, filter_string):
    try:
        df = spark.read.format('bigquery').option('table', table).option('filter', filter_string).load()
        return df
    except Py4JJavaError as e:
        java_error_msg = str(e).split('n')[1]
        if "java.lang.RuntimeException" in java_error_msg and ("{} not found".format(table)) in java_error_msg:
            raise Exception("RuntimeException: Table {} not found!".format(table))

As I said, this works like a charm. Now, the test for it looks like this:

module/test_query_bq

import pytest
from mock import patch, mock
from py4j.java_gateway import GatewayProperty, GatewayClient, JavaObject
from py4j.protocol import Py4JJavaError
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StructType


def mock_p4j_java_error_generator(msg):
    gateway_property = GatewayProperty(auto_field="Mock", pool="Mock")
    client = GatewayClient(gateway_property=gateway_property)
    java_object = JavaObject("RunTimeError", client)
    exception = Py4JJavaError(msg, java_exception=java_object)
    return Exception(exception)


def test_exception_is_thrown_if_table_not_present():

    # Given
    mock_table_name = 'spark_bq_test.false_table_name'
    mock_filter = "word is 'V'"
    mock_errmsg = "Table {} not found".format(mock_table_name)

    # Mocking
    mock_spark = mock.Mock()
    mock_spark_reader = mock.Mock()

    # Mocking return-values setup
    mock_spark.read.format.return_value = mock_spark_reader
    mock_spark_reader.option.return_value = mock_spark_reader
    mock_spark_reader.load.side_effect = mock_p4j_java_error_generator(mock_errmsg)

    # When
    with pytest.raises(Exception) as exception:
        submit_bq_query(mock_spark, mock_table_name, mock_filter)
    assert exception.value.message.errmsg == mock_errmsg

Running the test succeeds, but when I try to debug it, just to follow the execution, I notice that the code just after the exception is caught:

module/query_bq

...
    except Py4JJavaError as e:
        java_error_msg = str(e).split('n')[1] .  # This line is never reached!
        if "java.lang.RuntimeException" in java_error_msg and ("{} not found".format(table)) in java_error_msg:
            raise Exception("RuntimeException: Table {} not found!".format(table))
...

is never reached. Nevertheless, the test, still, does succeed.

In short, the exception is mocked and thrown as should in the test. It is also caught but it is not handled. The test’s assertion passes and the test is successful as if it was handled when it is not, but I never get to inspect the insides of the mock exception. Once more, let me note that module/query_bq works against the server just fine; returns dataframes and handles exceptions just fine when a table is not present! The point here is testing.

I need to do additional things to the handling part of the exception in module/query_bq, but I can’t ’cause I don’t know what is happening. Can anyone explain?

Answers:

After 3 days of struggling, I sorted it out. The main problem was that:

  • I was not properly mocking the spark.read process signature, and;
  • I was not properly instantiating a mock instance of a Py4JJavaError.

Here is how I did both of those:

…/utils/bigquery_util.py

import logging

from py4j.protocol import Py4JJavaError, Py4JNetworkError


def load_bq_table(spark, table, filter_string):
    tries = 3
    for i in range(tries):
        try:
            logging.info("SQL statement being executed...")
            df = get_df(spark, table, filter_string)
            logging.info("Table-ID: {}, Rows:{} x Cols:{}".format(table, df.count(), len(df.columns)))
            logging.debug("Table-ID: {}, Schema: {}".format(table, df.schema))
            return df
        except Py4JJavaError as e:
            java_exception_str = get_exception_msg(e)
            is_runtime_exception = "java.lang.RuntimeException" in java_exception_str
            table_not_found = ("{} not found".format(table)) in java_exception_str
            if is_runtime_exception and table_not_found:
                logging.error(java_exception_str)
                raise RuntimeError("Table {} not found!".format(table))
        except Py4JNetworkError as ne:
            if i is tries-1:
                java_exception_str = ne.cause
                runtime_error_str = "Error while trying to reach server... {}"
                logging.error(java_exception_str)
                raise EnvironmentError(runtime_error_str.format(java_exception_str))
            continue


def get_exception_msg(e):
    return str(e.java_exception)


def get_df(spark, table, filter_string):
    return (spark.read
            .format('bigquery')
            .option('table', table)
            .option('filter', filter_string)
            .load())

As for testing: …/test/utils/test_bigquery_util.py

    import pytest
    from mock import patch, mock

    from <...>.utils.bigquery_util import load_bq_table
    from <...>.test.utils.mock_py4jerror import *

    def test_runtime_error_exception_is_thrown_if_table_not_present():

        # Given
        mock_table_name = 'spark_bq_test.false_table_name'
        mock_filter = "word is 'V'"

        # Mocking
        py4j_error_exception = get_mock_py4j_error_exception(get_mock_gateway_client(), mock_target_id="o123")
        mock_errmsg = "java.lang.RuntimeException: Table {} not found".format(mock_table_name)

        # When
        with mock.patch('red_agent.common.utils.bigquery_util.get_exception_msg', return_value=mock_errmsg):
            with mock.patch('red_agent.common.utils.bigquery_util.get_df', side_effect=py4j_error_exception):
                with pytest.raises(RuntimeError):
                    mock_spark = mock.Mock()
                    df = load_bq_table(mock_spark, mock_table_name, mock_filter)

..and finally for mocking the Py4JJavaError: …/test/utils/mock_py4jerror.py

import mock
from py4j.protocol import Py4JJavaError, Py4JNetworkError


def get_mock_gateway_client():
    mock_client = mock.Mock()
    mock_client.send_command.return_value = "0"
    mock_client.converters = []
    mock_client.is_connected.return_value = True
    mock_client.deque = mock.Mock()
    return mock_client


def get_mock_java_object(mock_client, mock_target_id):
    mock_java_object = mock.Mock()
    mock_java_object._target_id = mock_target_id
    mock_java_object._gateway_client = mock_client
    return mock_java_object


def get_mock_py4j_error_exception(mock_client, mock_target_id):
    mock_java_object = get_mock_java_object(mock_client, mock_target_id)
    mock_errmsg = "An error occurred while calling {}.load.".format(mock_target_id)
    return Py4JJavaError(mock_errmsg, java_exception=mock_java_object)


def get_mock_py4j_network_exception(mock_target_id):
    mock_errmsg = "An error occurred while calling {}.load.".format(mock_target_id)
    return Py4JNetworkError(mock_errmsg)

Hope this will help someone…

I was getting the error TypeError: exceptions must derive from BaseException when I tried using something like create_autospec(spec=Py4JJavaError). The accepted answer helped me write this solution:

def my_function():
    try:
        problematic_code()
    except Py4JJavaError as e:
        if "java.lang.RuntimeException" in str(e.java_exception):
            raise MyException(f"The problematic code failed with {e}.") from e
        raise e


def test(self):
    java_exception = MagicMock(_target_id="_")
    java_exception.__str__.return_value = "java.lang.RuntimeException"
    py4j_java_error = Py4JJavaError("_", java_exception=java_exception)

    with patch.object(Py4JJavaError, "__str__", return_value="_"), 
            patch("my_module.problematic_code", side_effect=py4j_java_error):
        self.assertRaises(MyException, my_function)

I realized the Java gateway client was only used in the __str__ method of Py4JJavaError, so I just mocked that whole method instead of mocking the client.

Reference: https://github.com/py4j/py4j/blob/master/py4j-python/src/py4j/protocol.py

Answered By: Heather Sawatsky