4.4. Software Testing of Machine Learning Projects#
Machine learning code is very hard to test.
Due to the nature of the our models, we often have soft failures in the model that are difficult to test against. That basically means, they look like they’re doing what they’re supposed to, but secretly they’re not because of some bug.
Writing software tests in science, is already incredibly hard, so in this section we’ll touch on
some fairly simple tests we can implement to ensure consistency of our input data
avoid bad bugs in data loading procedures
some strategies to probe our models
First we’ll split the data from the Data notebook and load the model from the Sharing notebook.
from pathlib import Path
DATA_FOLDER = Path("..", "..") / "data"
DATA_FILEPATH = DATA_FOLDER / "penguins_clean.csv"
import pandas as pd
penguins = pd.read_csv(DATA_FILEPATH)
penguins.head()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | Sex | Species | |
---|---|---|---|---|---|
0 | 39.1 | 18.7 | 181.0 | MALE | Adelie Penguin (Pygoscelis adeliae) |
1 | 39.5 | 17.4 | 186.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
2 | 40.3 | 18.0 | 195.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
3 | 36.7 | 19.3 | 193.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
4 | 39.3 | 20.6 | 190.0 | MALE | Adelie Penguin (Pygoscelis adeliae) |
from sklearn.model_selection import train_test_split
num_features = ["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)"]
cat_features = ["Sex"]
features = num_features + cat_features
target = ["Species"]
X_train, X_test, y_train, y_test = train_test_split(penguins[features], penguins[target], stratify=penguins[target[0]], train_size=.7, random_state=42)
from sklearn.svm import SVC
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import joblib
from joblib import load
MODEL_FOLDER = Path("..", "..") / "model"
MODEL_EXPORT_FILE = MODEL_FOLDER / "svc.joblib"
clf = load(MODEL_EXPORT_FILE)
clf.score(X_test, y_test)
1.0
4.4.1. Deterministic Tests#
When I work with neural networks, implementing a new layer, method, or fancy thing, I try to write a test for that layer. The Conv2D
layer in Keras and Pytorch for example should always do the same exact thing, when they convole a kernel with an image.
Consider writing a small pytest
test that takes a simple numpy array and tests against a known output.
You can check out the keras
test suite here and an example how they validate the input and output shapes.
Admittedly this isn’t always easy to do and can go beyond the need for research scripts.
4.4.2. Data Tests for Models#
An even easier test is by essentially reusing the notebook from the Model Evaluation and writing a test function for it.
def test_penguins(clf):
# Define data you definitely know the answer to
test_data = pd.DataFrame([[34.6, 21.1, 198.0, "MALE"],
[46.1, 18.2, 178.0, "FEMALE"],
[52.5, 15.6, 221.0, "MALE"]],
columns=["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)", "Sex"])
# Define target to the data
test_target = ['Adelie Penguin (Pygoscelis adeliae)',
'Chinstrap penguin (Pygoscelis antarctica)',
'Gentoo penguin (Pygoscelis papua)']
# Assert the model should get these right.
assert clf.score(test_data, test_target) == 1
test_penguins(clf)
This means we have some samples in the data, where we clearly know they should be part of one class and we can use these to test the model.
4.4.3. Automated Testing of Docstring Examples#
There is an even easier way to run simple tests. This can be useful when we write specific functions to pre-process our data. In the Model Sharing notebook, we looked into auto-generating docstrings.
We can upgrade our docstring and get free software tests out of it!
This is called doctest and usually useful to keep docstring examples up to date and write quick unit tests for a function.
This makes future users (including yourself from the future) quite happy.
def shorten_class_name(df: pd.DataFrame) -> pd.DataFrame:
"""Shorten the class names of the penguins to the shortest version
Parameters
----------
df : pd.DataFrame
Dataframe containing the Species column with penguins
Returns
-------
pd.DataFrame
Normalised dataframe with shortened names
Examples
--------
>>> shorten_class_name(pd.DataFrame([[1,2,3,"Adelie Penguin (Pygoscelis adeliae)"]], columns=["1","2","3","Species"]))
1 2 3 Species
0 1 2 3 Adelie
"""
df["Species"] = df.Species.str.split(r" [Pp]enguin", n=1, expand=True)[0]
return df
import doctest
doctest.testmod()
TestResults(failed=0, attempted=1)
shorten_class_name(penguins).head()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | Sex | Species | |
---|---|---|---|---|---|
0 | 39.1 | 18.7 | 181.0 | MALE | Adelie |
1 | 39.5 | 17.4 | 186.0 | FEMALE | Adelie |
2 | 40.3 | 18.0 | 195.0 | FEMALE | Adelie |
3 | 36.7 | 19.3 | 193.0 | FEMALE | Adelie |
4 | 39.3 | 20.6 | 190.0 | MALE | Adelie |
So these give a nice example of usage in the docstring, an expected output and a first test case that is validated by our test suite.
4.4.4. Input Data Validation#
You validate that the data that users are providing matches what your model is expecting.
These tools are often used in production systems to determine whether APIs usage and user inputs are formatted correctly.
Example tools are:
import pandera as pa
# data to validate
X_train.describe()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | |
---|---|---|---|
count | 233.000000 | 233.000000 | 233.000000 |
mean | 43.982403 | 17.228755 | 201.412017 |
std | 5.537146 | 1.994191 | 13.929695 |
min | 33.500000 | 13.100000 | 172.000000 |
25% | 39.000000 | 15.700000 | 190.000000 |
50% | 44.400000 | 17.300000 | 198.000000 |
75% | 48.800000 | 18.800000 | 213.000000 |
max | 59.600000 | 21.200000 | 231.000000 |
The following code is supposed to fail to see what happens if the schema doesn’t match!
# define schema
schema = pa.DataFrameSchema({
"Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
pa.Check.le(60)]),
"Culmen Depth (mm)": pa.Column(float, checks=[pa.Check.ge(13),
pa.Check.le(22)]),
"Flipper Length (mm)": pa.Column(float, checks=[pa.Check.ge(170),
pa.Check.le(235)]),
"Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
})
validated_test = schema(X_test)
---------------------------------------------------------------------------
SchemaError Traceback (most recent call last)
Cell In[11], line 12
1 # define schema
2 schema = pa.DataFrameSchema({
3 "Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
4 pa.Check.le(60)]),
(...)
9 "Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
10 })
---> 12 validated_test = schema(X_test)
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:443, in DataFrameSchema.__call__(self, dataframe, head, tail, sample, random_state, lazy, inplace)
415 def __call__(
416 self,
417 dataframe: pd.DataFrame,
(...)
423 inplace: bool = False,
424 ):
425 """Alias for :func:`DataFrameSchema.validate` method.
426
427 :param pd.DataFrame dataframe: the dataframe to be validated.
(...)
441 otherwise creates a copy of the data.
442 """
--> 443 return self.validate(
444 dataframe, head, tail, sample, random_state, lazy, inplace
445 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:375, in DataFrameSchema.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
363 check_obj = check_obj.map_partitions( # type: ignore [operator]
364 self._validate,
365 head=head,
(...)
371 meta=check_obj,
372 )
373 return check_obj.pandera.add_schema(self)
--> 375 return self._validate(
376 check_obj=check_obj,
377 head=head,
378 tail=tail,
379 sample=sample,
380 random_state=random_state,
381 lazy=lazy,
382 inplace=inplace,
383 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:404, in DataFrameSchema._validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
395 if self._is_inferred:
396 warnings.warn(
397 f"This {type(self)} is an inferred schema that hasn't been "
398 "modified. It's recommended that you refine the schema "
(...)
401 UserWarning,
402 )
--> 404 return self.get_backend(check_obj).validate(
405 check_obj,
406 schema=self,
407 head=head,
408 tail=tail,
409 sample=sample,
410 random_state=random_state,
411 lazy=lazy,
412 inplace=inplace,
413 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:97, in DataFrameSchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
92 components = self.collect_schema_components(
93 check_obj, schema, column_info
94 )
96 # run the checks
---> 97 error_handler = self.run_checks_and_handle_errors(
98 error_handler,
99 schema,
100 check_obj,
101 column_info,
102 sample,
103 components,
104 lazy,
105 head,
106 tail,
107 random_state,
108 )
110 if error_handler.collected_errors:
111 if getattr(schema, "drop_invalid_rows", False):
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:172, in DataFrameSchemaBackend.run_checks_and_handle_errors(self, error_handler, schema, check_obj, column_info, sample, components, lazy, head, tail, random_state)
161 else:
162 error = SchemaError(
163 schema,
164 data=check_obj,
(...)
170 reason_code=result.reason_code,
171 )
--> 172 error_handler.collect_error(
173 result.reason_code,
174 error,
175 original_exc=result.original_exc,
176 )
178 return error_handler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/error_handlers.py:38, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
31 """Collect schema error, raising exception if lazy is False.
32
33 :param reason_code: string representing reason for error.
34 :param schema_error: ``SchemaError`` object.
35 :param original_exc: original exception associated with the SchemaError.
36 """
37 if not self._lazy:
---> 38 raise schema_error from original_exc
40 # delete data of validated object from SchemaError object to prevent
41 # storing copies of the validated DataFrame/Series for every
42 # SchemaError collected.
43 del schema_error.data
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:192, in DataFrameSchemaBackend.run_schema_component_checks(self, check_obj, schema_components, lazy)
190 for schema_component in schema_components:
191 try:
--> 192 result = schema_component.validate(
193 check_obj, lazy=lazy, inplace=True
194 )
195 check_passed.append(is_table(result))
196 except SchemaError as err:
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/components.py:169, in Column.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
142 def validate(
143 self,
144 check_obj: pd.DataFrame,
(...)
150 inplace: bool = False,
151 ) -> pd.DataFrame:
152 """Validate a Column in a DataFrame object.
153
154 :param check_obj: pandas DataFrame to validate.
(...)
167 :returns: validated DataFrame.
168 """
--> 169 return self.get_backend(check_obj).validate(
170 check_obj,
171 self,
172 head=head,
173 tail=tail,
174 sample=sample,
175 random_state=random_state,
176 lazy=lazy,
177 inplace=inplace,
178 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:119, in ColumnBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
115 check_obj = validate_column(
116 check_obj, column_name, return_check_obj=True
117 )
118 else:
--> 119 validate_column(check_obj, column_name)
121 if lazy and error_handler.collected_errors:
122 raise SchemaErrors(
123 schema=schema,
124 schema_errors=error_handler.collected_errors,
125 data=check_obj,
126 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:89, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name, return_check_obj)
84 error_handler.collect_error(
85 reason_code=None,
86 schema_error=err,
87 )
88 except SchemaError as err:
---> 89 error_handler.collect_error(err.reason_code, err)
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/error_handlers.py:38, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
31 """Collect schema error, raising exception if lazy is False.
32
33 :param reason_code: string representing reason for error.
34 :param schema_error: ``SchemaError`` object.
35 :param original_exc: original exception associated with the SchemaError.
36 """
37 if not self._lazy:
---> 38 raise schema_error from original_exc
40 # delete data of validated object from SchemaError object to prevent
41 # storing copies of the validated DataFrame/Series for every
42 # SchemaError collected.
43 del schema_error.data
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:68, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name, return_check_obj)
65 def validate_column(check_obj, column_name, return_check_obj=False):
66 try:
67 # pylint: disable=super-with-arguments
---> 68 validated_check_obj = super(ColumnBackend, self).validate(
69 check_obj,
70 copy(schema).set_name(column_name),
71 head=head,
72 tail=tail,
73 sample=sample,
74 random_state=random_state,
75 lazy=lazy,
76 inplace=inplace,
77 )
79 if return_check_obj:
80 return validated_check_obj
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/array.py:69, in ArraySchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
66 error_handler.collect_error(exc.reason_code, exc)
68 # run the core checks
---> 69 error_handler = self.run_checks_and_handle_errors(
70 error_handler,
71 schema,
72 check_obj,
73 head,
74 tail,
75 sample,
76 random_state,
77 )
79 if lazy and error_handler.collected_errors:
80 if getattr(schema, "drop_invalid_rows", False):
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/array.py:150, in ArraySchemaBackend.run_checks_and_handle_errors(self, error_handler, schema, check_obj, head, tail, sample, random_state)
139 else:
140 error = SchemaError(
141 schema=schema,
142 data=check_obj,
(...)
148 reason_code=result.reason_code,
149 )
--> 150 error_handler.collect_error(
151 result.reason_code,
152 error,
153 original_exc=result.original_exc,
154 )
156 return error_handler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/error_handlers.py:38, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
31 """Collect schema error, raising exception if lazy is False.
32
33 :param reason_code: string representing reason for error.
34 :param schema_error: ``SchemaError`` object.
35 :param original_exc: original exception associated with the SchemaError.
36 """
37 if not self._lazy:
---> 38 raise schema_error from original_exc
40 # delete data of validated object from SchemaError object to prevent
41 # storing copies of the validated DataFrame/Series for every
42 # SchemaError collected.
43 del schema_error.data
SchemaError: <Schema Column(name=Sex, type=DataType(str))> failed element-wise validator 0:
<Check isin: isin(['MALE', 'FEMALE'])>
failure cases:
index failure_case
0 259 .
X_test.Sex.unique()
array(['FEMALE', 'MALE', '.'], dtype=object)
X_test.loc[259]
Culmen Length (mm) 44.5
Culmen Depth (mm) 15.7
Flipper Length (mm) 217.0
Sex .
Name: 259, dtype: object
Can you fix the data to conform to the schema?