4.4. Software Testing of Machine Learning Projects#

Machine learning code is very hard to test.

Due to the nature of the our models, we often have soft failures in the model that are difficult to test against. That basically means, they look like they’re doing what they’re supposed to, but secretly they’re not because of some bug.

Writing software tests in science, is already incredibly hard, so in this section we’ll touch on

  • some fairly simple tests we can implement to ensure consistency of our input data

  • avoid bad bugs in data loading procedures

  • some strategies to probe our models

First we’ll split the data from the Data notebook and load the model from the Sharing notebook.

from pathlib import Path

DATA_FOLDER = Path("..", "..") / "data"
DATA_FILEPATH = DATA_FOLDER / "penguins_clean.csv"
import pandas as pd
penguins = pd.read_csv(DATA_FILEPATH)
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Sex Species
0 39.1 18.7 181.0 MALE Adelie Penguin (Pygoscelis adeliae)
1 39.5 17.4 186.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
2 40.3 18.0 195.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
3 36.7 19.3 193.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
4 39.3 20.6 190.0 MALE Adelie Penguin (Pygoscelis adeliae)
from sklearn.model_selection import train_test_split
num_features = ["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)"]
cat_features = ["Sex"]
features = num_features + cat_features
target = ["Species"]

X_train, X_test, y_train, y_test = train_test_split(penguins[features], penguins[target], stratify=penguins[target[0]], train_size=.7, random_state=42)
from sklearn.svm import SVC
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import joblib
from joblib import load

MODEL_FOLDER = Path("..", "..") / "model"

clf.score(X_test, y_test)

4.4.1. Deterministic Tests#

When I work with neural networks, implementing a new layer, method, or fancy thing, I try to write a test for that layer. The Conv2D layer in Keras and Pytorch for example should always do the same exact thing, when they convole a kernel with an image.

Consider writing a small pytest test that takes a simple numpy array and tests against a known output.

You can check out the keras test suite here and an example how they validate the input and output shapes.

Admittedly this isn’t always easy to do and can go beyond the need for research scripts.

4.4.2. Data Tests for Models#

An even easier test is by essentially reusing the notebook from the Model Evaluation and writing a test function for it.

def test_penguins(clf):
    # Define data you definitely know the answer to
    test_data = pd.DataFrame([[34.6, 21.1, 198.0, "MALE"],
                              [46.1, 18.2, 178.0, "FEMALE"],
                              [52.5, 15.6, 221.0, "MALE"]],
             columns=["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)", "Sex"])
    # Define target to the data
    test_target = ['Adelie Penguin (Pygoscelis adeliae)',
                   'Chinstrap penguin (Pygoscelis antarctica)',
                   'Gentoo penguin (Pygoscelis papua)']
    # Assert the model should get these right.
    assert clf.score(test_data, test_target) == 1

This means we have some samples in the data, where we clearly know they should be part of one class and we can use these to test the model.

4.4.3. Automated Testing of Docstring Examples#

There is an even easier way to run simple tests. This can be useful when we write specific functions to pre-process our data. In the Model Sharing notebook, we looked into auto-generating docstrings.

We can upgrade our docstring and get free software tests out of it!

This is called doctest and usually useful to keep docstring examples up to date and write quick unit tests for a function.

This makes future users (including yourself from the future) quite happy.

