A RESTful ML Model Service

Building a Service for Deploying ML Models

Introduction

Package Structure

- rest_model_service
- __init__.py
- configuration.py # data models for configuration
- generate_openapi.py # script to generate an openapi spec
- main.py # entry point for service
- routes.py # controllers for routes
- schemas.py # service schemas
- tests
- requirements.txt
- setup.py
- test_requirements.txt

FastAPI

Model Metadata Endpoint

class ModelMetadata(BaseModel):
“””Metadata of a model.”””
display_name: str = Field(description=”The display name of the model.”)
qualified_name: str = Field(description=”The qualified name of the model.”)
description: str = Field(description=”The description of the model.”)
version: str = Field(description=”The version of the model.”)
class ModelMetadataCollection(BaseModel):
“””Collection of model metadata.”””
models: List[ModelMetadata] = Field(description=”A collection of model description.”)
async def get_models():
try:
model_manager = ModelManager()
models_metadata_collection = model_manager.get_models()
models_metadata_collection = ModelMetadataCollection(
**{“models”: models_metadata_collection}).dict()
return JSONResponse(status_code=200,
content=models_metadata_collection)
except Exception as e:
error = Error(type=”ServiceError”, message=str(e)).dict()
return JSONResponse(status_code=500, content=error)

Prediction Endpoint

class PredictionController(object):
def __init__(self, model: MLModel) -> None:
self._model = model
def __call__(self, data):
try:
prediction = self._model.predict(data).dict()
return JSONResponse(status_code=200, content=prediction)
except MLModelSchemaValidationException as e:
error = Error(type=”SchemaValidationError”,
message=str(e)).dict()
return JSONResponse(status_code=400, content=error)
except Exception as e:
error = Error(type=”ServiceError”, message=str(e)).dict()
return JSONResponse(status_code=500, content=error)

Application Startup

if os.environ.get(“REST_CONFIG”) is not None:
file_path = os.environ[“REST_CONFIG”]
else:
file_path = “rest_config.yaml”
if path.exists(file_path) and path.isfile(file_path):
with open(file_path) as file:
configuration = yaml.full_load(file)
configuration = Configuration(**configuration)
app = create_app(configuration.service_title,
configuration.models)
else:
raise ValueError(“Could not find configuration file
‘{}’.”.format(file_path))
def create_app(service_title: str, models: List[Model]) -> FastAPI:
app: FastAPI = FastAPI(title=service_title, version=__version__)
app.add_api_route(“/”,
get_root,
methods=[“GET”])
app.add_api_route(“/api/models”,
get_models,
methods=[“GET”],
response_model=ModelMetadataCollection,
responses={
500: {“model”: Error}
})
model_manager = ModelManager()
for model in models:
model_manager.load_model(model.class_path)
if model.create_endpoint:
model = model_manager.get_model(model.qualified_name)
controller = PredictionController(model=model)
controller.__call__.__annotations__[“data”] =
model.input_schema
app.add_api_route(“/api/models/{}/prediction” \
.format(model.qualified_name),
controller,
methods=[“POST”],
response_model=model.output_schema,
description=model.description,
responses={
400: {“model”: Error},
500: {“model”: Error}
})
else:
logger.info(“Skipped creating an endpoint for model
{}”.format(model.qualified_name))
return app

Creating a Package

pip install rest_model_service

Using the Service

class IrisModelInput(BaseModel):
sepal_length: float = Field(gt=5.0, lt=8.0, description=”Length
of the sepal of the flower.”)
sepal_width: float = Field(gt=2.0, lt=6.0, description=”Width of
the sepal of the flower.”)
petal_length: float = Field(gt=1.0, lt=6.8, description=”Length
of the petal of the flower.”)
petal_width: float = Field(gt=0.0, lt=3.0, description=”Width of
the petal of the flower.”)
class Species(str, Enum):
iris_setosa = “Iris setosa”
iris_versicolor = “Iris versicolor”
iris_virginica = “Iris virginica”.
class IrisModelOutput(BaseModel):
species: Species = Field(description=”Predicted species of the
flower.”)
class IrisModel(MLModel):
display_name = “Iris Model”
qualified_name = “iris_model”
description = “Model for predicting the species of a flower
based on its measurements.”
version = “1.0.0”
input_schema = IrisModelInput
output_schema = IrisModelOutput
def __init__(self):
pass
def predict(self, data):
return IrisModelOutput(species=”Iris setosa”)
service_title: REST Model Service
models:
- qualified_name: iris_model
class_path: tests.mocks.IrisModel
create_endpoint: true
export PYTHONPATH=./
export REST_CONFIG=examples/rest_config.yaml
uvicorn rest_model_service.main:app — reload

Generating the OpenAPI Contract

export PYTHONPATH=./
export REST_CONFIG=examples/rest_config.yaml
generate_openapi — output_file=example.yaml
info:
title: REST Model Service
version: <version_placeholder>
openapi: 3.0.2
paths:
/:
get:
description: Root of API.
operationId: get_root__get
responses:
‘200’:
content:
application/json:
schema: {}
description: Successful Response
summary: Get Root
/api/models:
get:
description: List of models available.

Closing

Coder and machine learning enthusiast

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store