Log Any Type of Model in MLflow
Question:
I am trying to create a wrapper function that allows my Data Scientists to log their models in MLflow.
This is what the function looks like,
def log_model(self, params, metrics, model, run_name, artifact_path, artifacts=None):
with mlflow.start_run(run_name=run_name):
run_id = mlflow.active_run().info.run_id
mlflow.log_params(params)
mlflow.log_metrics(metrics)
if model:
mlflow.lightgbm.log_model(model, artifact_path=artifact_path)
if artifacts:
for artifact in artifacts:
mlflow.log_artifact(artifact, artifact_path=artifact_path)
return run_id
It can be seen here that the model is being logged as a lightgbm
model, however, the model
parameter that is passed into this function can be of any type.
How can I update this function, so that it will be able to log any kind of model?
As far as I know, there is no log_model
function that comes with mlflow
. It’s always mlflow.<model_type>.log_model
.
How can I go about handling this?
Answers:
I was able to solve this using the following approach,
def log_model(model, artifact_path):
model_class = get_model_class(model).split('.')[0]
try:
log_model = getattr(mlflow, model_class).log_model
log_model(model, artifact_path)
except AttributeError:
logger.info('The log_model function is not available as expected!')
def get_model_class(model):
klass = model.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__
From what I have seen, this will be able to handle most cases. The get_model_class()
method will return the class used to develop the model and based on this, we can use the getattr()
method to extract the relevant log_model()
method.
I am trying to create a wrapper function that allows my Data Scientists to log their models in MLflow.
This is what the function looks like,
def log_model(self, params, metrics, model, run_name, artifact_path, artifacts=None):
with mlflow.start_run(run_name=run_name):
run_id = mlflow.active_run().info.run_id
mlflow.log_params(params)
mlflow.log_metrics(metrics)
if model:
mlflow.lightgbm.log_model(model, artifact_path=artifact_path)
if artifacts:
for artifact in artifacts:
mlflow.log_artifact(artifact, artifact_path=artifact_path)
return run_id
It can be seen here that the model is being logged as a lightgbm
model, however, the model
parameter that is passed into this function can be of any type.
How can I update this function, so that it will be able to log any kind of model?
As far as I know, there is no log_model
function that comes with mlflow
. It’s always mlflow.<model_type>.log_model
.
How can I go about handling this?
I was able to solve this using the following approach,
def log_model(model, artifact_path):
model_class = get_model_class(model).split('.')[0]
try:
log_model = getattr(mlflow, model_class).log_model
log_model(model, artifact_path)
except AttributeError:
logger.info('The log_model function is not available as expected!')
def get_model_class(model):
klass = model.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__
From what I have seen, this will be able to handle most cases. The get_model_class()
method will return the class used to develop the model and based on this, we can use the getattr()
method to extract the relevant log_model()
method.