top of page

Building a Random Forest Classifier from Scratch in Python

  • Writer: Sairam Penjarla
    Sairam Penjarla
  • Jul 17, 2024
  • 7 min read

Introduction

In this blog post, we will walk through the process of implementing a Random Forest classifier from scratch in Python. A Random Forest is a powerful machine learning algorithm that combines multiple decision trees to improve classification accuracy and prevent overfitting. By the end of this post, you will have a deep understanding of how each component of the Random Forest works and how they come together to make accurate predictions.

To follow along, you can clone the GitHub repository and try the code yourself. This hands-on approach will help solidify your understanding of the concepts discussed. Let's get started!

Cloning the Repository

To get started, clone the repository using the following commands:

# Clone the repository
git clone <https://github.com/sairam-penjarla/Random-Forest-from-scratch.git>

# Navigate to the project directory
cd Random-Forest-from-scratch

Understanding the Code Structure

The project is organized into four main files:

  1. tree.py

  2. decision_tree.py

  3. random_forest.py

  4. inference.py

Each file has a specific role in building the Random Forest classifier. Let's dive into each of these files and understand the functions and classes they contain.

tree.py

This file contains the Node class and the resample_data function, which are essential for creating decision trees and generating bootstrap samples.

resample_data

The resample_data function creates a bootstrap sample from the dataset. This means it randomly samples data points with replacement to create a new dataset of the same size as the original.

import numpy as np

def resample_data(features, labels):
    """Generate a bootstrap sample from the dataset."""
    num_samples = features.shape[0]
    sample_indices = np.random.choice(num_samples, size=num_samples, replace=True)
    return features[sample_indices], labels[sample_indices]

Node

The Node class represents a single node in a decision tree. It stores information about the depth of the node, the feature and threshold used for splitting, and pointers to the left and right child nodes.

class Node:
    def __init__(self, current_depth=0, depth_limit=None):
        """Initialize a tree node."""
        self.current_depth = current_depth
        self.depth_limit = depth_limit
        self.left_child = None
        self.right_child = None
        self.split_feature = None
        self.split_value = None
        self.leaf_value = None  # Value for leaf nodes

decision_tree.py

The decision_tree.py file contains the CustomDecisionTree class, which encapsulates the logic for building and using a single decision tree. This class includes methods for fitting the tree to the data, recursively growing the tree, finding the best splits, calculating Gini impurity, and making predictions.

CustomDecisionTree

The CustomDecisionTree class is initialized with an optional depth limit, which determines the maximum depth the tree can grow to.

import numpy as np
from tree import Node, resample_data

class CustomDecisionTree:
    def __init__(self, depth_limit=None):
        """Initialize the CustomDecisionTree with an optional depth limit."""
        self.depth_limit = depth_limit

fit

The fit method trains the decision tree using the provided features and labels. It calculates the number of unique classes and features in the dataset and then starts building the tree.

def fit(self, features, labels):
    """Fit the CustomDecisionTree to the provided features and labels."""
    self.num_classes = len(np.unique(labels))
    self.num_features = features.shape[1]
    self.root = self._build_tree(features, labels)

buildtree

The buildtree method is a recursive function that constructs the decision tree. It determines the majority class at each node and checks if further splitting is possible based on the current depth and the depth limit.

def _build_tree(self, features, labels, current_depth=0):
    """Recursively build the tree."""
    class_counts = [np.sum(labels == c) for c in range(self.num_classes)]
    majority_class = np.argmax(class_counts)
    node = Node(current_depth, self.depth_limit)

    # Check if we should split further
    if current_depth < self.depth_limit and len(features) > 1:
        selected_features = np.random.choice(self.num_features, size=int(np.sqrt(self.num_features)), replace=False)
        best_feature, best_split = self._find_optimal_split(features, labels, selected_features)
        if best_feature is not None:
            node.split_feature = best_feature
            node.split_value = best_split
            left_mask = features[:, best_feature] < best_split
            right_mask = ~left_mask
            features_left, labels_left = features[left_mask], labels[left_mask]
            features_right, labels_right = features[right_mask], labels[right_mask]
            if len(labels_left) > 0 and len(labels_right) > 0:
                node.left_child = self._build_tree(features_left, labels_left, current_depth + 1)
                node.right_child = self._build_tree(features_right, labels_right, current_depth + 1)
            else:
                node.leaf_value = majority_class
        else:
            node.leaf_value = majority_class
    else:
        node.leaf_value = majority_class

    return node

