Skip to main content
Decision trees create hierarchical splitting rules to make predictions. They handle non-linear relationships and don’t require feature scaling.

Using Decision Trees

1
Classification with Decision Trees
2
Train a decision tree classifier on the Iris dataset:
3
import { loadIris } from "deepbox/datasets";
import { accuracy } from "deepbox/metrics";
import { DecisionTreeClassifier } from "deepbox/ml";
import { trainTestSplit } from "deepbox/preprocess";

const iris = loadIris();
const [XTrain, XTest, yTrain, yTest] = trainTestSplit(
  iris.data, 
  iris.target, 
  {
    testSize: 0.2,
    randomState: 42,
  }
);

// Create and train decision tree
const dtc = new DecisionTreeClassifier({ 
  maxDepth: 5, 
  minSamplesSplit: 2 
});
dtc.fit(XTrain, yTrain);

const dtcPred = dtc.predict(XTest);
const acc = accuracy(yTest, dtcPred);

console.log("Decision Tree Classifier");
console.log(`Accuracy: ${(Number(acc) * 100).toFixed(2)}%`);
4
Output:
5
Decision Tree Classifier
Accuracy: 96.67%
6
Regression with Decision Trees
7
Use decision trees for continuous predictions:
8
import { mse, r2Score } from "deepbox/metrics";
import { DecisionTreeRegressor } from "deepbox/ml";
import { tensor } from "deepbox/ndarray";

// Create regression dataset
const XReg = tensor([
  [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
  [6, 7], [7, 8], [8, 9], [9, 10], [10, 11],
  [1, 3], [2, 5], [3, 2], [4, 1], [5, 4],
  [6, 3], [7, 6], [8, 5], [9, 8], [10, 7],
]);
const yReg = tensor([5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 7, 12, 7, 6, 13, 12, 19, 18, 25, 24]);

const [XRegTrain, XRegTest, yRegTrain, yRegTest] = trainTestSplit(
  XReg, 
  yReg, 
  {
    testSize: 0.2,
    randomState: 42,
  }
);

// Train decision tree regressor
const dtr = new DecisionTreeRegressor({ maxDepth: 5 });
dtr.fit(XRegTrain, yRegTrain);

const dtrPred = dtr.predict(XRegTest);
console.log("\nDecision Tree Regressor");
console.log(`MSE: ${mse(yRegTest, dtrPred).toFixed(4)}`);
console.log(`R²:  ${r2Score(yRegTest, dtrPred).toFixed(4)}`);
9
Output:
10
Decision Tree Regressor
MSE: 0.5000
R²:  0.9850
11
Controlling tree complexity
12
Adjust hyperparameters to prevent overfitting:
13
// Shallow tree (less overfitting)
const shallowTree = new DecisionTreeClassifier({ 
  maxDepth: 3,
  minSamplesSplit: 10 
});

// Deep tree (more expressive)
const deepTree = new DecisionTreeClassifier({ 
  maxDepth: 10,
  minSamplesSplit: 2 
});
14
Key parameters:
15
  • maxDepth: Maximum tree depth (prevents overfitting)
  • minSamplesSplit: Minimum samples required to split a node
  • When to Use Decision Trees

    • Handle non-linear relationships naturally
    • No feature scaling required
    • Interpretable decision rules
    • Work with both numerical and categorical features
    • Prone to overfitting on noisy data (use Random Forests instead)

    Next Steps

    Random Forests

    Reduce overfitting with ensemble methods

    Gradient Boosting

    Achieve higher accuracy with boosting

    Build docs developers (and LLMs) love