def shorten_class_name(df: pd.DataFrame) -> pd.DataFrame:
    """Shorten the class names of the penguins to the shortest version

    df : pd.DataFrame
        Dataframe containing the Species column with penguins

        Normalised dataframe with shortened names

    >>> shorten_class_name(pd.DataFrame([[1,2,3,"Adelie Penguin (Pygoscelis adeliae)"]], columns=["1","2","3","Species"]))
       1  2  3 Species
    0  1  2  3  Adelie
    df["Species"] = df.Species.str.split(r" [Pp]enguin", n=1, expand=True)[0]

    return df

import doctest
TestResults(failed=0, attempted=1)
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Sex Species
0 39.1 18.7 181.0 MALE Adelie
1 39.5 17.4 186.0 FEMALE Adelie
2 40.3 18.0 195.0 FEMALE Adelie
3 36.7 19.3 193.0 FEMALE Adelie
4 39.3 20.6 190.0 MALE Adelie

So these give a nice example of usage in the docstring, an expected output and a first test case that is validated by our test suite.

4.4.4. Input Data Validation#

You validate that the data that users are providing matches what your model is expecting.

These tools are often used in production systems to determine whether APIs usage and user inputs are formatted correctly.

Example tools are:

import pandera as pa
# data to validate
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm)
count 233.000000 233.000000 233.000000
mean 43.982403 17.228755 201.412017
std 5.537146 1.994191 13.929695
min 33.500000 13.100000 172.000000
25% 39.000000 15.700000 190.000000
50% 44.400000 17.300000 198.000000
75% 48.800000 18.800000 213.000000
max 59.600000 21.200000 231.000000

The following code is supposed to fail to see what happens if the schema doesn’t match!

# define schema
schema = pa.DataFrameSchema({
    "Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
    "Culmen Depth (mm)": pa.Column(float, checks=[pa.Check.ge(13),
    "Flipper Length (mm)": pa.Column(float, checks=[pa.Check.ge(170),
    "Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),

validated_test = schema(X_test)
SchemaError                               Traceback (most recent call last)
Cell In[11], line 12
      1 # define schema
      2 schema = pa.DataFrameSchema({
      3     "Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
      4                                                    pa.Check.le(60)]),
      9     "Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
     10 })
---> 12 validated_test = schema(X_test)

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/dataframe/container.py:327, in DataFrameSchema.__call__(self, dataframe, head, tail, sample, random_state, lazy, inplace)
    299 def __call__(
    300     self,
    301     dataframe: TDataObject,
    307     inplace: bool = False,
    308 ) -> TDataObject:
    309     """Alias for :func:`DataFrameSchema.validate` method.
    311     :param pd.DataFrame dataframe: the dataframe to be validated.
    325         otherwise creates a copy of the data.
    326     """
--> 327     return self.validate(
    328         dataframe, head, tail, sample, random_state, lazy, inplace
    329     )

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:126, in DataFrameSchema.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
    114     check_obj = check_obj.map_partitions(  # type: ignore [operator]
    115         self._validate,
    116         head=head,
    122         meta=check_obj,
    123     )
    124     return check_obj.pandera.add_schema(self)
--> 126 return self._validate(
    127     check_obj=check_obj,
    128     head=head,
    129     tail=tail,
    130     sample=sample,
    131     random_state=random_state,
    132     lazy=lazy,
    133     inplace=inplace,
    134 )

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:156, in DataFrameSchema._validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
    147 if self._is_inferred:
    148     warnings.warn(
    149         f"This {type(self)} is an inferred schema that hasn't been "
    150         "modified. It's recommended that you refine the schema "
    153         UserWarning,
    154     )
--> 156 return self.get_backend(check_obj).validate(
    157     check_obj,
    158     schema=self,
    159     head=head,
    160     tail=tail,
    161     sample=sample,
    162     random_state=random_state,
    163     lazy=lazy,
    164     inplace=inplace,
    165 )

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:105, in DataFrameSchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
    100 components = self.collect_schema_components(
    101     check_obj, schema, column_info
    102 )
    104 # run the checks
--> 105 error_handler = self.run_checks_and_handle_errors(
    106     error_handler,
    107     schema,
    108     check_obj,
    109     column_info,
    110     sample,
    111     components,
    112     lazy,
    113     head,
    114     tail,
    115     random_state,
    116 )
    118 if error_handler.collected_errors:
    119     if getattr(schema, "drop_invalid_rows", False):

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:180, in DataFrameSchemaBackend.run_checks_and_handle_errors(self, error_handler, schema, check_obj, column_info, sample, components, lazy, head, tail, random_state)
    169         else:
    170             error = SchemaError(
    171                 schema,
    172                 data=check_obj,
    178                 reason_code=result.reason_code,
    179             )
--> 180         error_handler.collect_error(
    181             validation_type(result.reason_code),
    182             result.reason_code,
    183             error,
    184             result.original_exc,
    185         )
    187 return error_handler

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/base/error_handler.py:54, in ErrorHandler.collect_error(self, error_type, reason_code, schema_error, original_exc)
     47 """Collect schema error, raising exception if lazy is False.
     49 :param error_type: type of error
     50 :param reason_code: string representing reason for error
     51 :param schema_error: ``SchemaError`` object.
     52 """
     53 if not self._lazy:
---> 54     raise schema_error from original_exc
     56 # delete data of validated object from SchemaError object to prevent
     57 # storing copies of the validated DataFrame/Series for every
     58 # SchemaError collected.
     59 if hasattr(schema_error, "data"):

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:201, in DataFrameSchemaBackend.run_schema_component_checks(self, check_obj, schema_components, lazy)
    199 for schema_component in schema_components:
    200     try:
--> 201         result = schema_component.validate(
    202             check_obj, lazy=lazy, inplace=True
    203         )
    204         check_passed.append(is_table(result))
    205     except SchemaError as err:

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/dataframe/components.py:162, in ComponentSchema.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
    133 def validate(
    134     self,
    135     check_obj,
    142 ):
    143     # pylint: disable=too-many-locals,too-many-branches,too-many-statements
    144     """Validate a series or specific column in dataframe.
    146     :check_obj: data object to validate.
    161     """
--> 162     return self.get_backend(check_obj).validate(
    163         check_obj,
    164         schema=self,
    165         head=head,
    166         tail=tail,
    167         sample=sample,
    168         random_state=random_state,
    169         lazy=lazy,
    170         inplace=inplace,
    171     )

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:136, in ColumnBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
    132             check_obj = validate_column(
    133                 check_obj, column_name, return_check_obj=True
    134             )
    135         else:
--> 136             validate_column(check_obj, column_name)
    138 if lazy and error_handler.collected_errors:
    139     raise SchemaErrors(
    140         schema=schema,
    141         schema_errors=error_handler.schema_errors,
    142         data=check_obj,
    143     )

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:92, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name, return_check_obj)
     88         error_handler.collect_error(
     89             validation_type(err.reason_code), err.reason_code, err
     90         )
     91 except SchemaError as err:
---> 92     error_handler.collect_error(
     93         validation_type(err.reason_code), err.reason_code, err
     94     )

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/base/error_handler.py:54, in ErrorHandler.collect_error(self, error_type, reason_code, schema_error, original_exc)
     47 """Collect schema error, raising exception if lazy is False.
     49 :param error_type: type of error
     50 :param reason_code: string representing reason for error
     51 :param schema_error: ``SchemaError`` object.
     52 """
     53 if not self._lazy:
---> 54     raise schema_error from original_exc
     56 # delete data of validated object from SchemaError object to prevent
     57 # storing copies of the validated DataFrame/Series for every
     58 # SchemaError collected.
     59 if hasattr(schema_error, "data"):

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:72, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name, return_check_obj)
     69 def validate_column(check_obj, column_name, return_check_obj=False):
     70     try:
     71         # pylint: disable=super-with-arguments
---> 72         validated_check_obj = super(ColumnBackend, self).validate(
     73             check_obj,
     74             deepcopy(schema).set_name(column_name),
     75             head=head,
     76             tail=tail,
     77             sample=sample,
     78             random_state=random_state,
     79             lazy=lazy,
     80             inplace=inplace,
     81         )
     83         if return_check_obj:
     84             return validated_check_obj

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/array.py:81, in ArraySchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
     75 check_obj = self.run_parsers(
     76     schema,
     77     check_obj,
     78 )
     80 # run the core checks
---> 81 error_handler = self.run_checks_and_handle_errors(
     82     error_handler,
     83     schema,
     84     check_obj,
     85     head=head,
     86     tail=tail,
     87     sample=sample,
     88     random_state=random_state,
     89 )
     91 if lazy and error_handler.collected_errors:
     92     if getattr(schema, "drop_invalid_rows", False):

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/array.py:145, in ArraySchemaBackend.run_checks_and_handle_errors(self, error_handler, schema, check_obj, **subsample_kwargs)
    134         else:
    135             error = SchemaError(
    136                 schema=schema,
    137                 data=check_obj,
    143                 reason_code=result.reason_code,
    144             )
--> 145             error_handler.collect_error(
    146                 validation_type(result.reason_code),
    147                 result.reason_code,
    148                 error,
    149                 original_exc=result.original_exc,
    150             )
    152 return error_handler

File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/base/error_handler.py:54, in ErrorHandler.collect_error(self, error_type, reason_code, schema_error, original_exc)
     47 """Collect schema error, raising exception if lazy is False.
     49 :param error_type: type of error
     50 :param reason_code: string representing reason for error
     51 :param schema_error: ``SchemaError`` object.
     52 """
     53 if not self._lazy:
---> 54     raise schema_error from original_exc
     56 # delete data of validated object from SchemaError object to prevent
     57 # storing copies of the validated DataFrame/Series for every
     58 # SchemaError collected.
     59 if hasattr(schema_error, "data"):

SchemaError: Column 'Sex' failed element-wise validator number 0: isin(['MALE', 'FEMALE']) failure cases: .
array(['FEMALE', 'MALE', '.'], dtype=object)
Culmen Length (mm)      44.5
Culmen Depth (mm)       15.7
Flipper Length (mm)    217.0
Sex                        .
Name: 259, dtype: object

Can you fix the data to conform to the schema?