โ† Back to Machine Learning

Softmax Regression

๐Ÿ“‹ Overview

Softmax Regression (also called Multinomial Logistic Regression) extends logistic regression to multiclass classification problems. It uses the softmax function to convert raw scores (logits) into probability distributions over multiple classes.

๐ŸŽฏ Learning Objectives

  • Understand the mathematical foundation of softmax regression
  • Derive the softmax function and its properties
  • Implement softmax regression from MLE perspective
  • Apply softmax regression to multiclass problems
  • Compare with binary logistic regression
โฑ๏ธ Estimated Time: 30โ€“35 minutes reading + 60 minutes practice

Mathematical Foundation

โ„น๏ธ Note: Softmax Regression is the natural extension of Logistic Regression to handle more than two classes (K > 2).

The Softmax Function

The softmax function converts a vector of K real numbers into a probability distribution over K classes:

Softmax Function:

$$\sigma(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}} \quad \text{for } i = 1, 2, ..., K$$

Key properties of the softmax function:

  • Normalization: ฮฃ(i=1 to K) ฯƒ(z_i) = 1
  • Range: ฯƒ(z_i) โˆˆ (0, 1) for all i
  • Argmax Property: argmax(z) = argmax(ฯƒ(z))
  • Differentiable: Smooth everywhere

Softmax Regression Model

For multiclass classification with K classes, we model the probability of each class:

Model:

$$P(y=k|\boldsymbol{x}) = \frac{e^{\boldsymbol{w}_k^T \boldsymbol{x} + b_k}}{\sum_{j=1}^{K} e^{\boldsymbol{w}_j^T \boldsymbol{x} + b_j}}$$

Where:

  • $\boldsymbol{x}$ is the feature vector
  • $\boldsymbol{w}_k$ is the weight vector for class k
  • $b_k$ is the bias term for class k
  • $P(y=k|\boldsymbol{x})$ is the probability of class k

Matrix Form

We can write this more compactly using matrix notation:

Logits:

$$\boldsymbol{z} = \boldsymbol{W}^T \boldsymbol{x} + \boldsymbol{b}$$
Probabilities:

$$\boldsymbol{p} = \text{softmax}(\boldsymbol{z})$$

Where:

  • $\boldsymbol{W} \in \mathbb{R}^{d \times K}$ is the weight matrix
  • $\boldsymbol{b} \in \mathbb{R}^K$ is the bias vector
  • $\boldsymbol{z} \in \mathbb{R}^K$ is the logits vector
  • $\boldsymbol{p} \in \mathbb{R}^K$ is the probability vector

Theoretical Foundation: Maximum Likelihood Estimation

Assumption

We assume that $y_i$ follows a categorical distribution:

$$y_i | \boldsymbol{x}_i \sim \text{Categorical}(\boldsymbol{p}_i)$$

Where $\boldsymbol{p}_i = \text{softmax}(\boldsymbol{W}^T \boldsymbol{x}_i + \boldsymbol{b})$.

Likelihood Function

For one-hot encoded labels, the probability mass function is:

$$P(y_i = k | \boldsymbol{x}_i, \boldsymbol{W}, \boldsymbol{b}) = p_{i,k} = \frac{e^{\boldsymbol{w}_k^T \boldsymbol{x}_i + b_k}}{\sum_{j=1}^{K} e^{\boldsymbol{w}_j^T \boldsymbol{x}_i + b_j}}$$

For all $n$ observations, the likelihood function is:

$$L(\boldsymbol{W}, \boldsymbol{b}) = \prod_{i=1}^{n} \prod_{k=1}^{K} p_{i,k}^{y_{i,k}}$$

Where $y_{i,k}$ is 1 if sample i belongs to class k, 0 otherwise (one-hot encoding).

Log-Likelihood

Taking the natural logarithm:

$$\ell(\boldsymbol{W}, \boldsymbol{b}) = \sum_{i=1}^{n} \sum_{k=1}^{K} y_{i,k} \log(p_{i,k})$$

Cross-Entropy Loss

To minimize (instead of maximize), we use the negative log-likelihood:

$$J(\boldsymbol{W}, \boldsymbol{b}) = -\frac{1}{n}\sum_{i=1}^{n} \sum_{k=1}^{K} y_{i,k} \log(p_{i,k})$$

This is exactly the categorical cross-entropy loss function!

Gradient Derivation

The gradient of the loss with respect to $\boldsymbol{w}_k$ is:

$$\frac{\partial J}{\partial \boldsymbol{w}_k} = \frac{1}{n}\sum_{i=1}^{n} (p_{i,k} - y_{i,k})\boldsymbol{x}_i$$

