Serialising Arbitrary ML models with MLflow
Mlflow has a myriad of tools to get your models to production. One reason I find it so useful is the openness and dynamicism of the Mlflow API.
Most of the time, deploying a model will involve more steps than just executing a feedforward pass of an off-the-shelf framework like Scikit Learn or PyTorch.
MLflow allows you to serialise arbitrary model code into an mlflow.pyfunc
object, compatible with the broader Mlflow ecosystem.
“Whether you’re working with a niche machine learning framework, need custom preprocessing steps, or want to integrate unique prediction logic, mlflow.pyfunc
is the tool for the job.”
Here I’ll demonstrate how to:
- Serialise (save) any arbitrary Python model as a MLflow compatible object.
- Override the built-in ‘predict’ method to allow for more flexible and complete model deployments.
All the code is available here. Examples are derived from the MLflow documentation.
The mlflow.pyfunc.PythonModel
class
The core idea is to inherit compatibility with Mlflow from PythonModel
.
1. Serialise An Arbitrary Models
An arbitrary model can inherit PythonModel
, and subsequently, be serialised as a Mlflow-compatible compressed model with joblib
.
The key takeaway is that any logic that encompasses the model pipeline can be executed in the model class, and is thereafter readily compatible with the MLflow API (serialisation, tracking, deployment etc).
Consider a simple model that adds n
to each input. The mlflow.pyfunc.PythonModel
module necessitates that we instantiate a predict
method.
import mlflow.pyfunc
import pandas as pd
import logging
class AddN(mlflow.pyfunc.PythonModel):
"""
A custom model that adds a specified value `n` to all columns of the input DataFrame.
Attributes:
-----------
n : int
The value to add to input columns.
"""
def __init__(self, n):
"""
Constructor method. Initializes the model with the specified value `n`.
Parameters:
-----------
n : int
The value to add to input columns.
"""
self.n = n
def predict(self, context, model_input, params=None):
"""
Prediction method for the custom model.
Parameters:
-----------
context : Any
Ignored in this example. It's a placeholder for additional data or utility methods.
model_input : pd.DataFrame
The input DataFrame to which `n` should be added.
params : dict, optional
Additional prediction parameters. Ignored in this example.
Returns:
--------
pd.DataFrame
The input DataFrame with `n` added to all columns.
"""
return model_input.apply(lambda column: column + self.n)
Now saving (serialising), loading and performing inference on the model is seamless.
logging.basicConfig(level=logging.DEBUG)
if __name__ == '__main__':
# Define the path to save the model
model_path = "./tmp/add_n_model"
# Create an instance of the model with `n=5`
logging.info("Creating a model with n=5.")
add5_model = AddN(n=5)
# Save the model using MLflow
logging.info("Saving the model to %s." % model_path)
mlflow.pyfunc.save_model(path=model_path, python_model=add5_model)
logging.info("Model saved successfully!")
# Load the saved model
logging.info("Loading the model from %s." % model_path)
loaded_model = mlflow.pyfunc.load_model(model_path)
# Define a sample input DataFrame
model_input = pd.DataFrame([range(10)])
# Use the loaded model to make predictions
logging.info("Making predictions with the loaded model.")
model_output = loaded_model.predict(model_input)
_output = "The model output is: \n{}".format(model_output)
logging.info(_output)
logging.info("Exited with status_code=1. Model loaded and predictions made successfully!")
logging.info("Runtime Complete.")
2. Add Complexity, Hyperparameters & Model Signature
To demonstrate a more complex model build, we might be interested in:
- Model hyperparameters.
- MLflow signatures.
- An arbitrary return type (in this example, a figure).
Consider the Lissajous model from this MLflow tutorial, the key takeaway is the separation of (hyper)parameters delta
and the data x,y
.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import mlflow.pyfunc
from mlflow.models import infer_signature
class Lissajous(mlflow.pyfunc.PythonModel):
def __init__(self, A=1, B=1, num_points=1000):
self.A = A
self.B = B
self.num_points = num_points
self.t_range = (0, 2 * np.pi)
def generate_lissajous(self, a, b, delta):
t = np.linspace(self.t_range[0], self.t_range[1], self.num_points)
x = self.A * np.sin(a * t + delta)
y = self.B * np.sin(b * t)
return pd.DataFrame({"x": x, "y": y})
def predict(self, context, model_input, params=None):
"""
Generate and plot the Lissajous curve with annotations for parameters.
Args:
- model_input (pd.DataFrame): DataFrame containing columns 'a' and 'b'.
- params (dict, optional): Dictionary containing optional parameter 'delta'.
"""
# Extract a and b values from the input DataFrame
a = model_input["a"].iloc[0]
b = model_input["b"].iloc[0]
# Extract delta from params or set it to 0 if not provided
delta = params.get("delta", 0)
# Generate the Lissajous curve data
df = self.generate_lissajous(a, b, delta)
sns.set_theme()
# Create the plot components
fig, ax = plt.subplots(figsize=(10, 8))
ax.plot(df["x"], df["y"])
ax.set_title("Lissajous Curve")
# Define the annotation string
annotation_text = f"""
A = {self.A}
B = {self.B}
a = {a}
b = {b}
delta = {np.round(delta, 2)} rad
"""
# Add the annotation with a bounding box outside the plot area
ax.annotate(
annotation_text,
xy=(1.05, 0.5),
xycoords="axes fraction",
fontsize=12,
bbox={"boxstyle": "round,pad=0.25", "facecolor": "aliceblue", "edgecolor": "black"},
)
# Adjust plot borders to make space for the annotation
plt.subplots_adjust(right=0.65)
plt.show()
plt.close()
# Return the plot
return fig
Model Signature
The signature description is useful for many reasons: Input validation, Consistency, self-documentation, versioning, automated API generation, reproducibility, ease of use & error handling.
The separation of inputs & (hyper)parameters.
The separation of params
from model_input
allows for easy adjustments to model behaviour without altering the main input data, improves versioning and compatibility, and aligns with good API design practices.
This separation also facilitates batch processing; can lead to performance optimizations; and provides a more intuitive and adaptable interface.
This is readily extendable to hyper-parameters.
logging.basicConfig(level=logging.DEBUG)
if __name__ == "__main__":
# Define the path to save the model
logging.info("Instantiating Lissajous model.")
model_path = "./tmp/lissajous_model"
# rm files from previous runs
if os.path.exists(model_path):
shutil.rmtree(model_path, ignore_errors=True)
# Create an instance of the model, overriding the default instance variables `A`, `B`, and `num_points`
model_10k_standard = Lissajous(1, 1, 10_000)
# Infer the model signature, ensuring that we define the params that will be available for customization at inference time
signature = infer_signature(
model_input=pd.DataFrame([{"a": 1, "b": 2}]),
params={"delta": np.pi / 5})
# Save our custom model to the path we defined, with the signature that we declared
mlflow.pyfunc.save_model(
path=model_path,
python_model=model_10k_standard,
signature=signature)
logging.info("Model saved successfully!")
# Load our custom model from the local artifact store
loaded_pyfunc_model = mlflow.pyfunc.load_model(model_path)
# Inference
# Define the input DataFrame. In our custom model, we're reading only the first row of data to generate a plot.
model_input = pd.DataFrame({"a": [3], "b": [2]})
# Define a params override for the `delta` parameter
params = {"delta": np.pi / 3}
# Run predict, which will call our internal method `generate_lissajous` before generating a `matplotlib` plot showing the curve
fig = loaded_pyfunc_model.predict(model_input, params)
logging.info('Runtime Complete.')
3. Delivering a More Complete Model API (Overriding the predict
method).
In order to maintain compatibility and consistency with the broader API, MLflow introduces several restrictions when performing model serialisation.
MLflow Tracking
To run the below code first launch a local MLflow tracking server (in a separate terminal):
mlflow server --host 127.0.0.1 --port 8080
A more expressive model
Suppose your model class has various prediction & process methods that are required at different stages in production. We can override this PyFunc.predict
method to maintain this flexibility.
Consider a scikit-learn
model with the following methods:
predict
predict_proba
predict_log_proba
Training & evaluating a sci-kit
model is straightforward.
# load data
iris = load_iris()
x = iris.data[:, 2:]
y = iris.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=9001)
# train model
model = LogisticRegression(random_state=0, max_iter=5_000, solver="newton-cg").fit(x_train, y_train)
# evaluate model/predict
print('Prediction: \n', model.predict(x_test)[:5])
print('Predictive Probability: \n', model.predict_proba(x_test)[:5])
print('Score: ', model.score(x_test, y_test))
If we save & load this model using the PyFunc
API, we only retain the predict
method.
# save model 1: sklearn without predict_proba
with mlflow.start_run() as run:
mlflow.sklearn.save_model(
sk_model=model,
path=sklearn_path,
input_example=x_train[:2],
)
# load and predict (only .predict is instantiated)
loaded_logreg_model = mlflow.pyfunc.load_model(sklearn_path)
y_pred = loaded_logreg_model.predict(x_test)
print('Prediction: \n', y_pred[:5])
# Will raise exception!!
loaded_logreg_model.predict_proba(x_test)
The MLflow PyFunc
API only exposes the predict
method. We can override this behaviour by introducing an additional argument to specify the prediction method.
We can write a ModelWrapper
class that extends the functionality of PyFunc.PythonModel
.
class ModelWrapper(PythonModel):
def __init__(self):
self.model = None
def load_context(self, context):
from joblib import load
self.model = load(context.artifacts["model_path"])
def predict(self, context, model_input, params=None):
params = params or {"predict_method": "predict"}
predict_method = params.get("predict_method")
if predict_method == "predict":
return self.model.predict(model_input)
elif predict_method == "predict_proba":
return self.model.predict_proba(model_input)
elif predict_method == "predict_log_proba":
return self.model.predict_log_proba(model_input)
else:
raise ValueError(f"The prediction method '{predict_method}' is not supported.")
More generally, the predict
method must encapsulate all inference functionality.
This time we save/serialise the model with joblib
(which is independent from MLflow). The model artifacts
and wrapper
are passed to the save_model
method independently. This is incredibly powerful because we did not need to re-write the code to inherit from ModelWrapper
. The signature
is not strictly required but is useful when using the broader Mlflow API.
# serialize the model (independent of MLflow).
model_directory = "./tmp/sklearn_model.joblib"
dump(model, model_directory)
# load model artifact
artifacts = {"model_path": model_directory}
# define the signature associated with the model
signature = infer_signature(x_train, params={"predict_method": "predict_proba"})
# model path
pyfunc_path = "./tmp/dynamic_regressor"
# Save the custom model to the specified path
with mlflow.start_run() as run:
mlflow.pyfunc.save_model(
path=pyfunc_path,
python_model=ModelWrapper(),
input_example=x_train,
signature=signature,
artifacts=artifacts,
pip_requirements=["joblib", "sklearn"],
)
# Load the custom model
loaded_dynamic = mlflow.pyfunc.load_model(pyfunc_path)
y_pred = loaded_dynamic.predict(x_test)
print('Prediction: \n', y_pred[:5])
# custom model with predict_proba
y_pred_proba = loaded_dynamic.predict(x_test, params={"predict_method": "predict_log_proba"})
print('Predictive Probability: \n', y_pred_proba[:5])
We now maintain full flexibility and functionality but are still able to port our model to the Mlflow ecosystem. predict_log_proba
is accessible by specifying the predict_method
parameter.
It is straightforward to see how other/custom methods can be retained in this way.
All the code is available here.