DEV Community

Cover image for Understanding Stacking Ensemble Method: A Comprehensive Guide -PART IV
Seenivasa Ramadurai
Seenivasa Ramadurai

Posted on

Understanding Stacking Ensemble Method: A Comprehensive Guide -PART IV

Introduction

In the world of machine learning, ensemble methods have proven to be powerful techniques for improving model performance. Among these, the Stacking Ensemble Method stands out as a sophisticated approach that combines multiple base classifiers with a meta-classifier to achieve superior prediction accuracy. In this blog post, we'll explore how we implemented a Stacking Classifier for iris flower species classification.

What is Stacking?

Stacking (also known as Stacked Generalization) is an ensemble learning technique that combines multiple base classifiers with a meta-classifier. The process involves:

  1. Training multiple base classifiers on the training data
  2. Using these base classifiers to make predictions
  3. Using these predictions as features for a meta-classifier
  4. Training the meta-classifier to make the final prediction

Our Implementation

Base Classifiers

We used five different base classifiers:

  • Decision Tree
  • Support Vector Machine (SVM)
  • Random Forest
  • K-Nearest Neighbors (KNN)
  • Naive Bayes

Each classifier brings its unique strengths:

  • Decision Tree: Good for capturing non-linear relationships
  • SVM: Effective for high-dimensional data
  • Random Forest: Robust and handles overfitting well
  • KNN: Good for local patterns
  • Naive Bayes: Fast and works well with probabilistic predictions

Meta Classifier Options

We experimented with several meta-classifiers:

  • Logistic Regression
  • Random Forest
  • Support Vector Machine
  • K-Nearest Neighbors
  • XGBoost (optional)

The Training Process

  1. Data Preparation
   # Load the iris dataset
   iris = load_iris()
   X = iris.data
   y = iris.target

   # Split into training and testing sets
   X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
Enter fullscreen mode Exit fullscreen mode
  1. Base Classifier Configuration
   base_classifiers = [
       ('dt', DecisionTreeClassifier(random_state=42)),
       ('svc', SVC(probability=True, random_state=42)),
       ('rf', RandomForestClassifier(n_estimators=100, random_state=42)),
       ('knn', KNeighborsClassifier(n_neighbors=3)),
       ('nb', GaussianNB())
   ]
Enter fullscreen mode Exit fullscreen mode
  1. Cross-validation Process

    • Split training data into 5 folds
    • For each fold:
      • Train base classifiers on 4/5 of the data
      • Make predictions on the remaining 1/5
      • Collect predictions as meta-features
  2. Meta Classifier Training

    • Combine predictions from all base classifiers
    • Use these as features for the meta-classifier
    • Train the meta-classifier to make final predictions

Performance Evaluation

Comparing Meta Classifiers

We evaluated different meta-classifiers to find the optimal combination:

def evaluate_meta_classifiers(base_classifiers, meta_classifiers, X_train, X_test, y_train, y_test):
    results = {}
    for name, meta_clf in meta_classifiers.items():
        stacking = StackingClassifier(
            estimators=base_classifiers,
            final_estimator=meta_clf,
            cv=5,
            stack_method='predict_proba'
        )
        stacking.fit(X_train, y_train)
        y_pred = stacking.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        results[name] = accuracy
    return results
Enter fullscreen mode Exit fullscreen mode

Visualization

We created visualizations to compare the performance of different meta-classifiers:

plt.figure(figsize=(10, 6))
plt.bar(meta_results.keys(), meta_results.values())
plt.title('Meta Classifier Performance Comparison')
plt.xlabel('Meta Classifier')
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Results and Insights

Key Findings

  1. The stacking ensemble consistently outperformed individual base classifiers
  2. Different meta-classifiers showed varying performance:
    • Logistic Regression: Good balance of performance and interpretability
    • Random Forest: Often showed strong performance
    • SVM: Effective but computationally more expensive
    • KNN: Good for local patterns but sensitive to feature scaling
    • XGBoost: Excellent performance when available

Best Practices

  1. Data Preprocessing

    • Ensure proper scaling of features
    • Handle missing values appropriately
    • Consider feature selection
  2. Base Classifier Selection

    • Choose diverse classifiers
    • Ensure they complement each other
    • Consider computational cost
  3. Meta Classifier Selection

    • Experiment with different options
    • Consider interpretability vs. performance
    • Monitor for overfitting

