4.4. Testing#
Machine learning is very hard to test. Due to the nature of the our models, we often have soft failures in the model that are difficult to test against.
Writing software tests in science, is already incredibly hard, so in this section we’ll touch on
some fairly simple tests we can implement to ensure consistency of our input data
avoid bad bugs in data loading procedures
some strategies to probe our models
from pathlib import Path
DATA_FOLDER = Path("..", "..") / "data"
DATA_FILEPATH = DATA_FOLDER / "penguins_clean.csv"
import pandas as pd
penguins = pd.read_csv(DATA_FILEPATH)
penguins.head()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | Sex | Species | |
---|---|---|---|---|---|
0 | 39.1 | 18.7 | 181.0 | MALE | Adelie Penguin (Pygoscelis adeliae) |
1 | 39.5 | 17.4 | 186.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
2 | 40.3 | 18.0 | 195.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
3 | 36.7 | 19.3 | 193.0 | FEMALE | Adelie Penguin (Pygoscelis adeliae) |
4 | 39.3 | 20.6 | 190.0 | MALE | Adelie Penguin (Pygoscelis adeliae) |
from sklearn.model_selection import train_test_split
num_features = ["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)"]
cat_features = ["Sex"]
features = num_features + cat_features
target = ["Species"]
X_train, X_test, y_train, y_test = train_test_split(penguins[features], penguins[target], stratify=penguins[target[0]], train_size=.7, random_state=42)
from sklearn.svm import SVC
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import joblib
from joblib import load
MODEL_FOLDER = Path("..", "..") / "model"
MODEL_EXPORT_FILE = MODEL_FOLDER / "svc.joblib"
clf = load(MODEL_EXPORT_FILE)
clf.score(X_test, y_test)
1.0
4.4.1. Deterministic Tests#
When I work with neural networks, implementing a new layer, method, or fancy thing, I try to write a test for that layer. The Conv2D
layer in Keras and Pytorch for example should always do the same exact thing, when they convole a kernel with an image.
Consider writing a small pytest
test that takes a simple numpy array and tests against a known output.
You can check out the keras
test suite here and an example how they validate the input and output shapes.
Admittedly this isn’t always easy to do and can go beyond the need for research scripts.
4.4.2. Data Tests for Models#
An even easier test is by essentially reusing the notebook from the Model Evaluation and writing a test function for it.
def test_penguins(clf):
# Define data you definitely know the answer to
test_data = pd.DataFrame([[34.6, 21.1, 198.0, "MALE"],
[46.1, 18.2, 178.0, "FEMALE"],
[52.5, 15.6, 221.0, "MALE"]],
columns=["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)", "Sex"])
# Define target to the data
test_target = ['Adelie Penguin (Pygoscelis adeliae)',
'Chinstrap penguin (Pygoscelis antarctica)',
'Gentoo penguin (Pygoscelis papua)']
# Assert the model should get these right.
assert clf.score(test_data, test_target) == 1
test_penguins(clf)
4.4.3. Automated Testing of Docstring Examples#
There is an even easier way to run simple tests. This can be useful when we write specific functions to pre-process our data. In the Model Sharing notebook, we looked into auto-generating docstrings.
We can upgrade our docstring and get free software tests out of it!
This is called doctest and usually useful to keep docstring examples up to date and write quick unit tests for a function.
This makes future users (including yourself from the future) quite happy.
def shorten_class_name(df: pd.DataFrame) -> pd.DataFrame:
"""Shorten the class names of the penguins to the shortest version
Parameters
----------
df : pd.DataFrame
Dataframe containing the Species column with penguins
Returns
-------
pd.DataFrame
Normalised dataframe with shortened names
Examples
--------
>>> shorten_class_name(pd.DataFrame([[1,2,3,"Adelie Penguin (Pygoscelis adeliae)"]], columns=["1","2","3","Species"]))
1 2 3 Species
0 1 2 3 Adelie
"""
df["Species"] = df.Species.str.split(r" [Pp]enguin", n=1, expand=True)[0]
return df
import doctest
doctest.testmod()
TestResults(failed=0, attempted=1)
shorten_class_name(penguins).head()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | Sex | Species | |
---|---|---|---|---|---|
0 | 39.1 | 18.7 | 181.0 | MALE | Adelie |
1 | 39.5 | 17.4 | 186.0 | FEMALE | Adelie |
2 | 40.3 | 18.0 | 195.0 | FEMALE | Adelie |
3 | 36.7 | 19.3 | 193.0 | FEMALE | Adelie |
4 | 39.3 | 20.6 | 190.0 | MALE | Adelie |
So these give a nice example of usage in the docstring, an expected output and a first test case that is validated by our test suite.
4.4.4. Input Data Validation#
You validate that the data that users are providing matches what your model is expecting.
These tools are often used in production systems to determine whether APIs usage and user inputs are formatted correctly.
Example tools are:
import pandera as pa
# data to validate
X_train.describe()
Culmen Length (mm) | Culmen Depth (mm) | Flipper Length (mm) | |
---|---|---|---|
count | 233.000000 | 233.000000 | 233.000000 |
mean | 43.982403 | 17.228755 | 201.412017 |
std | 5.537146 | 1.994191 | 13.929695 |
min | 33.500000 | 13.100000 | 172.000000 |
25% | 39.000000 | 15.700000 | 190.000000 |
50% | 44.400000 | 17.300000 | 198.000000 |
75% | 48.800000 | 18.800000 | 213.000000 |
max | 59.600000 | 21.200000 | 231.000000 |
# define schema
schema = pa.DataFrameSchema({
"Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
pa.Check.le(60)]),
"Culmen Depth (mm)": pa.Column(float, checks=[pa.Check.ge(13),
pa.Check.le(22)]),
"Flipper Length (mm)": pa.Column(float, checks=[pa.Check.ge(170),
pa.Check.le(235)]),
"Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
})
validated_test = schema(X_test)
---------------------------------------------------------------------------
SchemaError Traceback (most recent call last)
Cell In[11], line 12
1 # define schema
2 schema = pa.DataFrameSchema({
3 "Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
4 pa.Check.le(60)]),
(...)
9 "Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
10 })
---> 12 validated_test = schema(X_test)
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:409, in DataFrameSchema.__call__(self, dataframe, head, tail, sample, random_state, lazy, inplace)
381 def __call__(
382 self,
383 dataframe: pd.DataFrame,
(...)
389 inplace: bool = False,
390 ):
391 """Alias for :func:`DataFrameSchema.validate` method.
392
393 :param pd.DataFrame dataframe: the dataframe to be validated.
(...)
407 otherwise creates a copy of the data.
408 """
--> 409 return self.validate(
410 dataframe, head, tail, sample, random_state, lazy, inplace
411 )
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:340, in DataFrameSchema.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
328 check_obj = check_obj.map_partitions( # type: ignore [operator]
329 self._validate,
330 head=head,
(...)
336 meta=check_obj,
337 )
338 return check_obj.pandera.add_schema(self)
--> 340 return self._validate(
341 check_obj=check_obj,
342 head=head,
343 tail=tail,
344 sample=sample,
345 random_state=random_state,
346 lazy=lazy,
347 inplace=inplace,
348 )
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/api/pandas/container.py:370, in DataFrameSchema._validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
361 if self._is_inferred:
362 warnings.warn(
363 f"This {type(self)} is an inferred schema that hasn't been "
364 "modified. It's recommended that you refine the schema "
(...)
367 UserWarning,
368 )
--> 370 return self.BACKEND.validate(
371 check_obj,
372 schema=self,
373 head=head,
374 tail=tail,
375 sample=sample,
376 random_state=random_state,
377 lazy=lazy,
378 inplace=inplace,
379 )
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:100, in DataFrameSchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
96 self.run_schema_component_checks(
97 check_obj_subsample, schema_components, lazy, error_handler
98 )
99 except SchemaError as exc:
--> 100 error_handler.collect_error(exc.reason_code, exc)
102 try:
103 self.run_checks(check_obj_subsample, schema, error_handler)
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/error_handlers.py:37, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
31 """Collect schema error, raising exception if lazy is False.
32
33 :param reason_code: string representing reason for error
34 :param schema_error: ``SchemaError`` object.
35 """
36 if not self._lazy:
---> 37 raise schema_error from original_exc
39 # delete data of validated object from SchemaError object to prevent
40 # storing copies of the validated DataFrame/Series for every
41 # SchemaError collected.
42 del schema_error.data
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:96, in DataFrameSchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
92 check_obj_subsample = self.subsample(
93 check_obj, head, tail, sample, random_state
94 )
95 try:
---> 96 self.run_schema_component_checks(
97 check_obj_subsample, schema_components, lazy, error_handler
98 )
99 except SchemaError as exc:
100 error_handler.collect_error(exc.reason_code, exc)
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:138, in DataFrameSchemaBackend.run_schema_component_checks(self, check_obj, schema_components, lazy, error_handler)
136 check_results.append(is_table(result))
137 except SchemaError as err:
--> 138 error_handler.collect_error("schema_component_check", err)
139 except SchemaErrors as err:
140 for schema_error_dict in err.schema_errors:
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/error_handlers.py:37, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
31 """Collect schema error, raising exception if lazy is False.
32
33 :param reason_code: string representing reason for error
34 :param schema_error: ``SchemaError`` object.
35 """
36 if not self._lazy:
---> 37 raise schema_error from original_exc
39 # delete data of validated object from SchemaError object to prevent
40 # storing copies of the validated DataFrame/Series for every
41 # SchemaError collected.
42 del schema_error.data
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/container.py:133, in DataFrameSchemaBackend.run_schema_component_checks(self, check_obj, schema_components, lazy, error_handler)
131 for schema_component in schema_components:
132 try:
--> 133 result = schema_component.validate(
134 check_obj, lazy=lazy, inplace=True
135 )
136 check_results.append(is_table(result))
137 except SchemaError as err:
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/api/pandas/components.py:164, in Column.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
137 def validate(
138 self,
139 check_obj: pd.DataFrame,
(...)
145 inplace: bool = False,
146 ) -> pd.DataFrame:
147 """Validate a Column in a DataFrame object.
148
149 :param check_obj: pandas DataFrame to validate.
(...)
162 :returns: validated DataFrame.
163 """
--> 164 return self.BACKEND.validate(
165 check_obj,
166 self,
167 head=head,
168 tail=tail,
169 sample=sample,
170 random_state=random_state,
171 lazy=lazy,
172 inplace=inplace,
173 )
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:94, in ColumnBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
90 validate_column(
91 check_obj[column_name].iloc[:, [i]], column_name
92 )
93 else:
---> 94 validate_column(check_obj, column_name)
96 if lazy and error_handler.collected_errors:
97 raise SchemaErrors(
98 schema=schema,
99 schema_errors=error_handler.collected_errors,
100 data=check_obj,
101 )
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:72, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name)
68 error_handler.collect_error(
69 err_dict["reason_code"], err_dict["error"]
70 )
71 except SchemaError as err:
---> 72 error_handler.collect_error(err.reason_code, err)
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/error_handlers.py:37, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
31 """Collect schema error, raising exception if lazy is False.
32
33 :param reason_code: string representing reason for error
34 :param schema_error: ``SchemaError`` object.
35 """
36 if not self._lazy:
---> 37 raise schema_error from original_exc
39 # delete data of validated object from SchemaError object to prevent
40 # storing copies of the validated DataFrame/Series for every
41 # SchemaError collected.
42 del schema_error.data
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:56, in ColumnBackend.validate.<locals>.validate_column(check_obj, column_name)
53 def validate_column(check_obj, column_name):
54 try:
55 # pylint: disable=super-with-arguments
---> 56 super(ColumnBackend, self).validate(
57 check_obj,
58 copy(schema).set_name(column_name),
59 head=head,
60 tail=tail,
61 sample=sample,
62 random_state=random_state,
63 lazy=lazy,
64 inplace=inplace,
65 )
66 except SchemaErrors as err:
67 for err_dict in err.schema_errors:
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/array.py:98, in ArraySchemaBackend.validate(self, check_obj, schema, head, tail, sample, random_state, lazy, inplace)
85 if not check_result.passed:
86 error_handler.collect_error(
87 check_result.reason_code,
88 SchemaError(
(...)
95 ),
96 )
---> 98 check_results = self.run_checks(
99 check_obj_subsample, schema, error_handler, lazy
100 )
101 assert all(check_results)
103 if lazy and error_handler.collected_errors:
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:196, in ColumnBackend.run_checks(self, check_obj, schema, error_handler, lazy)
190 check_results.append(
191 self.run_check(
192 check_obj, schema, check, check_index, *check_args
193 )
194 )
195 except SchemaError as err:
--> 196 error_handler.collect_error("dataframe_check", err)
197 except Exception as err: # pylint: disable=broad-except
198 # catch other exceptions that may occur when executing the Check
199 err_msg = f'"{err.args[0]}"' if len(err.args) > 0 else ""
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/error_handlers.py:37, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
31 """Collect schema error, raising exception if lazy is False.
32
33 :param reason_code: string representing reason for error
34 :param schema_error: ``SchemaError`` object.
35 """
36 if not self._lazy:
---> 37 raise schema_error from original_exc
39 # delete data of validated object from SchemaError object to prevent
40 # storing copies of the validated DataFrame/Series for every
41 # SchemaError collected.
42 del schema_error.data
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/components.py:191, in ColumnBackend.run_checks(self, check_obj, schema, error_handler, lazy)
188 check_args = [None] if is_field(check_obj) else [schema.name]
189 try:
190 check_results.append(
--> 191 self.run_check(
192 check_obj, schema, check, check_index, *check_args
193 )
194 )
195 except SchemaError as err:
196 error_handler.collect_error("dataframe_check", err)
File /opt/hostedtoolcache/Python/3.8.16/x64/lib/python3.8/site-packages/pandera/backends/pandas/base.py:115, in PandasSchemaBackend.run_check(self, check_obj, schema, check, check_index, *args)
113 warnings.warn(error_msg, UserWarning)
114 return True
--> 115 raise SchemaError(
116 schema,
117 check_obj,
118 error_msg,
119 failure_cases=failure_cases,
120 check=check,
121 check_index=check_index,
122 check_output=check_result.check_output,
123 )
124 return check_result.check_passed
SchemaError: <Schema Column(name=Sex, type=DataType(str))> failed element-wise validator 0:
<Check isin: isin(['MALE', 'FEMALE'])>
failure cases:
index failure_case
0 259 .
X_test.Sex.unique()
array(['FEMALE', 'MALE', '.'], dtype=object)
X_test.loc[259]
Culmen Length (mm) 44.5
Culmen Depth (mm) 15.7
Flipper Length (mm) 217.0
Sex .
Name: 259, dtype: object
Can you fix the data to conform to the schema?