Road to ML Engineer #6 - Cross Validation

Last Edited: 8/2/2024

The blog post introduces bias-variance tradeoff that plagues all of machine learning.

ML

Common Misconception

When people learn about linear regression, they often assume it is called linear regression because it fits a straight line or linear function to data. However, this is not where the term "linear" comes from. IN fact, we can fit non-linear functions like quadratic functions, radial functions, Fourier functions, and so on. The "linear" in linear regression comes from the fact that we fit functions that are linear in their parameters, not necessarily their inputs. For example, the following functions can be fitted using linear regression:

f(x)=w1x2+w2x+w3g(x,y)=w1x3+w2y2+w3h(x,y,z)=w1cos(x)+w2cos(y)+w3cos(z)+w4 f(x) = w_1 x^2 + w_2 x + w_3 \\ g(x,y) = w_1 x^3 + w_2 y^2 + w_3 \\ h(x, y, z) = w_1 cos(x) + w_2 cos(y) + w_3 cos(z) + w_4

Meanwhile, the following functions can NOT be fitted using linear regression:

l(x)=w12x+w2k(x,y)=w13x2+w22y2+log(w3)m(x,y,z)=cos(w1)cos(x)+w2cos(y)+w3cos(z)+w4 l(x) = w_1^2 x + w_2 \\ k(x,y) = w_1^3 x^2 + w_2^2 y^2 + log(w_3) \\ m(x, y, z) = cos(w_1) cos(x) + w_2 cos(y) + w_3 cos(z) + w_4

While all the functions that can be fitted using linear regression have linear weights/parameters (w1w_1, w2w_2, w3w_3) , all the functions that cannot be fitted using linear regression have non-linear weights/parameters (w12 w_1^2, log(w3)log(w_3), cos(w2)cos(w_2)). The functions we use to fit the data are called basis functions, and the basis functions need to be linear in their parameters like the above.

In linear regression, we take the partial derivatives of the loss function, informed by the basis functions, with respect to the weights/parameters. If those parameters are non-linear in basis functions, the derivatives will be much harder to calculate than when they are linear. Also, you might be able to notice that having non-linear weights have no meaning other than to make the calculations more complicated since we can just set w=w12w=w_1^2 or w=log(w3)w=log(w_3) and make the functions linear in their parameters. Hence, it makes sense that we want the basis functions to be linear in their parameters. As logistic and softmax regression are both linear regression on log-odds, we can choose any basis function that is linear to its parameters for them as well.

Bias-Variance Tradeoff

Given that we can choose any basis function, let's see the example below.

When presented with the above dataset, what basis function will you choose? For example, you can choose to fit a simple linear function (green), a less simple log function (blue), and an even more complex trigonometric function (purple) like this.

From the above, we can observe that the more complex the function is, the better it can fit the data. In fact, the complex trigonometric function is actually fitting perfectly to the data. How well a model fits the training data is called bias, and we can say that the trigonometric function has the lowest bias while the linear function has the highest bias.

Technically, you can make the function so complex that the number of parameters is the same as the number of training data and make the function fit perfectly to any data. Then why do we even bother using linear functions or log functions when we know that they cannot fit that well to the training data? It has to do with testing data.

As you can see from the above, the trigonometric function is not fitting well to the testing data while the simpler linear function and the log function are fitting well. How well a model fits the testing data is called variance, and we can say that the trigonometric function has the highest variance. As we had bias too low to the training data, we resulted in having the highest variance to the testing data, and we call this overfitting (overly fitting to the training data).

Meanwhile, the linear function and the log function have low variance by not fitting the training data too much. However, in this particular case, the log function seems to be performing better as it captures the general non-linear relationship that the linear function cannot capture. As we had too little bias to the training data with the linear function, we failed to capture complex relationships to the data and less-than-optimal variance with the linear function, and we call this underfitting.

We can see that there is a bias-variance tradeoff where a model with lower bias tends to achieve higher variance, and vice versa, in general. To find the optimal model, we need to balance the bias-variance tradeoff and avoid overfitting and underfitting by choosing the model with the right level of complexity. What is difficult about this is that we need to do that without seeing the testing dataset because we choose the model in Step 3: Model Training, before we engage with the testing dataset in Step 4: Model Evaluation.

Cross Validation

The challenge of choosing the model with the right bias for achieving low variance comes up all the time, and ML researchers have been tackling this challenge in multiple ways. One of the ways they came up with is cross-validation.

When choosing the model, we can simply set up a pseudo testing dataset called a validation dataset by taking some portion of the training dataset. Then, we can use metrics on the validation dataset to estimate the variance of the models and choose the one with the lowest estimated variance. However, we want the validation dataset to be representative of the whole dataset, which is difficult to assume given that the validation dataset is just a fraction of the training dataset.

Hence, instead of only using one validation dataset, we divide the training dataset into k small groups, leave each one out for it to be used as a validation dataset to compute metrics, and take the average of the metrics. This way, we can utilize the training dataset fully and have a better estimate of the variance of the model. This method is called k-fold cross-validation.

Even then, we might see groups with only one class when the dataset is imbalanced. In such a case, we can use stratified k-fold cross-validation, where we put at least one data point from each class in every group. You can also use shuffle split cross-validation, where you allow the same data to be randomly picked in more than one group so that cross-validation can be performed even when the training dataset is small.

Code Implementation