API Implementation

We also implemented a FastAPI service for making predictions:

@app.post("/predict")
def predict(data: IrisInput):
    with get_model() as model:
        new_data = np.array([[
            data.sepal_length,
            data.sepal_width,
            data.petal_length,
            data.petal_width
        ]])
        prediction = model.predict(new_data)
        probabilities = model.predict_proba(new_data)
        confidence = float(max(probabilities[0]))

        return {
            "prediction": iris.target_names[prediction[0]],
            "confidence": confidence,
            "input_data": data.model_dump()
        }
Enter fullscreen mode Exit fullscreen mode

Program output

Sample of first 5 rows of features:
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]]

Sample of first 5 target values:
[0 0 0 0 0]

First 5 rows of the dataset:
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target target_names
0 5.1 3.5 1.4 0.2 0 setosa
1 4.9 3.0 1.4 0.2 0 setosa
2 4.7 3.2 1.3 0.2 0 setosa
3 4.6 3.1 1.5 0.2 0 setosa
4 5.0 3.6 1.4 0.2 0 setosa

=== Starting Model Training Pipeline ===
Training set size: 105 samples
Testing set size: 45 samples

=== Configuring Base Classifiers ===
Base classifiers configured:

  • dt: DecisionTreeClassifier Parameters: {'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'monotonic_cst': None, 'random_state': 42, 'splitter': 'best'}
  • svc: SVC Parameters: {'C': 1.0, 'break_ties': False, 'cache_size': 200, 'class_weight': None, 'coef0': 0.0, 'decision_function_shape': 'ovr', 'degree': 3, 'gamma': 'scale', 'kernel': 'rbf', 'max_iter': -1, 'probability': True, 'random_state': 42, 'shrinking': True, 'tol': 0.001, 'verbose': False}
  • rf: RandomForestClassifier Parameters: {'bootstrap': True, 'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': 'sqrt', 'max_leaf_nodes': None, 'max_samples': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'monotonic_cst': None, 'n_estimators': 100, 'n_jobs': None, 'oob_score': False, 'random_state': 42, 'verbose': 0, 'warm_start': False}
  • knn: KNeighborsClassifier Parameters: {'algorithm': 'auto', 'leaf_size': 30, 'metric': 'minkowski', 'metric_params': None, 'n_jobs': None, 'n_neighbors': 3, 'p': 2, 'weights': 'uniform'}
  • nb: GaussianNB Parameters: {'priors': None, 'var_smoothing': 1e-09}

=== Configuring Meta Classifier Options ===
Available meta-classifiers:

  • logistic: LogisticRegression Parameters: {'C': 1.0, 'class_weight': None, 'dual': False, 'fit_intercept': True, 'intercept_scaling': 1, 'l1_ratio': None, 'max_iter': 100, 'multi_class': 'deprecated', 'n_jobs': None, 'penalty': 'l2', 'random_state': 42, 'solver': 'lbfgs', 'tol': 0.0001, 'verbose': 0, 'warm_start': False}
  • rf: RandomForestClassifier Parameters: {'bootstrap': True, 'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': 'sqrt', 'max_leaf_nodes': None, 'max_samples': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'monotonic_cst': None, 'n_estimators': 100, 'n_jobs': None, 'oob_score': False, 'random_state': 42, 'verbose': 0, 'warm_start': False}
  • svc: SVC Parameters: {'C': 1.0, 'break_ties': False, 'cache_size': 200, 'class_weight': None, 'coef0': 0.0, 'decision_function_shape': 'ovr', 'degree': 3, 'gamma': 'scale', 'kernel': 'rbf', 'max_iter': -1, 'probability': True, 'random_state': 42, 'shrinking': True, 'tol': 0.001, 'verbose': False}
  • knn: KNeighborsClassifier Parameters: {'algorithm': 'auto', 'leaf_size': 30, 'metric': 'minkowski', 'metric_params': None, 'n_jobs': None, 'n_neighbors': 3, 'p': 2, 'weights': 'uniform'}
  • xgb: XGBClassifier Parameters: {'objective': 'binary:logistic', 'base_score': None, 'booster': None, 'callbacks': None, 'colsample_bylevel': None, 'colsample_bynode': None, 'colsample_bytree': None, 'device': None, 'early_stopping_rounds': None, 'enable_categorical': False, 'eval_metric': None, 'feature_types': None, 'gamma': None, 'grow_policy': None, 'importance_type': None, 'interaction_constraints': None, 'learning_rate': None, 'max_bin': None, 'max_cat_threshold': None, 'max_cat_to_onehot': None, 'max_delta_step': None, 'max_depth': None, 'max_leaves': None, 'min_child_weight': None, 'missing': nan, 'monotone_constraints': None, 'multi_strategy': None, 'n_estimators': None, 'n_jobs': None, 'num_parallel_tree': None, 'random_state': 42, 'reg_alpha': None, 'reg_lambda': None, 'sampling_method': None, 'scale_pos_weight': None, 'subsample': None, 'tree_method': None, 'validate_parameters': None, 'verbosity': None}

=== Comparing Meta Classifier Performance ===

Evaluating logistic as meta-classifier...
logistic accuracy: 1.00

Evaluating rf as meta-classifier...
rf accuracy: 1.00

Evaluating svc as meta-classifier...
svc accuracy: 1.00

Evaluating knn as meta-classifier...
knn accuracy: 1.00

Evaluating xgb as meta-classifier...
xgb accuracy: 1.00

Best meta-classifier: logistic (accuracy: 1.00)

=== Training Final Model with Best Meta Classifier ===

=== Training Stacking Classifier ===
Step 1: Training base classifiers...
This involves:

  1. Splitting training data into 5 folds for cross-validation
  2. For each fold:
    • Training base classifiers on 4/5 of the data
    • Making predictions on the remaining 1/5
    • Collecting these predictions as meta-features
  3. Training base classifiers on full training data

Detailed training process:

Fold 1/5:
Training samples: 84
Validation samples: 21
dt predictions shape: (21, 3)
svc predictions shape: (21, 3)
rf predictions shape: (21, 3)
knn predictions shape: (21, 3)
nb predictions shape: (21, 3)

Fold 2/5:
Training samples: 84
Validation samples: 21
dt predictions shape: (21, 3)
svc predictions shape: (21, 3)
rf predictions shape: (21, 3)
knn predictions shape: (21, 3)
nb predictions shape: (21, 3)

Fold 3/5:
Training samples: 84
Validation samples: 21
dt predictions shape: (21, 3)
svc predictions shape: (21, 3)
rf predictions shape: (21, 3)
knn predictions shape: (21, 3)
nb predictions shape: (21, 3)

Fold 4/5:
Training samples: 84
Validation samples: 21
dt predictions shape: (21, 3)
svc predictions shape: (21, 3)
rf predictions shape: (21, 3)
knn predictions shape: (21, 3)
nb predictions shape: (21, 3)

Fold 5/5:
Training samples: 84
Validation samples: 21
dt predictions shape: (21, 3)
svc predictions shape: (21, 3)
rf predictions shape: (21, 3)
knn predictions shape: (21, 3)
nb predictions shape: (21, 3)

Step 2: Training meta-classifier...
This involves:

  1. Combining predictions from all base classifiers
  2. Using these predictions as features for the meta-classifier
  3. Training the meta-classifier on the combined predictions Training completed!

=== Model Evaluation ===
Making predictions on test set...
Overall Accuracy: 1.00

Classification Report:
precision recall f1-score support

       0       1.00      1.00      1.00        19
       1       1.00      1.00      1.00        13
       2       1.00      1.00      1.00        13

accuracy                           1.00        45
Enter fullscreen mode Exit fullscreen mode

macro avg 1.00 1.00 1.00 45
weighted avg 1.00 1.00 1.00 45

=== Base Classifier Performance ===

Training dt...
dt accuracy: 1.00
Predictions shape: (45,)

Training svc...
svc accuracy: 1.00
Predictions shape: (45,)

Training rf...
rf accuracy: 1.00
Predictions shape: (45,)

Training knn...
knn accuracy: 1.00
Predictions shape: (45,)

Training nb...
nb accuracy: 0.98
Predictions shape: (45,)

=== Meta Classifier Performance ===
Getting meta-features from base classifiers...
Meta-features shape: (45, 15)
Making predictions with meta-classifier...
Meta classifier accuracy: 1.00

=== Confusion Matrix ===

=== Saving Model ===
Model saved as 'stacking_model.pkl'

=== Sample Prediction ===
Input data: [5.1 3.5 1.4 0.2]

Base classifier predictions:
dt:

  • Prediction: 0
  • Probabilities: [1. 0. 0.] svc:
  • Prediction: 0
  • Probabilities: [0.96691777 0.02037569 0.01270654] rf:
  • Prediction: 0
  • Probabilities: [1. 0. 0.] knn:
  • Prediction: 0
  • Probabilities: [1. 0. 0.] nb:
  • Prediction: 0
  • Probabilities: [1.00000000e+00 7.82732978e-17 1.66528708e-24]

Final prediction:
Predicted class: 0
Predicted species: setosa
Confidence: 0.98
All class probabilities: [0.97867634 0.01094243 0.01038123]

API

Image description

Image description

Complete code

"""
Stacking Ensemble Method Implementation

This script implements a machine learning model using the Stacking Classifier algorithm
to predict iris flower species. It combines multiple base classifiers with a final meta-classifier
to improve prediction accuracy.

Base Classifiers:
- Decision Tree
- Support Vector Machine (SVM)
- Random Forest
- K-Nearest Neighbors
- Naive Bayes

Meta Classifier Options:
- Logistic Regression
- Random Forest
- Support Vector Machine
- K-Nearest Neighbors
- XGBoost (optional)

The model is trained on the iris dataset and can predict three species:
- setosa
- versicolor
- virginica

Author: Sreeni
Date: 2025
"""

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import StackingClassifier, RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import numpy as np
import uvicorn
from contextlib import contextmanager, asynccontextmanager
from typing import Generator
from fastapi import FastAPI, Depends
from pydantic import BaseModel, Field
import pandas as pd

# Try importing optional classifiers
try:
    import xgboost as xgb
    XGBOOST_AVAILABLE = True
except ImportError:
    print("Note: XGBoost not available. Some meta-classifier options will be limited.")
    XGBOOST_AVAILABLE = False

# Context manager for model loading
@contextmanager
def get_model() -> Generator:
    """
    Context manager for safely loading the trained model from disk.

    Yields:
        model: The loaded scikit-learn model object

    Example:
        with get_model() as model:
            prediction = model.predict(data)
    """
    try:
        with open('stacking_model.pkl', 'rb') as f:
            model = pickle.load(f)
            yield model
    finally:
        pass  # Clean up if needed

# Load the iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Data preprocessing and visualization
print("Sample of first 5 rows of features:")
print(X[:5])
print("\nSample of first 5 target values:")
print(y[:5])

# Create a pandas DataFrame for better data visualization
df = pd.DataFrame(X, columns=iris.feature_names)
df['target'] = y
target_dict = dict(enumerate(iris.target_names))
df['target_names'] = df['target'].map(target_dict)
print("\nFirst 5 rows of the dataset:")
print(df.head())

# Model Training Pipeline
print("\n=== Starting Model Training Pipeline ===")
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
print(f"Training set size: {len(X_train)} samples")
print(f"Testing set size: {len(X_test)} samples")

# Define base classifiers
print("\n=== Configuring Base Classifiers ===")
base_classifiers = [
    ('dt', DecisionTreeClassifier(random_state=42)),
    ('svc', SVC(probability=True, random_state=42)),
    ('rf', RandomForestClassifier(n_estimators=100, random_state=42)),
    ('knn', KNeighborsClassifier(n_neighbors=3)),
    ('nb', GaussianNB())
]
print("Base classifiers configured:")
for name, clf in base_classifiers:
    print(f"- {name}: {clf.__class__.__name__}")
    if hasattr(clf, 'get_params'):
        print(f"  Parameters: {clf.get_params()}")

# Define meta-classifier options
print("\n=== Configuring Meta Classifier Options ===")
meta_classifiers = {
    'logistic': LogisticRegression(random_state=42),
    'rf': RandomForestClassifier(n_estimators=100, random_state=42),
    'svc': SVC(probability=True, random_state=42),
    'knn': KNeighborsClassifier(n_neighbors=3)
}

# Add XGBoost if available
if XGBOOST_AVAILABLE:
    meta_classifiers['xgb'] = xgb.XGBClassifier(random_state=42)

print("Available meta-classifiers:")
for name, clf in meta_classifiers.items():
    print(f"- {name}: {clf.__class__.__name__}")
    if hasattr(clf, 'get_params'):
        print(f"  Parameters: {clf.get_params()}")

# Function to evaluate stacking with different meta-classifiers
def evaluate_meta_classifiers(base_classifiers, meta_classifiers, X_train, X_test, y_train, y_test):
    results = {}
    for name, meta_clf in meta_classifiers.items():
        print(f"\nEvaluating {name} as meta-classifier...")
        stacking = StackingClassifier(
            estimators=base_classifiers,
            final_estimator=meta_clf,
            cv=5,
            stack_method='predict_proba'
        )

        # Train and evaluate
        stacking.fit(X_train, y_train)
        y_pred = stacking.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        results[name] = accuracy
        print(f"{name} accuracy: {accuracy:.2f}")

    return results

# Evaluate all meta-classifiers
print("\n=== Comparing Meta Classifier Performance ===")
meta_results = evaluate_meta_classifiers(base_classifiers, meta_classifiers, X_train, X_test, y_train, y_test)

# Find best meta-classifier
best_meta = max(meta_results.items(), key=lambda x: x[1])
print(f"\nBest meta-classifier: {best_meta[0]} (accuracy: {best_meta[1]:.2f})")

# Visualize meta-classifier comparison
plt.figure(figsize=(10, 6))
plt.bar(meta_results.keys(), meta_results.values())
plt.title('Meta Classifier Performance Comparison')
plt.xlabel('Meta Classifier')
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Use the best meta-classifier for final model
print("\n=== Training Final Model with Best Meta Classifier ===")
final_meta_classifier = meta_classifiers[best_meta[0]]
stacking = StackingClassifier(
    estimators=base_classifiers,
    final_estimator=final_meta_classifier,
    cv=5,
    stack_method='predict_proba'
)

# Train the Stacking classifier
print("\n=== Training Stacking Classifier ===")
print("Step 1: Training base classifiers...")
print("This involves:")
print("1. Splitting training data into 5 folds for cross-validation")
print("2. For each fold:")
print("   - Training base classifiers on 4/5 of the data")
print("   - Making predictions on the remaining 1/5")
print("   - Collecting these predictions as meta-features")
print("3. Training base classifiers on full training data")

# Custom fit method to show internal workings
print("\nDetailed training process:")
for fold in range(5):
    print(f"\nFold {fold + 1}/5:")
    # Get the indices for this fold
    fold_indices = np.arange(len(X_train))
    fold_size = len(X_train) // 5
    val_indices = fold_indices[fold * fold_size:(fold + 1) * fold_size]
    train_indices = np.concatenate([fold_indices[:fold * fold_size], 
                                  fold_indices[(fold + 1) * fold_size:]])

    print(f"Training samples: {len(train_indices)}")
    print(f"Validation samples: {len(val_indices)}")

    # Train base classifiers on this fold
    for name, clf in base_classifiers:
        clf.fit(X_train[train_indices], y_train[train_indices])
        val_pred = clf.predict_proba(X_train[val_indices])
        print(f"{name} predictions shape: {val_pred.shape}")

print("\nStep 2: Training meta-classifier...")
print("This involves:")
print("1. Combining predictions from all base classifiers")
print("2. Using these predictions as features for the meta-classifier")
print("3. Training the meta-classifier on the combined predictions")

# Train the actual stacking classifier
stacking.fit(X_train, y_train)
print("Training completed!")

# Model Evaluation
print("\n=== Model Evaluation ===")
print("Making predictions on test set...")
y_pred = stacking.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Overall Accuracy: {accuracy:.2f}")

print("\nClassification Report:")
print(classification_report(y_test, y_pred))

# Individual Base Classifier Performance
print("\n=== Base Classifier Performance ===")
for name, clf in base_classifiers:
    print(f"\nTraining {name}...")
    clf.fit(X_train, y_train)
    base_pred = clf.predict(X_test)
    base_acc = accuracy_score(y_test, base_pred)
    print(f"{name} accuracy: {base_acc:.2f}")
    print(f"Predictions shape: {base_pred.shape}")

# Meta Classifier Performance
print("\n=== Meta Classifier Performance ===")
print("Getting meta-features from base classifiers...")
meta_features = stacking.transform(X_test)
print(f"Meta-features shape: {meta_features.shape}")
print("Making predictions with meta-classifier...")
meta_pred = stacking.final_estimator_.predict(meta_features)
meta_acc = accuracy_score(y_test, meta_pred)
print(f"Meta classifier accuracy: {meta_acc:.2f}")

# Visualize the confusion matrix
print("\n=== Confusion Matrix ===")
conf_matrix = confusion_matrix(y_test, y_pred)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# Save the trained model
print("\n=== Saving Model ===")
with open('stacking_model.pkl', 'wb') as f:
    pickle.dump(stacking, f)
print("Model saved as 'stacking_model.pkl'")

# Sample prediction with detailed explanation
print("\n=== Sample Prediction ===")
new_data = np.array([[5.1, 3.5, 1.4, 0.2]])
print("Input data:", new_data[0])

# Get predictions from base classifiers
print("\nBase classifier predictions:")
for name, clf in base_classifiers:
    base_pred = clf.predict(new_data)
    base_prob = clf.predict_proba(new_data)
    print(f"{name}:")
    print(f"- Prediction: {base_pred[0]}")
    print(f"- Probabilities: {base_prob[0]}")

# Get final prediction
prediction = stacking.predict(new_data)
probabilities = stacking.predict_proba(new_data)
confidence = float(max(probabilities[0]))

print("\nFinal prediction:")
print(f"Predicted class: {prediction[0]}")
print(f"Predicted species: {iris.target_names[prediction[0]]}")
print(f"Confidence: {confidence:.2f}")
print(f"All class probabilities: {probabilities[0]}")

# FastAPI Implementation
class IrisInput(BaseModel):
    """
    Pydantic model for input data validation.

    Attributes:
        sepal_length (float): Length of sepal in cm
        sepal_width (float): Width of sepal in cm
        petal_length (float): Length of petal in cm
        petal_width (float): Width of petal in cm
    """
    sepal_length: float = Field(..., gt=0, description="Length of sepal in cm")
    sepal_width: float = Field(..., gt=0, description="Width of sepal in cm")
    petal_length: float = Field(..., gt=0, description="Length of petal in cm")
    petal_width: float = Field(..., gt=0, description="Width of petal in cm")

    class Config:
        json_schema_extra = {
            "example": {
                "sepal_length": 5.1,
                "sepal_width": 3.5,
                "petal_length": 1.4,
                "petal_width": 0.2
            }
        }

app = FastAPI(
    title="Iris Species Prediction API",
    description="""
    This API predicts the species of iris flowers using a Stacking Classifier.

    The model combines multiple base classifiers with a meta-classifier to improve accuracy.
    It accepts four measurements of an iris flower and returns the predicted species.

    Measurements required:
    - Sepal length (cm)
    - Sepal width (cm)
    - Petal length (cm)
    - Petal width (cm)
    """,
    version="1.0.0"
)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Lifecycle manager for the FastAPI application.
    Verifies model availability on startup.
    """
    # Verify model can be loaded
    with get_model() as model:
        pass
    yield
    # Cleanup code here if needed

app.router.lifespan_context = lifespan

@app.get("/", tags=["Status"])
def read_root():
    """
    Root endpoint to check API status.

    Returns:
        dict: Status message indicating the API is running
    """
    return {"message": "Iris Species Prediction API is running successfully"}

@app.post("/predict", tags=["Prediction"])
def predict(data: IrisInput):
    """
    Predict iris species based on flower measurements.

    Args:
        data (IrisInput): Input measurements of the iris flower

    Returns:
        dict: Predicted species of the iris flower
    """
    with get_model() as model:
        new_data = np.array([[
            data.sepal_length,
            data.sepal_width,
            data.petal_length,
            data.petal_width
        ]])
        prediction = model.predict(new_data)
        probabilities = model.predict_proba(new_data)
        confidence = float(max(probabilities[0]))

        return {
            "prediction": iris.target_names[prediction[0]],
            "confidence": confidence,
            "input_data": data.model_dump()
        }

# Run the application
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=5020)
Enter fullscreen mode Exit fullscreen mode

Conclusion

The Stacking Ensemble Method proved to be an effective approach for iris flower classification. By combining multiple base classifiers with an optimized meta-classifier, we achieved robust and accurate predictions. The implementation provides a flexible framework that can be adapted for various classification tasks.

Future Improvements

  1. Add more base classifiers
  2. Implement feature importance analysis
  3. Add hyperparameter tuning
  4. Include more evaluation metrics
  5. Add model interpretability tools

Thanks
Sreeni Ramadorai

Top comments (0)