findoptimal_split

The findoptimal_split method identifies the best feature and threshold for splitting the data to minimize the Gini impurity.

def _find_optimal_split(self, features, labels, feature_indices):
    """Find the best feature and threshold for splitting the data."""
    lowest_gini = float('inf')
    optimal_feature, optimal_threshold = None, None

    if len(feature_indices) == 0:
        return optimal_feature, optimal_threshold

    for feature in feature_indices:
        sorted_thresholds = np.unique(features[:, feature])
        for threshold in sorted_thresholds:
            left_mask = features[:, feature] < threshold
            right_mask = features[:, feature] >= threshold
            if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
                continue

            gini_score = self._compute_gini(labels[left_mask], labels[right_mask])

            if gini_score < lowest_gini:
                lowest_gini = gini_score
                optimal_feature = feature
                optimal_threshold = threshold

    return optimal_feature, optimal_threshold

computegini

The computegini method calculates the Gini impurity of a split, which measures the impurity or disorder of the split.

def _compute_gini(self, left_labels, right_labels):
    """Calculate the Gini impurity of a split."""
    n_left, n_right = len(left_labels), len(right_labels)
    if n_left == 0 or n_right == 0:
        return 0
    total_samples = n_left + n_right
    p_left = n_left / total_samples
    p_right = n_right / total_samples
    gini_left = 1 - sum((np.sum(left_labels == c) / n_left) ** 2 for c in np.unique(left_labels))
    gini_right = 1 - sum((np.sum(right_labels == c) / n_right) ** 2 for c in np.unique(right_labels))
    return p_left * gini_left + p_right * gini_right

predict

The predict method takes a set of data points and returns their predicted class labels by traversing the tree.

def predict(self, data):
    """Predict the class labels for the given data."""
    return np.array([self._predict_single(sample, self.root) for sample in data])

predictsingle

The predictsingle method is a recursive function that traverses the tree for a single data point to determine its predicted class label.

def _predict_single(self, sample, node):
    """Recursively predict the class label for a single sample."""
    if node.leaf_value is not None:
        return node.leaf_value
    if sample[node.split_feature] < node.split_value:
        return self._predict_single(sample, node.left_child)
    else:
        return self._predict_single(sample, node.right_child)

By breaking down each function, we can see how the decision tree is constructed, how it finds the best splits, and how it makes predictions. This detailed explanation provides a comprehensive understanding of the CustomDecisionTree class and its methods.

random_forest.py

The random_forest.py file contains the RandomForestClassifier class, which implements a Random Forest ensemble of decision trees. This class includes methods for fitting multiple trees to the data, making predictions using the ensemble, and generating bootstrap samples.

RandomForestClassifier

The RandomForestClassifier class is initialized with parameters for the number of trees (num_trees) and an optional depth limit (depth_limit) for individual decision trees.

import numpy as np
from decision_tree import CustomDecisionTree

class RandomForestClassifier:
    def __init__(self, num_trees=100, depth_limit=None):
        """Initialize the RandomForestClassifier with the number of trees and an optional depth limit."""
        self.num_trees = num_trees
        self.depth_limit = depth_limit
        self.trees = []

fit

The fit method trains the random forest by creating multiple decision trees (CustomDecisionTree instances) and fitting them to bootstrap samples of the data.

def fit(self, features, labels):
    """Fit the RandomForestClassifier to the provided features and labels."""
    self.trees = []

    for _ in range(self.num_trees):
        tree = CustomDecisionTree(depth_limit=self.depth_limit)
        sample_features, sample_labels = self._create_bootstrap_sample(features, labels)
        tree.fit(sample_features, sample_labels)
        self.trees.append(tree)

