py
import boto3
import mlflow
import os
import shutil
# Set up S3 client
s3 = boto3.client("s3")
# MLflow Tracking URI
MLFLOW_TRACKING_URI = "http://your-mlflow-server"
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
# Your S3 bucket details
S3_BUCKET = "your-s3-bucket-name"
S3_DESTINATION_PREFIX = "mlflow-artifacts/"
def get_prod_model():
"""Fetch the model version marked as 'prod'."""
client = mlflow.tracking.MlflowClient()
models = client.search_model_versions("name='your-model-name'")
for model in models:
if model.current_stage.lower() == "production": # 'prod' alias maps to 'production' stage
return model
return None
def download_and_upload_artifacts(model):
"""Download model artifacts and upload to S3."""
local_path = f"/tmp/{model.version}"
os.makedirs(local_path, exist_ok=True)
artifact_uri = model.source
mlflow.artifacts.download_artifacts(artifact_uri, local_path)
for root, _, files in os.walk(local_path):
for file in files:
local_file_path = os.path.join(root, file)
s3_key = f"{S3_DESTINATION_PREFIX}{model.version}/{file}"
s3.upload_file(local_file_path, S3_BUCKET, s3_key)
shutil.rmtree(local_path)
def test_lambda_function():
"""Test the Lambda function in Jupyter notebook."""
model = get_prod_model()
if not model:
print("No model found with alias 'prod'")
return
print(f"Found model version: {model.version}")
download_and_upload_artifacts(model)
print(f"Artifacts for model version {model.version} uploaded successfully to S3.")
# Run the test
test_lambda_function()
For further actions, you may consider blocking this person and/or reporting abuse
Top comments (0)