DEV Community

Cover image for Decision Tree, Classification: Supervised Machine Learning
Harsh Mishra
Harsh Mishra

Posted on

Decision Tree, Classification: Supervised Machine Learning

What is a Decision Tree?

Definition and Purpose

Decision Tree is a supervised learning technique used in machine learning and data science for both classification and regression tasks. It uses a tree-like model of decisions and their possible consequences, including outcomes, resource costs, and utility. The main purpose of decision trees in classification is to create a model that predicts the value of a target variable based on several input variables by learning simple decision rules inferred from the data features.

Key Objectives:

  • Prediction: Categorizing new data points into predefined classes.
  • Interpretability: Providing a clear and intuitive representation of the decision-making process.
  • Handling Non-linearity: Capturing complex, non-linear relationships between features and target variables.

Decision Tree Structure

A decision tree is composed of the following components:

  • Root Node: Represents the entire dataset and the starting point of the tree.
  • Internal Nodes: Represent the features used to split the data.
  • Branches: Represent the outcome of a decision or test.
  • Leaf Nodes (Terminal Nodes): Represent the final class labels (for classification) or predicted values (for regression).

Decision Tree Algorithm

  1. Selecting the Best Feature: The algorithm selects the best feature to split the data at each node based on a criterion such as Gini impurity, entropy, or information gain.

  2. Splitting the Data: The selected feature splits the data into subsets that maximize the homogeneity of the target variable within each subset.

  3. Recursively Splitting: The process is repeated recursively for each subset until a stopping criterion is met (e.g., maximum depth, minimum samples per leaf, or no further information gain).

  4. Assigning Class Labels: Once the splitting is complete, each leaf node is assigned a class label based on the majority class of the data points in that node.

Cost Function and Loss Minimization in Decision Trees

Cost Function

The cost function in decision trees quantifies the impurity or heterogeneity of the data in the nodes. The goal is to minimize this impurity by choosing the best splits at each node.

Gini Impurity: Measures the likelihood of a random sample being misclassified.

Entropy: Measures the disorder or impurity in the dataset.

Information Gain: Measures the reduction in entropy after a dataset is split on an attribute.

Loss Minimization (Optimization)

Loss minimization in decision trees involves finding the best splits that minimize impurity (Gini impurity or entropy) and maximize information gain.

Steps of Optimization:

  1. Calculate Impurity: For each node, calculate the impurity (Gini impurity or entropy) for the current split.

  2. Evaluate Splits: For each possible split, evaluate the resulting impurity for the child nodes.

  3. Select Best Split: Choose the split that results in the lowest impurity or highest information gain.

  4. Repeat: Recursively apply the process to each child node until a stopping criterion is met.

Decision Tree (Binary Classification) Example

Decision trees are a versatile machine learning technique used for both classification and regression tasks. This example demonstrates how to implement a decision tree for binary classification using synthetic data, evaluate the model's performance, and visualize the decision boundary.

Python Code Example

1. Import Libraries

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
Enter fullscreen mode Exit fullscreen mode

This block imports the necessary libraries for data manipulation, plotting, and machine learning.

2. Generate Sample Data

np.random.seed(42)  # For reproducibility

# Generate synthetic data for 2 classes
n_samples = 1000
n_samples_per_class = n_samples // 2

# Class 0: Centered around (-1, -1)
X0 = np.random.randn(n_samples_per_class, 2) * 0.7 + [-1, -1]

# Class 1: Centered around (1, 1)
X1 = np.random.randn(n_samples_per_class, 2) * 0.7 + [1, 1]

# Combine the data
X = np.vstack([X0, X1])
y = np.hstack([np.zeros(n_samples_per_class), np.ones(n_samples_per_class)])

# Shuffle the dataset
shuffle_idx = np.random.permutation(n_samples)
X, y = X[shuffle_idx], y[shuffle_idx]
Enter fullscreen mode Exit fullscreen mode

This block generates synthetic data with two features, where the target variable y is defined based on the class center, simulating a binary classification scenario.

3. Split the Dataset

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Enter fullscreen mode Exit fullscreen mode

This block splits the dataset into training and testing sets for model evaluation.

4. Create and Train the Decision Tree Classifier

model = DecisionTreeClassifier(random_state=42, max_depth=1)  # Limit depth for visualization
model.fit(X_train, y_train)
Enter fullscreen mode Exit fullscreen mode

This block initializes the decision tree model with a limited depth and trains it using the training dataset.

5. Make Predictions

y_pred = model.predict(X_test)
Enter fullscreen mode Exit fullscreen mode

This block uses the trained model to make predictions on the test set.

6. Evaluate the Model

accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)

print(f"Accuracy: {accuracy:.4f}")
print("\nConfusion Matrix:")
print(conf_matrix)
print("\nClassification Report:")
print(class_report)
Enter fullscreen mode Exit fullscreen mode

Output:

Accuracy: 0.9200

Confusion Matrix:
[[96  8]
 [ 8 88]]

Classification Report:
              precision    recall  f1-score   support

         0.0       0.92      0.92      0.92       104
         1.0       0.92      0.92      0.92        96

    accuracy                           0.92       200
   macro avg       0.92      0.92      0.92       200
