4.4. Software Testing of Machine Learning Projects#
Machine learning code is very hard to test.
Due to the nature of the our models, we often have soft failures in the model that are difficult to test against. That basically means, they look like they’re doing what they’re supposed to, but secretly they’re not because of some bug.
Writing software tests in science, is already incredibly hard, so in this section we’ll touch on
some fairly simple tests we can implement to ensure consistency of our input data
avoid bad bugs in data loading procedures
some strategies to probe our models
First we’ll split the data from the Data notebook and load the model from the Sharing notebook.
from pathlib import Path
DATA_FOLDER = Path("..", "..") / "data"
DATA_FILEPATH = DATA_FOLDER / "penguins_clean.csv"
import pandas as pd
penguins = pd.read_csv(DATA_FILEPATH)
penguins.head()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | Sex | Species | |
---|---|---|---|---|---|
0 | 39.1 | 18.7 | 181.0 | MALE | Adelie Penguin (Pygoscelis adeliae) |
1 | 39.5 | 17.4 | 186.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
2 | 40.3 | 18.0 | 195.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
3 | 36.7 | 19.3 | 193.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
4 | 39.3 | 20.6 | 190.0 | MALE | Adelie Penguin (Pygoscelis adeliae) |
from sklearn.model_selection import train_test_split
num_features = ["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)"]
cat_features = ["Sex"]
features = num_features + cat_features
target = ["Species"]
X_train, X_test, y_train, y_test = train_test_split(penguins[features], penguins[target], stratify=penguins[target[0]], train_size=.7, random_state=42)
from sklearn.svm import SVC
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import joblib
from joblib import load
MODEL_FOLDER = Path("..", "..") / "model"
MODEL_EXPORT_FILE = MODEL_FOLDER / "svc.joblib"
clf = load(MODEL_EXPORT_FILE)
clf.score(X_test, y_test)
1.0
4.4.1. Deterministic Tests#
When I work with neural networks, implementing a new layer, method, or fancy thing, I try to write a test for that layer. The Conv2D
layer in Keras and Pytorch for example should always do the same exact thing, when they convole a kernel with an image.
Consider writing a small pytest
test that takes a simple numpy array and tests against a known output.
You can check out the keras
test suite here and an example how they validate the input and output shapes.
Admittedly this isn’t always easy to do and can go beyond the need for research scripts.
4.4.2. Data Tests for Models#
An even easier test is by essentially reusing the notebook from the Model Evaluation and writing a test function for it.
def test_penguins(clf):
# Define data you definitely know the answer to
test_data = pd.DataFrame([[34.6, 21.1, 198.0, "MALE"],
[46.1, 18.2, 178.0, "FEMALE"],
[52.5, 15.6, 221.0, "MALE"]],
columns=["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)", "Sex"])
# Define target to the data
test_target = ['Adelie Penguin (Pygoscelis adeliae)',
'Chinstrap penguin (Pygoscelis antarctica)',
'Gentoo penguin (Pygoscelis papua)']
# Assert the model should get these right.
assert clf.score(test_data, test_target) == 1
test_penguins(clf)
This means we have some samples in the data, where we clearly know they should be part of one class and we can use these to test the model.
4.4.3. Automated Testing of Docstring Examples#
There is an even easier way to run simple tests. This can be useful when we write specific functions to pre-process our data. In the Model Sharing notebook, we looked into auto-generating docstrings.
We can upgrade our docstring and get free software tests out of it!
This is called doctest and usually useful to keep docstring examples up to date and write quick unit tests for a function.
This makes future users (including yourself from the future) quite happy.
def shorten_class_name(df: pd.DataFrame) -> pd.DataFrame:
"""Shorten the class names of the penguins to the shortest version
Parameters
----------
df : pd.DataFrame
Dataframe containing the Species column with penguins
Returns
-------
pd.DataFrame
Normalised dataframe with shortened names
Examples
--------
>>> shorten_class_name(pd.DataFrame([[1,2,3,"Adelie Penguin (Pygoscelis adeliae)"]], columns=["1","2","3","Species"]))
1 2 3 Species
0 1 2 3 Adelie
"""
df["Species"] = df.Species.str.split(r" [Pp]enguin", n=1, expand=True)[0]
return df
import doctest
doctest.testmod()
TestResults(failed=0, attempted=1)
shorten_class_name(penguins).head()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | Sex | Species | |
---|---|---|---|---|---|
0 | 39.1 | 18.7 | 181.0 | MALE | Adelie |
1 | 39.5 | 17.4 | 186.0 | FEMALE | Adelie |
2 | 40.3 | 18.0 | 195.0 | FEMALE | Adelie |
3 | 36.7 | 19.3 | 193.0 | FEMALE | Adelie |
4 | 39.3 | 20.6 | 190.0 | MALE | Adelie |
So these give a nice example of usage in the docstring, an expected output and a first test case that is validated by our test suite.
4.4.4. Input Data Validation#
You validate that the data that users are providing matches what your model is expecting.
These tools are often used in production systems to determine whether APIs usage and user inputs are formatted correctly.
Example tools are:
import pandera as pa
# data to validate
X_train.describe()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | |
---|---|---|---|
count | 233.000000 | 233.000000 | 233.000000 |
mean | 43.982403 | 17.228755 | 201.412017 |
std | 5.537146 | 1.994191 | 13.929695 |
min | 33.500000 | 13.100000 | 172.000000 |
25% | 39.000000 | 15.700000 | 190.000000 |
50% | 44.400000 | 17.300000 | 198.000000 |
75% | 48.800000 | 18.800000 | 213.000000 |
max | 59.600000 | 21.200000 | 231.000000 |
The following code is supposed to fail to see what happens if the schema doesn’t match!
# define schema
schema = pa.DataFrameSchema({
"Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
pa.Check.le(60)]),
"Culmen Depth (mm)": pa.Column(float, checks=[pa.Check.ge(13),
pa.Check.le(22)]),
"Flipper Length (mm)": pa.Column(float, checks=[pa.Check.ge(170),
pa.Check.le(235)]),
"Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
})
validated_test = schema(X_test)
---------------------------------------------------------------------------
SchemaError Traceback (most recent call last)
Cell In[11], line 12
1 # define schema
2 schema = pa.DataFrameSchema({
3 "Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
4 pa.Check.le(60)]),
(...)
9 "Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
10 })
---> 12 validated_test = schema(X_test)
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/dataframe/container.py:327, in DataFrameSchema.__call__(self, dataframe, head, tail, sample, random_state, lazy, inplace)
299 def __call__(
300 self,
301 dataframe: TDataObject,
(...)
307 inplace: bool = False,
308 ) -> TDataObject:
309 """Alias for :func:`DataFrameSchema.validate` method.
310
311 :param pd.DataFrame dataframe: the dataframe to be validated.
(...)
325 otherwise creates a copy of the data.
326 """
--> 327 return self.validate(
328 dataframe, head, tail, sample, random_state, lazy, inplace
329 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:126, in DataFrameSchema.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
114 check_obj = check_obj.map_partitions( # type: ignore [operator]
115 self._validate,
116 head=head,
(...)
122 meta=check_obj,
123 )
124 return check_obj.pandera.add_schema(self)
--> 126 return self._validate(
127 check_obj=check_obj,
128 head=head,
129 tail=tail,
130 sample=sample,
131 random_state=random_state,
132 lazy=lazy,
133 inplace=inplace,
134 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:156, in DataFrameSchema._validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
147 if self._is_inferred:
148 warnings.warn(
149 f"This {type(self)} is an inferred schema that hasn't been "
150 "modified. It's recommended that you refine the schema "
(...)
153 UserWarning,
154 )
--> 156 return self.get_backend(check_obj).validate(
157 check_obj,
158 schema=self,
159 head=head,
160 tail=tail,
161 sample=sample,
162 random_state=random_state,
163 lazy=lazy,
164 inplace=inplace,
165 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:105, in DataFrameSchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
100 components = self.collect_schema_components(
101 check_obj, schema, column_info
102 )
104 # run the checks
--> 105 error_handler = self.run_checks_and_handle_errors(
106 error_handler,
107 schema,
108 check_obj,
109 column_info,
110 sample,
111 components,
112 lazy,
113 head,
114 tail,
115 random_state,
116 )
118 if error_handler.collected_errors:
119 if getattr(schema, "drop_invalid_rows", False):
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:180, in DataFrameSchemaBackend.run_checks_and_handle_errors(self, error_handler, schema, check_obj, column_info, sample, components, lazy, head, tail, random_state)
169 else:
170 error = SchemaError(
171 schema,
172 data=check_obj,
(...)
178 reason_code=result.reason_code,
179 )
--> 180 error_handler.collect_error(
181 validation_type(result.reason_code),
182 result.reason_code,
183 error,
184 result.original_exc,
185 )
187 return error_handler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/base/error_handler.py:54, in ErrorHandler.collect_error(self, error_type, reason_code, schema_error, original_exc)
47 """Collect schema error, raising exception if lazy is False.
48
49 :param error_type: type of error
50 :param reason_code: string representing reason for error
51 :param schema_error: ``SchemaError`` object.
52 """
53 if not self._lazy:
---> 54 raise schema_error from original_exc
56 # delete data of validated object from SchemaError object to prevent
57 # storing copies of the validated DataFrame/Series for every
58 # SchemaError collected.
59 if hasattr(schema_error, "data"):
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:201, in DataFrameSchemaBackend.run_schema_component_checks(self, check_obj, schema_components, lazy)
199 for schema_component in schema_components:
200 try:
--> 201 result = schema_component.validate(
202 check_obj, lazy=lazy, inplace=True
203 )
204 check_passed.append(is_table(result))
205 except SchemaError as err:
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/dataframe/components.py:162, in ComponentSchema.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
133 def validate(
134 self,
135 check_obj,
(...)
142 ):
143 # pylint: disable=too-many-locals,too-many-branches,too-many-statements
144 """Validate a series or specific column in dataframe.
145
146 :check_obj: data object to validate.
(...)
160
161 """
--> 162 return self.get_backend(check_obj).validate(
163 check_obj,
164 schema=self,
165 head=head,
166 tail=tail,
167 sample=sample,
168 random_state=random_state,
169 lazy=lazy,
170 inplace=inplace,
171 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:136, in ColumnBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
132 check_obj = validate_column(
133 check_obj, column_name, return_check_obj=True
134 )
135 else:
--> 136 validate_column(check_obj, column_name)
138 if lazy and error_handler.collected_errors:
139 raise SchemaErrors(
140 schema=schema,
141 schema_errors=error_handler.schema_errors,
142 data=check_obj,
143 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:92, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name, return_check_obj)
88 error_handler.collect_error(
89 validation_type(err.reason_code), err.reason_code, err
90 )
91 except SchemaError as err:
---> 92 error_handler.collect_error(
93 validation_type(err.reason_code), err.reason_code, err
94 )
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/base/error_handler.py:54, in ErrorHandler.collect_error(self, error_type, reason_code, schema_error, original_exc)
47 """Collect schema error, raising exception if lazy is False.
48
49 :param error_type: type of error
50 :param reason_code: string representing reason for error
51 :param schema_error: ``SchemaError`` object.
52 """
53 if not self._lazy:
---> 54 raise schema_error from original_exc
56 # delete data of validated object from SchemaError object to prevent
57 # storing copies of the validated DataFrame/Series for every
58 # SchemaError collected.
59 if hasattr(schema_error, "data"):
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:72, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name, return_check_obj)
69 def validate_column(check_obj, column_name, return_check_obj=False):
70 try:
71 # pylint: disable=super-with-arguments
---> 72 validated_check_obj = super(ColumnBackend, self).validate(
73 check_obj,
74 deepcopy(schema).set_name(column_name),
75 head=head,
76 tail=tail,
77 sample=sample,
78 random_state=random_state,
79 lazy=lazy,
80 inplace=inplace,
81 )
83 if return_check_obj:
84 return validated_check_obj
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/array.py:81, in ArraySchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
75 check_obj = self.run_parsers(
76 schema,
77 check_obj,
78 )
80 # run the core checks
---> 81 error_handler = self.run_checks_and_handle_errors(
82 error_handler,
83 schema,
84 check_obj,
85 head=head,
86 tail=tail,
87 sample=sample,
88 random_state=random_state,
89 )
91 if lazy and error_handler.collected_errors:
92 if getattr(schema, "drop_invalid_rows", False):
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/backends/pandas/array.py:145, in ArraySchemaBackend.run_checks_and_handle_errors(self, error_handler, schema, check_obj, **subsample_kwargs)
134 else:
135 error = SchemaError(
136 schema=schema,
137 data=check_obj,
(...)
143 reason_code=result.reason_code,
144 )
--> 145 error_handler.collect_error(
146 validation_type(result.reason_code),
147 result.reason_code,
148 error,
149 original_exc=result.original_exc,
150 )
152 return error_handler
File /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/pandera/api/base/error_handler.py:54, in ErrorHandler.collect_error(self, error_type, reason_code, schema_error, original_exc)
47 """Collect schema error, raising exception if lazy is False.
48
49 :param error_type: type of error
50 :param reason_code: string representing reason for error
51 :param schema_error: ``SchemaError`` object.
52 """
53 if not self._lazy:
---> 54 raise schema_error from original_exc
56 # delete data of validated object from SchemaError object to prevent
57 # storing copies of the validated DataFrame/Series for every
58 # SchemaError collected.
59 if hasattr(schema_error, "data"):
SchemaError: Column 'Sex' failed element-wise validator number 0: isin(['MALE', 'FEMALE']) failure cases: .
X_test.Sex.unique()
array(['FEMALE', 'MALE', '.'], dtype=object)
X_test.loc[259]
Culmen Length (mm) 44.5
Culmen Depth (mm) 15.7
Flipper Length (mm) 217.0
Sex .
Name: 259, dtype: object
Can you fix the data to conform to the schema?