DEV Community

Developer213
Developer213

Posted on

Copy Artifact


py
import boto3
import mlflow
import os
import shutil

# Set up S3 client
s3 = boto3.client("s3")

# MLflow Tracking URI
MLFLOW_TRACKING_URI = "http://your-mlflow-server"
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

# Your S3 bucket details
S3_BUCKET = "your-s3-bucket-name"
S3_DESTINATION_PREFIX = "mlflow-artifacts/"

def get_prod_model():
    """Fetch the model version marked as 'prod'."""
    client = mlflow.tracking.MlflowClient()
    models = client.search_model_versions("name='your-model-name'")

    for model in models:
        if model.current_stage.lower() == "production":  # 'prod' alias maps to 'production' stage
            return model
    return None

def download_and_upload_artifacts(model):
    """Download model artifacts and upload to S3."""
    local_path = f"/tmp/{model.version}"
    os.makedirs(local_path, exist_ok=True)

    artifact_uri = model.source
    mlflow.artifacts.download_artifacts(artifact_uri, local_path)

    for root, _, files in os.walk(local_path):
        for file in files:
            local_file_path = os.path.join(root, file)
            s3_key = f"{S3_DESTINATION_PREFIX}{model.version}/{file}"

            s3.upload_file(local_file_path, S3_BUCKET, s3_key)

    shutil.rmtree(local_path)

def test_lambda_function():
    """Test the Lambda function in Jupyter notebook."""
    model = get_prod_model()

    if not model:
        print("No model found with alias 'prod'")
        return

    print(f"Found model version: {model.version}")
    download_and_upload_artifacts(model)
    print(f"Artifacts for model version {model.version} uploaded successfully to S3.")

# Run the test
test_lambda_function()
Enter fullscreen mode Exit fullscreen mode

Top comments (0)