Varol Cagdas Tok

Personal notes and articles.

Application: Linear Regression and the Normal Equation

Problem Formulation

Given \(n\) training samples \(\{(x_1, y_1), (x_2, y_2), \ldots, (x_n, y_n)\}\) where \(x_i \in \mathbb{R}^d\) are feature vectors and \(y_i \in \mathbb{R}\) are targets, find a linear function:

\[f(x) = w^Tx + b\]

that best predicts \(y\) from \(x\).

Equivalently, augment \(x\) with a bias term: \(x_{aug} = [1, x_{1}, x_{2}, \ldots, x^{d}]^T \in \mathbb{R}^{d+1}\), and learn:

\[f(x) = w^Tx_{\text{aug}}\]

where \(w = [b, w_1, w_2, \ldots, w_d]^T\) includes the bias.

Matrix Formulation

Organize data into a matrix \(X \in \mathbb{R}^{n \times d}\) (or \(\mathbb{R}^{n \times (d+1)}\) if augmented) and vector \(y \in \mathbb{R}^{n}\):

X = [— x₁ᵀ —]     y = [y₁]
    [— x₂ᵀ —]         [y₂]
    [...    ]          [...]
    [— xₙᵀ —]         [yₙ]

The predictions for all samples are:

ŷ = Xw

Loss Function

Use mean squared error (MSE):

\[L(w) = \frac{1}{n}\|y - Xw\|_2^2 = \frac{1}{n}\sum_{i=1}^n (y_i - w^Tx_i)^2\]

Goal: find \(w\) that minimizes \(L(w)\).

Derivation of the Normal Equation

Expand the squared norm:

\[L(w) = \frac{1}{n}(y - Xw)^T(y - Xw) = \frac{1}{n}(y^Ty - y^TXw - w^TX^Ty + w^TX^TXw)\]

Since \(w^TX^Ty\) is a scalar, it equals its transpose \(y^TXw\):

\[L(w) = \frac{1}{n}(y^Ty - 2y^TXw + w^TX^TXw)\]

Take the gradient with respect to \(w\):

\[\nabla_w L(w) = \frac{1}{n}(-2X^Ty + 2X^TXw)\]

Set the gradient to zero for a minimum:

\[-2X^Ty + 2X^TXw = 0\]

\[X^TXw = X^Ty\]

This is the normal equation.

Solving the Normal Equation

If \(X^TX\) is invertible (full column rank), the unique solution is:

\[w = (X^TX)^{-1}X^Ty\]

The matrix \((X^TX)^{-1}X^T\) is the pseudo-inverse of \(X\) (for full column rank).

Example

Dataset with n=3 samples, d=1 feature (plus bias):

xy
12
24
35

Augment with bias column:

X = [1  1]     y = [2]
    [1  2]         [4]
    [1  3]         [5]

Compute \(X^TX\):

XᵀX = [1  1  1][1  1]   [3   6]
      [1  2  3][1  2] = [6  14]
                [1  3]

Compute \(X^Ty\):

Xᵀy = [1  1  1][2]   [1*2 + 1*4 + 1*5]   [11]
      [1  2  3][4] = [1*2 + 2*4 + 3*5] = [25]
                [5]

Solve for \(w\):