Last time we started talking about how one would design an ML algorithm to predict the click-through-rate (CTR) on an ad impression, and you designed two extreme types of algorithms. The first one was the constant algorithm, which simply records the overall historical click-rate a, and predicts a for every new ad opportunity.
This was fine for overall accuracy but would do poorly on individual ad impressions, because its predictions are not differentiated based on the attributes of an impression, and CTRs could be very different for different types of impressions.
Exactly, and the other extreme was the memorization algorithm: it records the historical CTR for every combination of feature-values, and for a new ad impression with feature-value-combination x (e.g., browser = “Safari”, gender = “male”, age = 32, city = “NYC”, ISP = “Verizon”, dayOfWeek = “Sunday”), it looks up the historical CTR a(x) (if any) for x, and outputs a(x) as its prediction.
And we saw that this algorithm would be highly impractical and it would not generalize well to new ad impressions.
Yes, it would not generalize well for two reasons: (a) the specific feature-value-combination x may not have been seen before, and (b) even if x occurred in the training data, it may have occurred too infrequently to reliably estimate CTR for cases resembling x.
So how would we design an algorithm that generalizes well?
Let’s do a thought experiment: what do you do when you generalize in daily situations?
I start from specific facts or observations and formulate a general concept.
How, exactly, do you do that?
I abstract common aspects of the specific observations, so that I can make a broader, more universal statement. I build a mental model that’s consistent with my observations.
In terms of detail, would you say that your mental model is relatively simple or complex?
Simple, definitely. I abstract away the non-essential details.
Great, you’ve just identified some of the characteristics of a good generalization: a relatively simple, abstract, less detailed, model that is consistent with (or fits, or explains) the observations , and is more broadly applicable beyond the cases you have seen.
That’s how humans think, but how does this help us design a computer algorithm that generalizes well?
This helps us in a couple of ways. Firstly, it gives us a framework for designing an ML algorithm. Just like humans build mental models based on their observations, an ML algorithm should ingest training data and output a model that fits, or explains the training data well.
What does a “model” mean specifically?
A model is the mathematical analog to the human idea of a “concept” or “mental model”; it’s the mathematical formalization of what we have been informally calling “rules” until now. A model is essentially a function takes as input the characteristics (i.e. features) of a case (i.e. example), and outputs a classification (e.g., “cat” or “not cat”) or score (e.g. click-probability).
Wow, so a model is essentially a program, and you’re saying that an ML program is producing another program?
Sure, that’s a good way to look at it, if you think of a function as a program that takes an input and produces an output.
We seem to be back in voodoo-land: a program that ingests training data, and spits out a program…
Well the ML algorithm does not magically output code for a model; instead the ML algorithm designer usually restricts the models to a certain class of models M that have a common structure or form, and the models in the class only differ by certain parameters p. Think of the class of models M as a generic template, and a specific model from the class is selected by specifying the values of the parameters p. Once restricted to a certain class of models, the ML algorithm only needs to find the values of the parameters p such that the specific model with these parameter-values fits the training data well.
For example suppose we’re trying to predict the sale price y of a home in a certain locality based only on the square footage x. If we believe home prices are roughly proportional to their square footage, a simple class of models for this might be the class of linear models, i.e., the home price y is modeled as a*x. The reason these models are called linear is that, for a fixed value of a, a plot of x versus y would be a straight line.
Notice that a*x is a generic template for a linear model, and a value for a needs to be specified to get a specific linear model, which would then be used to predict the sale price for a given square-footage x. The ML algorithm would be trained on examples of pairs (x,y) of square-footage and home-price values, and it needs to find the value of the parameter a such that the model a*x “best fits” the training data. I hope this does not sound so mysterious any more?
It certainly helps! Instead of magically spitting out code for some model, an ML algorithm is actually “merely” finding parameters p for a specific class of models.
But how do we know a model is “correct” or is the “true” one?
Remember that a model expresses a relationship between features and response (e.g. classification or score). All non-trivial real-world phenomena are governed by some “true” underlying model (think of this as the signal, or the essence, of the relationship between the variables involved) combined with random noise. The noise could result from errors in measurement or inherent random variation in the relationship. However an ML algorithm does not need to discover this “true” underlying model in order to be able to generalize well; it suffices to find a model that fits the training data adequately. As the statistician George Box famously said,
All models are wrong, but some are useful,
meaning that in any non-trivial domain, every model is at best an approximation when examined closely enough, but some models can be more useful than others.
What kinds of models would be more useful?
In the ML context, a useful model is one that generalizes well beyond the training data. This brings us to the second way in which thinking about how humans generalize helps, which is that it suggests a guiding principle in designing a true learning algorithm that has good generalization abilities. This principle is called Occam’s Razor:
Find the simplest model that explains the training data well. Such a model would be more likely to generalize well beyond the training examples.
And what is a “simple” model?
In a given context or application-domain, a model is more complex if its structure or form is more complex and/or it has more parameters (the two are often related). For example when you’re predicting CTR, and you have thousands of features you could be looking at, but you’re able to explain the training data “adequately” using just 3 of them, then the Occam’s Razor principle says you should just use these 3 to formulate your CTR-prediction model. Or when predicting house prices as a function of square-footage, if a linear model fits the training data adequately, you should use this model (rather than a more complex model such as quadratic, etc) as it would generalize better to examples outside the training set. I’ve used italics when talking about fitting/explaining the training data because that’s something I haven’t explained yet.
I see, so the constant algorithm I designed to predict CTR produces an extremely simple model that does not even fit the training data well, so there’s no hope it would generalize beyond the training data-set?
Exactly, and on the other hand the memorization algorithm produces an extremely complex model: it records the historical CTR for each feature-value combination in the training data. When tested on any example (i.e. feature-value combination) that already occurred in the training data, this model would produce exactly the historical CTR observed for that example. In this sense this complex model fits the training data perfectly. When an overly complex model fits training data exceptionally well, we say the model is overfitting the training data. Such models would not generalize well beyond the training data-set.
Interesting, so we want to find a relatively simple model that fits the training data “adequately”, but we don’t want to overdo it, i.e. we shouldn’t overly complicate our model in order to improve the fit to the training data. Is it obvious why complex models in general wouldn’t generalize well?
It’s actually not obvious, but here’s an intuitive explanation. Remember we said that any real-world phenomenon is governed by some true model that represents the signal or essence of the phenomenon, combined with random noise. A complex model-class has so many degrees of freedom (either in its structure or number of parameters) that when a best-fitting model is found from this class, it overfits the training data: it fits the signal as well as the random noise. The random noise gets baked into the model and this makes the model perform poorly when applied to cases not in the training set.
Can you give me an example where a simpler model generalizes better than a complex one?
Let’s look at the problem of learning a model to predict home sale prices in a certain locality based on the square-footage. You’re given a training data-set of pairs (x,y) where x is the square-footage in thousands, and y is the sale price in millions of dollars:
(2, 4.2), (3, 5.8), (4, 8.2), (5, 9.8), (6, 12.2), (15, 27.8),
and your task is to find a model that “best” fits the data (you decide how to interpret “best”). Any ideas?
I notice from the training data that when the square-footage x is an even number, the home price is (2*x + 0.2), and when it’s an odd number the home price is (2*x — 0.2).
Great, your model fits the training data perfectly, but in order to do so you’ve made your model complex because you have two different expressions based on whether the square footage is even or odd. Can you think of a simpler model?
I see what you mean — I notice that the home prices (in millions of dollars) in the training data-set are fairly close to twice the number of thousands of square feet, so my simpler model would be 2*x.
This would be a simpler model, even though it does not exactly fit the training data: the home prices are not exactly 2*x but close to it. We can think of the home price as being 2*x plus some random noise (which could be positive or negative). It just so happens that in the specific training examples you’ve seen, the noise part in the even-square-footage examples is 0.2, and -0.2 for the others. However it could well be that for a new example with x=7 (an odd number) the price is exactly 2*x or slightly above 2*x, but if we were to use your previous complex model, we would be predicting 2*x — 0.2, which would not be a good prediction.
The point is that in an effort to fit the training data very well (in your case perfectly), you made your model overly complex and latched on to a spurious pattern that is unlikely to hold beyond the training data-set. By definition the noise part is random and un-predictable, so any attempt to model it and predict would result in errors. Instead, the simpler model captures only the “signal” and is oblivious to the random noise.
Here, then, is the recipe for designing a true ML algorithm that generalizes well:
- Pick a relatively simple class of models M, parametrized by a set of parameters p
- Find values of the parameters p such that the model from M with these parameters fits the training data optimally. Denote this model as M[p]. M[p] is a function that maps a feature-value combination x to a response (e.g. classification or score).
- Output M[p] as the predictive model for the task, i.e. for an example x, respond with M[p](x) as the predicted classification or score.
What does it mean for a model to fit or explain the training data, and how do I find the parameters that “optimally” fit the data?
Informally, a model fits the training data well when it is consistent with the training data. This is analogous to how you form a mental model based on your observations: your mental model is consistent with your observations (at least if you’re observant enough and think rationally). Of course to implement a computer algorithm we need to make precise the notion of an “optimal fit”, and specify how the computer should go about finding model parameters that “best” fit the training data. These are great topics for a future post!