weighted avg       0.92      0.92      0.92       200
Enter fullscreen mode Exit fullscreen mode

This block calculates and prints the accuracy, confusion matrix, and classification report, providing insights into the model's performance.

7. Visualize the Decision Boundary

x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                     np.arange(y_min, y_max, 0.1))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.figure(figsize=(10, 8))
plt.contourf(xx, yy, Z, alpha=0.4, cmap='RdYlBu')
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap='RdYlBu', edgecolor='black')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Binary Decision Tree Classification")
plt.colorbar(scatter)
plt.show()
Enter fullscreen mode Exit fullscreen mode

This block visualizes the decision boundary created by the decision tree model, illustrating how the model separates the two classes in the feature space.

Output:

Binary Decision Tree Classification

This structured approach demonstrates how to implement and evaluate a decision tree for binary classification tasks, providing a clear understanding of its capabilities. The visualization of the decision boundary aids in interpreting the model's predictions.

Decision Tree (Multiclass Classification) Example

Decision Trees can be effectively applied to multiclass classification tasks. This example demonstrates how to implement a Decision Tree using synthetic data, evaluate the model's performance, and visualize the decision boundary for five classes.

Python Code Example

1. Import Libraries

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
Enter fullscreen mode Exit fullscreen mode

This block imports the necessary libraries for data manipulation, plotting, and machine learning.

2. Generate Sample Data with 5 Classes

np.random.seed(42)  # For reproducibility
n_samples = 1000  # Total number of samples
n_samples_per_class = n_samples // 5  # Ensure this is exactly n_samples // 5

# Class 0: Top-left corner
X0 = np.random.randn(n_samples_per_class, 2) * 0.5 + [-2, 2]

# Class 1: Top-right corner
X1 = np.random.randn(n_samples_per_class, 2) * 0.5 + [2, 2]

# Class 2: Bottom-left corner
X2 = np.random.randn(n_samples_per_class, 2) * 0.5 + [-2, -2]

# Class 3: Bottom-right corner
X3 = np.random.randn(n_samples_per_class, 2) * 0.5 + [2, -2]

# Class 4: Center
X4 = np.random.randn(n_samples_per_class, 2) * 0.5 + [0, 0]

# Combine the data
X = np.vstack([X0, X1, X2, X3, X4])
y = np.hstack([np.zeros(n_samples_per_class), 
               np.ones(n_samples_per_class), 
               np.full(n_samples_per_class, 2),
               np.full(n_samples_per_class, 3),
               np.full(n_samples_per_class, 4)])

# Shuffle the dataset
shuffle_idx = np.random.permutation(n_samples)
X, y = X[shuffle_idx], y[shuffle_idx]
Enter fullscreen mode Exit fullscreen mode

This block generates synthetic data for five classes located in different regions of the feature space.

3. Split the Dataset

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Enter fullscreen mode Exit fullscreen mode

This block splits the dataset into training and testing sets for model evaluation.

4. Create and Train the Decision Tree Classifier

model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)
Enter fullscreen mode Exit fullscreen mode

This block initializes the Decision Tree classifier and trains it using the training dataset.

5. Make Predictions

y_pred = model.predict(X_test)
Enter fullscreen mode Exit fullscreen mode

This block uses the trained model to make predictions on the test set.

6. Evaluate the Model

accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)

print(f"Accuracy: {accuracy:.4f}")
print("\nConfusion Matrix:")
print(conf_matrix)
print("\nClassification Report:")
print(class_report)
Enter fullscreen mode Exit fullscreen mode

Output:

Accuracy: 0.9900

Confusion Matrix:
[[43  0  0  0  0]
 [ 0 40  0  0  1]
 [ 0  0 35  0  0]
 [ 0  0  0 33  0]
 [ 1  0  0  0 47]]

Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      1.00      0.99        43
         1.0       1.00      0.98      0.99        41
         2.0       1.00      1.00      1.00        35
         3.0       1.00      1.00      1.00        33
         4.0       0.98      0.98      0.98        48

    accuracy                           0.99       200
   macro avg       0.99      0.99      0.99       200
weighted avg       0.99      0.99      0.99       200
Enter fullscreen mode Exit fullscreen mode

This block calculates and prints the accuracy, confusion matrix, and classification report, providing insights into the model's performance.

7. Visualize the Decision Boundary

x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                     np.arange(y_min, y_max, 0.1))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.figure(figsize=(10, 8))
plt.contourf(xx, yy, Z, alpha=0.4, cmap='viridis')
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', edgecolor='black')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Multiclass Decision Tree Classification (5 Classes)")
plt.colorbar(scatter)
plt.show()
Enter fullscreen mode Exit fullscreen mode

This block visualizes the decision boundaries created by the Decision Tree classifier, illustrating how the model separates the five classes in the feature space.

Output:

Decision Tree Multiclass Classification

This structured approach demonstrates how to implement and evaluate a Decision Tree for multiclass classification tasks, providing a clear understanding of its capabilities and the effectiveness of visualizing decision boundaries.

Top comments (0)