Road to ML Engineer #9 - Decision Trees

Last Edited: 8/10/2024

The blog post introduces how decision trees can be utlized in machine learning.

ML

So far, we have been using linear regression to solve various problems. However, fitting a line is not the only approach we can take. We can also use trees.

Decision Trees

Let's say we have a classification problem where we use the height and weight of a dog to predict its species. Instead of using softmax regression, we can use a decision tree like the one below to perform classification.

Decision Tree Example

We call each element in the tree a node, and the nodes at the bottom containing the predicted class are called leaf nodes. We check each condition at the nodes to move down the tree until we end up at a leaf node to make a prediction using the decision tree. We can see from the diagram above that each condition effectively partitions the space for classifying the data.

This method can also be used for regression. For example, let's predict the height of a dog given its species and weight.

Decision Tree Example

The opacity of the points in the diagram indicates the height of the dogs. Fortunately, we only needed to partition a few times to predict the height perfectly. Decision trees used for classification are called classification trees, while those used for regression are called regression trees. We can use either type depending on the nature of the task.

How to Build Trees

The question now is how to build such trees. Surprisingly, it's quite simple. For both discrete and continuous variables, we take the value in the middle of two adjacent points as a candidate for splitting, and we compute either the sum of entropy (for classification trees) or the sum of squared residuals (for regression trees) for all the partitioned spaces for each split.

E=plog(p)SSR=(yf(x))2 E = - \sum p log(p) \\ SSR = \sum (y - f(x))^2

, where pp is the probability of observing that class in the partitioned space, yy is the true value, and f(x)f(x) is a predicted value from the tree, which is the average of the partitioned space. We keep track of the sum of these values and terminate the creation of branches once we see no further improvement. This process is similar to how we track the cost function to inform parameter updates in a model. In decision trees, instead of calling it a cost function, we refer to it as impurity. The decision tree chooses partitions to minimize the impurity of the predictions.

Problems with Decision Trees

Decision trees are straightforward to build and understand, but they have a critical drawback: overfitting. Decision trees are more prone to overfitting because they can create a leaf node for each data point, achieving a zero sum of squared residuals and zero entropy, or perfect purity with zero bias. In fact, the basic setup has no inherent mechanism to prevent this, leading to poor performance on testing datasets and high variance, as described by the bias-variance tradeoff.

Cost Complexity Pruning

Because decision trees are prone to overfitting, there needs to be a mechanism in place to prevent it. One such mechanism is cost complexity pruning, which penalizes trees with more leaf nodes. Instead of using the sum of entropy or sum of squared residuals directly, it adds a tree complexity penalty:

Scoreclf=E+αTScorereg=SSR+αT Score_{clf} = E + \alpha T \\ Score_{reg} = SSR + \alpha T

, where α\alpha is the penality rate and TT is the number of leaf nodes. This encourages the algorithm to favor simpler trees, much like how regularization adds penalty terms to encourage smaller parameters. The penalty rate, like the regularization rate, is a hyperparameter that controls the size of the penalty and can be tuned via cross-validation. When we remove leaf nodes and make the tree smaller, we call that pruning. The approach described above achieves pruning indirectly by adding the cost to the complexity, hence the name cost complexity pruning.

Code Implementation

sklearn library provides DecisionTreeClassifer and DecisionTreeRegressor that allow us to perform the above easily. Let's use DecisionTreeClassifer on the species classification on the Iris dataset (multiclass). (I recommend trying DecisionTreeRegressor out yourself.) The data exploration and preprocessing steps are omitted here.

Step 3. Model

We can set up a cross-validation function for choosing the right penalty rate for cost complexity pruning, as shown below.

from sklearn.model_selection import StratifiedKFold
from sklearn import tree
 
def dt_cross_validation(alphas, score, n_splits=10, shuffle=True, verbose=True):
  scores = {}
  kf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle)
 
  for i, (train_index, test_index) in enumerate(kf.split(X_train, np.argmax(y_train, axis=1))):
    print(f"{i+1}th Fold") if verbose else None
    for alpha in alphas:
      clf = tree.DecisionTreeClassifier(criterion="entropy", ccp_alpha=alpha)
      clf.fit(X_train[train_index], y_train[train_index])
      pred = clf.predict(X_train[test_index])
      pred, y_val = np.argmax(pred, axis=1), np.argmax(y_train[test_index], axis=1)
      scores[alpha] = scores.get(alpha, 0) + (score(y_val, pred) / n_splits)
      print(f"{alpha}: {scores[alpha]}") if verbose else None
 
  return scores

The ccp_aplha parameter sets the penalty rate. We can use the f1_weighted for the score and perform cross-validation.

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
 
def f1_weighted(y_true, pred):
  return f1_score(y_true, pred, average="weighted")
 
dt_cross_validation([1, 0.1, 0.01, 0.001], f1_weighted, verbose=False)
# {1: 0.1835164835164835,
#  0.1: 0.9692698412698411,
#  0.01: 0.9688253968253968,
#  0.001: 0.9485396825396826}

The results from the cross-validation reveal that the model overfits when the penalty rate is below 0.1 and underfits when the penalty rate is above 0.1. Hence, we can choose 0.1 as the penalty rate for our model.

from sklearn import tree
clf = tree.DecisionTreeClassifier(criterion="entropy", ccp_alpha=0.1)
clf.fit(X_train, y_train)

Step 4. Model Evaluation

Let's evaluate the model. We can use the same methods for evaluation that we covered in softmax regression. (I will omit the confusion matrix here for presentation purposes.)

from sklearn.metrics import classification_report
pred = np.argmax(pred, axis=1)
y_test = np.argmax(y_test, axis=1)
 
print(classification_report(y_test, pred))
 
#               precision    recall  f1-score   support
 
#            0       1.00      1.00      1.00        15
#            1       0.91      0.95      0.93        22
#            2       0.92      0.85      0.88        13
 
#     accuracy                           0.94        50
#    macro avg       0.94      0.93      0.94        50
# weighted avg       0.94      0.94      0.94        50

We can see from the above that the model performs quite well, despite some confusion between the Versicolor and Virginica classes. One advantage of using a decision tree is that we can easily visualize the tree.

tree.plot_tree(clf)
Tree Visualization

We can see that the model predominantly uses the third feature, petal width. This makes sense as the distributions are most separable when looking at petal width, as shown in the data exploration phase (available in "Road to ML Engineer #1 - Linear Regression"). It's surprising to see how simple the model is, given its high performance.

Causion

Data with high dimensions tend not to form clusters because it becomes harder for the data to be similar in all dimensions. Hence, decision trees tend to struggle with finding good partitions that separate the data into a fair number of samples and often underfit. Therefore, it's recommended to perform dimensionality reduction, like PCA, beforehand so that the likelihood of trees finding discriminative features increases. (This issue with high-dimensional data affects all machine learning models and is known as the curse of dimensionality.)

Additionally, there are many other hyperparameters for pruning the tree, such as max_depth, min_samples_split, and min_samples_leaf, which can prevent overfitting. When using decision trees in more practical settings, it's advisable to tune these hyperparameters as well. (Refer to the documentation for more practical tips.)

Resources