Cross entropy and training-test class imbalance


Suppose we want to train a machine learning model on a binary classification problem.  A standard way of measuring model performance here is called log-loss or binary cross entropy (I will refer to this as cross entropy throughout this post).  This means that given the task to predict some binary label y, rather than outputting a hard 0 / 1 to the predicted classes, one outputs a probability, \hat{y} say.  Then the cross entropy score of the model is

\sum_i -y_i\log\hat{y_i} - (1-y_i)\log(1-\hat{y_i}).

We will explain roughly where this loss comes from in the next section.  Now suppose that the test set actually has a different proportion of positives to negatives to the training set.  This is not a hypothetical scenario. this is exactly what competitors of the recently added ‘Quora question pair challenge‘ are faced with.  This post is to explain why the nature of cross entropy makes this is a problematic setup (something I, and other posters pointed out), and a theoretical solution.  This problem could also come up where the proportion of positives changes over time (and this is known), but the training cross-entropy score is to be used. Some posters on the Kaggle discussion boards mentioned attempts to convert training set predictions to test set predictions, but to my knowledge there is no serious published analysis on it so far, so here goes…

Cross-entropy and class imbalance problems

Cross entropy is a loss function that derives from information theory. One way to think about it is how much extra information is required to derive the label set from the predicted set.  This is how it is explained on the wikipedia page for example.  In my opinion, a more intuitive way to view it is as a loss function that rewards the model for being ‘honest’ about how probable it believes labels to be. Say our predictive model believes that there is a probability p of a given label being positive. What value 0 \leq q \leq 1 should we output to minimise cross entropy loss?  Well if we really believe that there is a chance p that a label is positive, then our best estimate for our loss is

\text{loss}(q) = - p \log (q) - (1 - p) \log (1 - q).

So we differentiate this function with respect to q, and set to zero (to find the minimum).  We get

\begin{aligned}  \frac{1 - p}{1 - q} = \frac{p}{q}, \end{aligned}

and thus can verify that the value of q which minimises the loss is q = p.

Now let’s think about why class proportions being different between the training and test set is problematic for cross entropy loss. Suppose we just wanted to take the most naive possible model where we output the same value for every label. By the above discussion, the single value that will optimise loss on the training set is \mathbb{P}(y), i.e. the probability that a randomly chosen label is positive, but it will only also maximise it on the test set if this probability is the same – i.e. if positives are equally likely in the training/test set.  This is our ‘prior’ in Bayesian-speak.

Moreover, more complicated models will tend to gravitate around this prior when they are very ‘unsure’, i.e. when they don’t glean any extra information about the label from the training features, and they will be punished for doing this if training/test class balances are not equal.

Note that there are other loss functions available that are less sensitive to this class imbalance problem, for example area under the curve (AUC).

Converting predictions using Bayes’ theorem

Let’s suppose our training set is drawn from a distribution (X, y), and our test set is drawn from (X', y').  Our assumption at this point will be that the only difference between these two distributions is that they happen to have different proportions of positives or negatives i.e.X | (y = 0) \sim X' | (y' = 0) and X | (y = 1) \sim X' | (y' = 1).

Suppose we have some sample x \in X.  Our model is trying to estimate \mathbb{P} (y| x), where (abusing notation for now) y is the event that the label is positive (and \neg y is the event that the label is negative). Suppose our model’s best estimate of this is p. By Bayes’ theorem, we have

\begin{aligned} p \approx \mathbb{P} (y| x) &= \frac{\mathbb{P}(x | y) \mathbb{P}(y)} {\mathbb{P}(x)} \\ &= \frac{\mathbb{P}(x | y) \mathbb{P}(y)} {\mathbb{P}(x | y) \mathbb{P}(y) + \mathbb{P}(x | \neg y) \mathbb{P}( \neg y)} \\ &= \frac{u}{u + v}, \end{aligned}

say.  Now suppose that the same x was instead sampled from X', and we are now trying to estimate \mathbb{P}(y' | x).  We suppose that X' is the same as X except the positives have been oversampled by a ratio \alpha, and the negatives by a ratio \beta – i.e. \mathbb{P}(y') = \alpha \mathbb{P}(y) and \mathbb{P}(\neg y') = \beta \mathbb{P}(\neg y). As noted, conditional on y', X' and X are identical, so that \mathbb{P}(x | y) = \mathbb{P}(x | y').  So then :

\begin{aligned} \mathbb{P} (y' | x) &= \frac{\mathbb{P}(x | y') \mathbb{P}(y')} {\mathbb{P}(x | y') \mathbb{P}(y') + \mathbb{P}(x | \neg y') \mathbb{P}( \neg y')} \\ &= \frac{\alpha u}{\alpha u + \beta v}. \end{aligned}

Now from the above equation, we have that v \approx u (1 - p) / p, and so

\begin{aligned} \mathbb{P} (y' | x) &\approx \frac{\alpha u}{\alpha u + \beta u (1 - p) / p } \\ &= \frac{\alpha p}{\alpha p + \beta (1 - p)} \end{aligned}

Thus a link function mapping from probabilities in the training set to probabilities in the test set is

\begin{aligned} f(x) = \frac{\alpha x}{\alpha x + \beta (1 - x)}. \end{aligned}

Further work could be done to estimate how uncertainty on x affects the uncertainty on f(x), but I’m not going to pursue that here.

Note than one can also derive this formula from trying to optimise the loss function

 \alpha p \log (q) + \beta (1 - p) \log (1 - q),

with respect to q. This suggests that we can optimise the loss function

\alpha y \log (\hat{y}) + \beta (1 - y) \log (1 - \hat{y})

on X, to optimise the cross entropy on X', which is another way to derive f (though slightly less insightful, in my opinion).

As an example let’s go back to the original Quora dateset.  It is (currently) believed that the training set has 37% positives, whereas the test set only has 16.5% negatives.  From the above discussion, we take \alpha = 16.5 / 37 and \beta = 83.5 / 63, and then f looks like so :

Screen Shot 2017-03-28 at 20.48.49

here we have marked lines to confirm that f(0.37) = 0.165.  Our function also has some desirable properties that a simple linear scaling could not have, e.g. if our x is very close to 0 or 1, then f(x) will also be.

Also note that we heavily relied on the assumption that the positives/negatives in X were equidistributed to the positives/negatives in X' respectively.  If that it is not true, then this analysis could be of limited use!


6 thoughts on “Cross entropy and training-test class imbalance

    • Sorry if not clear – the idea is to train on log-loss on the training set as it is (no oversampling or anything), then to convert to the test set which has a different class balance, use this function.


      • It’s OK, I believe it’s me. I mean, the first method implied generating more data to get the right proportions between classes. What are we doing here? Applying your function on the test set predictions before submitting?


      • Hey nah, at no point in this post are we oversampling to get the correct proportions. The idea is to train using log-loss on the proportion that is in the training set to get log-loss predictions on that set. But these won’t be correctly calibrated for the test set, so we apply this function to get predictions that are.


  1. A simpler – but equivalent – formula is to use the Logit transformation:
    C = logit(0.165) – logit(0.37)
    y’ = inv_logit(logit(y) + C)

    The intuition is here easy: we just change the constant of a logistic regression.

    def logit(x): return 1. / (1 + np.exp(-x))
    def inv_logit(p): return – np.log(1/p – 1)


Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s