And with respect to $b_k$:

$$\frac{\partial J}{\partial b_k} = \frac{1}{n}\sum_{i=1}^{n} (p_{i,k} - y_{i,k})$$

Key Properties

๐Ÿ“Š Probability Distribution

Outputs valid probability distributions over all classes.

๐ŸŽฏ Multiclass Support

Handles any number of classes K โ‰ฅ 2 naturally.

๐Ÿ“ˆ Smooth Function

Softmax function is smooth and differentiable everywhere.

๐Ÿ” Interpretable

Probabilities can be directly interpreted as confidence scores.

โšก Fast Training

Convex optimization problem with unique global minimum.

๐Ÿšซ No Assumptions

No assumptions about feature distributions.

Applications

  • Computer Vision: Image classification (CIFAR-10, ImageNet), object detection
  • NLP: Text classification, sentiment analysis, language detection, topic modeling
  • Healthcare: Disease classification, medical image analysis, drug discovery
  • Finance: Risk rating, customer segmentation, fraud detection
  • Engineering: Quality classification, fault diagnosis, system monitoring

Interactive Visualization

Explore the softmax function with different logit values:

Comparison: Binary vs Multiclass

Aspect Logistic Regression (Binary) Softmax Regression (Multiclass)
Number of Classes 2 (K = 2) Multiple (K > 2)
Activation Function Sigmoid ฯƒ(z) Softmax ฯƒ(z)
Output Range P(y=1) โˆˆ (0, 1) ฮฃ P(y=k) = 1
Parameters w โˆˆ โ„แตˆ, b โˆˆ โ„ W โˆˆ โ„แตˆหฃแดท, b โˆˆ โ„แดท
Loss Function Binary Cross-Entropy Categorical Cross-Entropy
Decision Rule P(y=1) > 0.5 argmax P(y=k)

๐Ÿ’ป Code Examples

NumPy, scikit-learn, and PyTorch implementations

๐Ÿ“Š Advanced Topics

One-vs-Rest, One-vs-One strategies

๐Ÿ‹๏ธ Exercises

Hands-on practice problems

Detailed Example: Handwritten Digit Classification

Let's work through a practical example of classifying handwritten digits (0-9) using softmax regression.

Problem Setup

We have 10 classes (digits 0-9) and want to predict the probability of each digit given pixel features.

$$P(\text{digit}=k|\boldsymbol{x}) = \frac{e^{\boldsymbol{w}_k^T \boldsymbol{x} + b_k}}{\sum_{j=0}^{9} e^{\boldsymbol{w}_j^T \boldsymbol{x} + b_j}}$$

Sample Prediction

For an input image $\boldsymbol{x}$, we compute logits for all 10 classes:

$$\boldsymbol{z} = \begin{bmatrix} z_0 \\ z_1 \\ z_2 \\ \vdots \\ z_9 \end{bmatrix} = \boldsymbol{W}^T \boldsymbol{x} + \boldsymbol{b}$$

Suppose we get logits: $\boldsymbol{z} = [2.1, -0.5, 0.8, 1.2, -1.1, 3.0, 0.3, -0.2, 1.5, 0.1]^T$

Softmax Calculation

First, compute the exponential of each logit:

$$e^{\boldsymbol{z}} = \begin{bmatrix} e^{2.1} \\ e^{-0.5} \\ e^{0.8} \\ \vdots \\ e^{0.1} \end{bmatrix} = \begin{bmatrix} 8.17 \\ 0.61 \\ 2.23 \\ \vdots \\ 1.11 \end{bmatrix}$$

Sum of exponentials: $\sum_{k=0}^{9} e^{z_k} = 8.17 + 0.61 + 2.23 + ... + 1.11 = 28.45$

Final probabilities:

$$\boldsymbol{p} = \begin{bmatrix} P(0) \\ P(1) \\ P(2) \\ \vdots \\ P(9) \end{bmatrix} = \begin{bmatrix} 0.287 \\ 0.021 \\ 0.078 \\ \vdots \\ 0.039 \end{bmatrix}$$

Prediction

The predicted class is the one with the highest probability:

$$\hat{y} = \arg\max_{k} P(\text{digit}=k|\boldsymbol{x}) = \arg\max_{k} p_k = 0$$

So the model predicts this is digit "0" with 28.7% confidence.

Training Process

During training, we minimize the categorical cross-entropy loss:

$$J(\boldsymbol{W}, \boldsymbol{b}) = -\frac{1}{n}\sum_{i=1}^{n} \sum_{k=0}^{9} y_{i,k} \log(p_{i,k})$$

Where $y_{i,k}$ is 1 if sample i is digit k, 0 otherwise.