Let's see cross-validation in action. Here, I will be using the Iris dataset, and the task is to predict the iris species. Now, we have three cadidate models to choose from.

log(odds)=w1pl+w2log(odds)=w1sl+w2sw+w3pl+w4pw+w5log(odds)=w1pl2+w2pl+w3 log(odds) = w_1 pl + w_2 \\ log(odds) = w_1 sl + w_2 sw + w_3 pl + w_4 pw + w_5 \\ log(odds) = w_1 pl^2 + w_2 pl + w_3

We can use SoftmaxRegressionGD for all of the models, except that we need to modify few things.

class SoftmaxRegressionGD():
  def __init__(self, lr=0.001, basis="model_1"):
    self.lr = lr # Learning rate
    self.history = [] # History of loss
    self.basis = basis
    shapes = {"model_1": 1, "model_2": 4, "model_3": 2}
    self.W = np.ones((shapes[self.basis], y.shape[1]))
    self.b = np.ones(y.shape[1])
 
  def transform(self, X):
    if self.basis == "model_1":
        return X[:,0][:, np.newaxis]
    elif self.basis == "model_3":
        return np.concatenate(([X[:,0]**2], [X[:,0]])).T
    else:
        return X
 
  def predict(self, X):
    X = self.transform(X)
    
    # Stabilize the softmax calculation using the log-sum-exp trick
    logits = np.matmul(X, self.W) + self.b
    max_logits = np.max(logits, axis=1, keepdims=True)
    
    stabilized_logits = logits - max_logits  # Prevent overflow
    odds = np.exp(stabilized_logits)
    total_odds = np.sum(odds, axis=1, keepdims=True)
 
    return odds / total_odds  # Normalized probabilities
 
  def fit(self, X, y, epochs=100):
    X = self.transform(X)
    for i in range(epochs):
      pred = self.predict(X)
 
      self.history.append(log_loss(y, pred))
 
      diff = 1 - pred
      grad_W = np.matmul(X.T, diff*y)
      grad_b = np.sum(diff*y, axis=0)
 
      self.W += self.lr * grad_W
      self.b += self.lr * grad_b
    return self.history

We set up a transform method that changes the dimensions of the input, and we set up different shapes for weights according to the equations. Another change being made is the log-sum-exp trick in predict, where we try not to have a number too large for exp by taking max out of the equation before computing odds. This prevents the exp of a number that is too large. Now, we can define a function for performing cross-validation.

from sklearn.model_selection import StratifiedKFold
 
def sm_cross_validation(X, y, models, score, epochs=500, n_splits=10, shuffle=True, verbose=True):
  scores = {}
  kf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle)
  
  for i, (train_index, test_index) in enumerate(kf.split(X, np.argmax(y, axis=1))):
    print(f"{i+1}th Fold") if verbose else None
    for m in models:
      sm = SoftmaxRegressionGD(basis=m)
      sm.fit(X[train_index], y[train_index], epochs=epochs)
      pred = sm.predict(X[test_index])
      pred, y_val = np.argmax(pred, axis=1), np.argmax(y[test_index], axis=1)
      scores[m] = scores.get(m, 0) + (score(y_val, pred) / n_splits)
      print(f"{m}: {scores[m]}") if verbose else None
 
  return scores

The above uses StratifiedKFold from sklearn.model_selection to get indices of stratified k-fold. We can fit models with different basis functions on each fold and compare the average value of metrics. The metrics we will use here is weighted-average F1-score.

from sklearn.metrics import f1_score
 
def f1_weighted(y_true, pred):
  return f1_score(y_true, pred, average="weighted")

Using all of the above, let's perform the cross-validation!

sm_cross_validation(X_train, y_train, ["model_1", "model_2", "model_3"], f1_weighted, verbose=False)
 
# {'model_1': 0.32254578754578755,
#  'model_2': 0.8637539682539682,
#  'model_3': 0.1835164835164835}

We can see that the model_2, which is actually the same model we have trained before in the past article on softmax regression, achieved the best result or lowest variance. Then, we can confidently choose model_2 to train with the whole training data.

Hyperparameter Tuning

As we have just demonstrated above, we can use cross-validation to choose the model with the best basis function or hyperparameter. The hyperparameters are parameters of the model that are not optimized during training, and the process of figuring out the best combination of hyperparameters is called hyperparameter tuning.

Other than the basis function, there is another hyperparameter of SoftmaxRegressionGD that we did not optimize for, and that is the learning rate. In fact, the above cross-validation only checked the best basis functions given that the learning rate is set to the default value of 0.001. How do we perform hyperparameter tuning and obtain the best combination of basis functions and learning rates? There are two ways of doing this: grid search and random search.

The grid search cross-validation is essentially a brute force approach, which uses nested loops to test all the combinations of hyperparameters. However, some models take quite a long time to finish training or have so many hyperparameters that it is practically impossible to test all of the combinations. The random search cross-validation, instead, takes the distribution of the hyperparameters and randomly picks a few combinations to test so that we can save time while gaining insights about the hyperparameters.

Hyperparameter tuning not only helps us tackle the bias-variance tradeoff, but it also allows us to find the best set of hyperparameters in terms of how learning mechanisms like gradient descent perform. Hence, it is a great method to use when there are a tractable number of hyperparameters in a model. However, when we talk about the solutions for the bias-variance tradeoff in machine learning, we often do not talk about cross-validation and instead discuss other solutions like regularization, bagging, and boosting, which I would like to cover in future articles.