A decision tree is a machine learning model that makes predictions by learning a sequence of yes/no questions about features in the data. Random forests are ensembles of many decorrelated decision trees whose predictions are averaged, dramatically reducing variance and overfitting. Together, they are among the most widely used ML algorithms in industry — especially for tabular data, where they often outperform deep learning.
How decision trees learn
A decision tree learns by recursively splitting the training data into subsets. At each node, it searches for the feature and split threshold that best separates the data — measured by impurity reduction.
| Splitting criterion | Formula | Used for | Notes |
|---|---|---|---|
| Gini impurity | G = 1 - Σ pᵢ² | Classification | Measures probability of misclassifying a randomly drawn sample. Gini=0 means pure (all one class). |
| Information gain (Entropy) | IG = H(parent) - weighted H(children) | Classification | Used in ID3/C4.5. Slightly favors features with more values. |
| Mean Squared Error reduction | MSE_parent - weighted MSE_children | Regression | Splits that produce the most homogeneous child groups in terms of target value |
Decision tree and Random Forest with scikit-learn — the most common ML workflow
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# ── Single decision tree ──────────────────────────────────────────────────
dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(X_train, y_train)
# Print the tree in text form — fully interpretable!
print(export_text(dt, feature_names=["sepal_len","sepal_wid","petal_len","petal_wid"]))
print("Decision Tree accuracy:", dt.score(X_test, y_test)) # e.g. 0.933
# ── Random Forest (ensemble of 100 trees) ────────────────────────────────
rf = RandomForestClassifier(
n_estimators=100, # number of trees
max_depth=None, # trees grow until pure or min_samples_split
max_features="sqrt", # each split considers √n_features candidates
bootstrap=True, # each tree trained on a bootstrap sample
random_state=42
)
rf.fit(X_train, y_train)
print("Random Forest accuracy:", rf.score(X_test, y_test)) # e.g. 0.967
# Feature importances — how much each feature reduces impurity on average
for name, importance in zip(["sepal_len","sepal_wid","petal_len","petal_wid"],
rf.feature_importances_):
print(f" {name}: {importance:.3f}")
# → petal_len: 0.441, petal_wid: 0.423 (dominate prediction)The ensemble idea: why 100 weak trees beat one strong tree
A single deep decision tree overfits — it memorizes training data perfectly but generalizes poorly. Random Forests fix this through three sources of randomness that decorrelate the trees:
- Bootstrap sampling (bagging): each tree is trained on a random sample with replacement (~63% unique examples per tree). The ~37% not selected form the 'out-of-bag' (OOB) set, giving a free internal validation estimate.
- Random feature subsets: at each split, only √p features are considered (where p is total features). This forces trees to use different features and prevents any one strong feature from dominating all trees.
- Averaging predictions: final prediction is the majority vote (classification) or mean (regression) across all trees. Errors of individual trees are largely uncorrelated and cancel out.
| Property | Decision Tree | Random Forest |
|---|---|---|
| Interpretability | High — you can print and read every decision | Low — 100 trees are hard to inspect |
| Overfitting risk | Very high (especially with unlimited depth) | Low — averaging prevents overfitting |
| Hyperparameter sensitivity | Very sensitive to max_depth, min_samples | Robust — n_estimators just needs to be large enough |
| Training time | Fast | Slower (100× a single tree) but parallelizable |
| Out-of-box accuracy | Moderate | Consistently strong; often beats neural nets on tabular data |
| Feature importance | Built-in but unstable for correlated features | More stable permutation importance |
When to use trees vs neural networks vs gradient boosting
| Data type | Best algorithm | Why |
|---|---|---|
| Tabular data (CSV, SQL) | Gradient Boosted Trees (XGBoost, LightGBM, CatBoost) | Consistently wins Kaggle competitions on tabular data; handles missing values natively; fast |
| Tabular data, interpretability required | Decision Tree or shallow RF | Fully inspectable; regulatory/compliance-safe (e.g. credit scoring, medicine) |
| Images, audio, video | Deep learning (CNN, ViT) | Trees cannot exploit spatial/temporal structure; neural nets are far better |
| Text / language | LLMs or fine-tuned Transformers | Trees have no concept of word order or meaning |
| Mixed tabular + text/image | Ensemble: tree model for tabular + neural net for text/image → blend | Combine strengths of both |
| Small dataset (<1K samples) | Random Forest or gradient boosting | Deep learning needs far more data to generalize |
Practice questions
- What is information gain and how does it guide decision tree splits? (Answer: Information gain = Entropy(parent) - weighted average Entropy(children). Entropy = -Σ p_i log₂(p_i). A split is good if it reduces uncertainty (entropy). Pure children (all one class) have entropy=0 — maximum information gain. The tree greedily chooses the split that maximises information gain at each node. Gini impurity (1 - Σ p_i²) is a computationally cheaper alternative used by sklearn's DecisionTreeClassifier by default.)
- A decision tree with depth=20 achieves 99% training accuracy and 68% test accuracy. What is happening and what are three fixes? (Answer: Overfitting — the tree has memorised training data. Fixes: (1) max_depth=5–10: limit tree depth to prevent memorisation. (2) min_samples_leaf=10: require at least 10 samples per leaf — prevents splitting on noise. (3) min_samples_split=20: require at least 20 samples to split a node. (4) Cross-validation to find optimal depth. (5) Use Random Forest instead of single tree.)
- Why does Random Forest decorrelate trees and why is this important for variance reduction? (Answer: Random Forest uses random feature subsets at each split — different trees learn different aspects of the data. Without random feature selection (Bagging alone): all trees tend to use the same dominant features → trees are correlated → their average does not reduce variance much. With max_features=sqrt{p}: trees are decorrelated → their errors are relatively independent → averaging genuinely reduces variance. The theoretical benefit: variance of average of n correlated trees ≠ variance/n. Independence is needed for maximum variance reduction.)
- What is the difference between feature importance computed from Random Forest vs permutation importance? (Answer: RF impurity importance: average decrease in Gini/entropy when a feature is used to split, across all trees. Biased toward high-cardinality features (features with many unique values get more split opportunities). Permutation importance: randomly shuffle one feature's values and measure how much model performance drops. Model-agnostic, less biased. The gold standard for feature selection. Can be computed on test set rather than training set — avoids inflated importance from memorised features.)
- What is gradient boosting and how does it improve on AdaBoost? (Answer: AdaBoost: sequentially trains weak learners (stumps) on reweighted data — misclassified points get higher weights. Final prediction: weighted majority vote. Gradient boosting: each new tree fits the negative gradient of the loss function (for regression: residuals; for classification: logistic loss gradients). More general — works with any differentiable loss. AdaBoost minimises exponential loss implicitly; gradient boosting can optimise any loss. Friedman (2001) unified the framework.)