predict

The predict method combines predictions from all trees in the forest to produce a final prediction. It calculates the mean prediction across all trees for each data point.

def predict(self, features):
    """Predict the class labels for the given features."""
    predictions = np.zeros((features.shape[0], self.num_trees))

    for idx, tree in enumerate(self.trees):
        predictions[:, idx] = tree.predict(features)

    return np.mean(predictions, axis=1)

createbootstrap_sample

The createbootstrap_sample method generates a bootstrap sample from the dataset. It randomly samples data points with replacement to create new datasets for each tree.

def _create_bootstrap_sample(self, features, labels):
    """Generate a bootstrap sample from the dataset."""
    num_samples = features.shape[0]
    sample_indices = np.random.choice(num_samples, size=num_samples, replace=True)
    return features[sample_indices], labels[sample_indices]

By understanding each method in the RandomForestClassifier class, we gain insight into how a Random Forest is constructed, trained, and used for making predictions. This breakdown helps in grasping the mechanics behind ensemble learning and the role of individual decision trees within the ensemble.

inference.py

The inference.py file demonstrates the practical usage of the RandomForestClassifier for training on the Iris dataset and making predictions.

Example Usage with Iris Dataset

# Example usage with Iris dataset
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from random_forest import RandomForestClassifier

# Load the Iris dataset
iris_data = load_iris()
features, labels = iris_data.data, iris_data.target

Splitting the Dataset

# Split the dataset into training and testing sets
features_train, features_test, labels_train, labels_test = train_test_split(features, labels, test_size=0.2, random_state=42)

Initializing and Training the Random Forest Classifier

# Initialize and train the Random Forest classifier
random_forest = RandomForestClassifier(num_trees=100, depth_limit=5)
random_forest.fit(features_train, labels_train)

Making Predictions

# Make predictions on the test set
predicted_labels = random_forest.predict(features_test)

Evaluating Model Accuracy

# Evaluate the accuracy of the model
accuracy = accuracy_score(labels_test, predicted_labels.round())
print(f'Accuracy: {accuracy * 100:.2f}%')

This code snippet provides a complete workflow for using the RandomForestClassifier to train a model on the Iris dataset, make predictions, and evaluate its accuracy. It demonstrates the integration of machine learning libraries (scikit-learn) with custom implementations (RandomForestClassifier from random_forest.py) for ensemble learning.

Conclusion

In this blog post, we delved into the implementation and usage of a Random Forest classifier from scratch using Python. We explored how decision trees are integrated into an ensemble model to improve predictive performance and robustness. Here’s a brief summary of what we covered:

  1. Understanding Random Forest: We introduced the concept of Random Forests, which are composed of multiple decision trees trained on bootstrap samples of the dataset with random feature subsets.

  2. Implementation Details: We walked through the implementation of key components:

  • Decision Tree: Constructed using a custom implementation (CustomDecisionTree) capable of recursive splitting based on Gini impurity.

  • Random Forest: Defined as RandomForestClassifier, which trains multiple decision trees and aggregates their predictions to improve accuracy.

  1. Practical Example: Demonstrated how to apply the Random Forest classifier to the classic Iris dataset. This included:

  • Loading the dataset and splitting it into training and testing sets.

  • Initializing and training the Random Forest model.

  • Making predictions on the test set and evaluating model accuracy using metrics like accuracy score.

  1. Next Steps: Encouraged readers to explore the provided GitHub repository (Random-Forest-from-scratch) for detailed code implementations and further experimentation. Trying out different datasets or modifying parameters can deepen understanding and enhance practical skills in machine learning.

By understanding the workings of Random Forests and implementing them from scratch, you gain insights into ensemble methods and the nuances of decision tree-based models. We hope this blog post has equipped you with foundational knowledge and practical skills to apply Random Forests in your own projects. Happy coding!

Feel free to clone the GitHub repository and explore the code firsthand. Experiment with different datasets and parameters to further enhance your understanding of Random Forests. Happy coding!

Sign up for more like this.

Thanks for submitting!

bottom of page