Improving the MLModel Base Class

Or, how to make ML models easier to install, document, and release

In general, I want to show how to make ML code easier to install and use.

Making the Iris Model into a Python Package

A common pattern in ML code is that it is almost always hard to use and deploy.

- project_root
- docs (a folder, package documentation will goes in here)
- iris_model (a folder, iris package code will goes in here)
- model_files (a folder, the model files go in here)
- __init__.py
- iris_predict.py (the prediction code goes here)
- iris_train.py ( the training script goes here)
- tests (unit tests for iris_model package go here)
- ml_model_abc.py (the MLModel base class goes here)
- requirements.txt
- setup.py (the package installation script goes here)

Adding Package Versioning

__version_info__ = (0, 1, 0)
__version__ = “.”.join([str(n) for n in __version_info__])

Adding a CLI interface to the Training Script

def argument_parser():
parser = argparse.ArgumentParser(
description=’Command to train the Iris model.’)
parser.add_argument(‘-gamma’, action=”store”, dest=”gamma”,
type=float, help=’Gamma value used to train the SVM model.’)
parser.add_argument(‘-c’, action=”store”, dest=”c”,
type=float, help=’C value used to train the SVM model.’)
return parser
def main():
parser = argument_parser()
results = parser.parse_args()
try:
if results.gamma is None and results.c is None:
train()
elif results.gamma is not None and results.c is None:
train(gamma=results.gamma)
elif results.gamma is None and results.c is not None:
train(c=results.c)
else:
train(gamma=results.gamma, c=results.c)
except Exception as e:
traceback.print_exc()
sys.exit(os.EX_SOFTWARE)
sys.exit(os.EX_OK)

Adding Sphinx Documentation

.. jsonschema:: ../build/input_schema.json
.. argparse::
:module: iris_model.iris_train
:func: argument_parser
:prog: iris_train

Adding a setup.py File

packages=[“iris_model”],
py_modules=[“ml_model_abc”],
package_data={‘iris_model’:[
‘model_files/svc_model.pickle’
]},
include_package_data=True,
entry_points={
‘console_scripts’: [
‘iris_train=iris_model.iris_train:main’, ]
}
mkdir example
cd example
# creating a virtual environment
python3 -m venv venv
# activating the virtual environment, on a mac computer
source venv/bin/activate
# installing the iris_model package from the github repository
pip install git+https://github.com/schmidtbri/ml-model-abc-improvements#egg=iris_model
>>> from iris_model.iris_predict import IrisModel
>>> model = IrisModel()
>>> model
<iris_model.iris_predict.IrisModel object at 0x105d1e940>
>>> model.input_schema
Schema({‘sepal_length’: <class ‘float’>, ‘sepal_width’: <class ‘float’>, ‘petal_length’: <class ‘float’>, ‘petal_width’: <class ‘float’>})
>>> model.output_schema
Schema({‘species’: <class ‘str’>})
iris_train -c=10.0 -gamma=0.01

Model Metadata in the MLModel Base Class

class MLModel(ABC):
@property
@abstractmethod
def display_name(self):
raise NotImplementedError()

@property
@abstractmethod
def qualified_name(self):
raise NotImplementedError()

@property
@abstractmethod
def description(self):
raise NotImplementedError()

@property
@abstractmethod
def major_version(self):
raise NotImplementedError()

@property
@abstractmethod
def minor_version(self):
raise NotImplementedError()

@property
@abstractmethod
def input_schema(self):
raise NotImplementedError()

@property
@abstractmethod
def output_schema(self):
raise NotImplementedError()

@abstractmethod
def __init__(self):
raise NotImplementedError()

@abstractmethod
def predict(self, data):
raise NotImplementedError()
# a display name for the model 
__display_name__ = “Iris Model”
# returning the package name as the qualified name for the model __qualified_name__ = __name__.split(“.”)[0]# a description of the model
__description__ = “A machine learning model for predicting the species of a flower based on its measurements.”
from ml_model_abc import MLModel
from iris_model import __version_info__, __display_name__, / __qualified_name__, __description__
class IrisModel(MLModel):
# accessing the package metadata
display_name = __display_name__
qualified_name = __qualified_name__
description = __description__
major_version = __version_info__[0]
minor_version = __version_info__[1]
# stating the input schema of the model as a Schema object
input_schema = Schema({‘sepal_length’: float,
‘sepal_width’: float,
‘petal_length’: float,
‘petal_width’: float})
# stating the output schema of the model as a Schema object
output_schema = Schema({‘species’: str})
def __init__(self):
dir_path = os.path.dirname(os.path.realpath(__file__))
file = open(os.path.join(dir_path,
“model_files”, “svc_model.pickle”), ‘rb’)
self._svm_model = pickle.load(file)
file.close()
def predict(self, data):
try:
self.input_schema.validate(data)
except Exception as e:
raise MLModelSchemaValidationException("Failed to validate input data: {}".format(str(e)))
X = array([data[“sepal_length”],
data[“sepal_width”],
data[“petal_length”],
data[“petal_width”]]).reshape(1, -1)
y_hat = int(self._svm_model.predict(X)[0])
targets = [‘setosa’, ‘versicolor’, ‘virginica’]
species = targets[y_hat]
return {“species”: species}
>>> from iris_model.iris_predict import IrisModel
>>> iris_model = IrisModel()
>>> iris_model.qualified_name
‘iris_model’
>>> iris_model.display_name
‘Iris Model’

Future Improvements

Coder and machine learning enthusiast

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store