Proving Convexity of Mean Squared Error Loss in a Regression Setting.

by Pritish Jadhav, Mrunal Jadhav - Sun, 25 Aug 2019
Tags: #python #Matrices #Convexity #Optimization.

Proving Convexity of Mean Squared Error Loss - Case Study

  • In this blog post, we shall quickly cover the convexity proof for Mean Squared Error Loss function used in a traditional Regression setting.
  • In case you haven't checked out my previous blog - The Curious Case of Convex Functions, I would highly recommend you to check it out. The blog focuses on all the basic building blocks for proving convexity.

With that in mind, let us start by reviewing -

  1. The MSE loss for a Regression Algorithm.
  2. Conditions for checking Convexity.

1. MSE Loss Function -

The MSE loss function in a Regression setting is defined as -

$$ \begin{align} J(W) = \frac{1}{2m}\sum_{i=1}^{m} [y^{(i)} - \hat{y}^{(i)}]^2 \tag{1} \end{align} $$

Where,

m = number of training examples.
$J(w)$ = Loss as a function of Regression Coeffients.
$y^{(i)}$ = true value for the $ith$ training example.
$\hat{y}^{(i)}$ = predicted value for the $ith$ training example.

For $ith$ training example, $\hat{y}^{(i)}$ is defined as -

$$ \begin{align} \hat{y}^{(i)} = \sum_{j = 1}^{n}(w_jx_{j}^{(i)} ) \tag{2} \end{align} $$

Where,

n = number of features.

For the sake of convenience/readability, let's assume n = 3. The eq.(2) can thus be written as -

$$ \begin{align} \hat{y}^{(i)} &= \sum_{j = 1}^{n}(w_jx_{j}^{(i)} ) \\ & = w_1x_{1}^{(i)} + w_2x_{2}^{(i)} + w_3x_{3}^{(i)} \tag{3} \end{align} $$

Since we have considered only one training example, we can let go of the training index.

$$ \begin{align} \therefore J(W) = \frac{1}{2} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})]^2 \tag{4} \end{align} $$

2. Checking for Convexity of J(W)-

For checking the convexity of Mean-Squared-Error function, we shall perform the following checks -

  • Step 1 - Computing the Hessian of J(W)
  • Step 2- Computing the Principal Minors of the Hessian.
  • Step 3 - Based on the values of principal minors, determine the definiteness of Hessian.
  • Step 4 - Comment on Convexity based on convexity tests.

Let us get down to it right away-

Step 1 - Hessian of $J(w)$ -

$$ \begin{align} J^H = \begin{bmatrix} \frac{\partial ^2 J}{\partial w_1^2} & \frac{\partial^2 J}{\partial w_1 \partial w_2} & \frac{\partial^2 J}{\partial w_1 \partial w_3} & \\ \frac{\partial ^2 J}{\partial w_2 \partial w_1} & \frac{\partial^2 J}{\partial {w_2}^2 } & \frac{\partial^2 J}{\partial w_2 \partial w_3} \\ \frac{\partial ^2 J}{\partial w_3 \partial w_1} & \frac{\partial^2 J}{\partial w_3 \partial w_2 } & \frac{\partial^2 J}{\partial {w_3}^2 } \\ \end{bmatrix} \end{align} $$

Lets compute each component of the matrix.

$$ \begin{align} \frac{\partial ^2 J}{\partial w_1^2} &= \frac{\partial}{\partial w_1} \big[ \frac{\partial}{\partial w_1}\big[\frac{1}{2} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})]^2\big] \big] \\ &= \frac{\partial}{\partial w_1} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})](-x_1) \\ &= (-x_1)(-x_1) \\ &= (x_1)^2 \end{align} $$$$ \begin{align} \frac{\partial ^2 J}{\partial w_1w_2} &= \frac{\partial}{\partial w_1} \big[ \frac{\partial}{\partial w_2}\big[\frac{1}{2} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})]^2\big] \big] \\ &= \frac{\partial}{\partial w_1}[y - (w_1x_{1} + w_2x_{2} + w_3x_{3})](-x_2) \\ & = (-x_2)(-x_1) \\ & = x_1x_2 \\ &= \frac{\partial ^2 J}{w_2w_1} \end{align} $$

Similarly, it can be proven that -

$$ \begin{align} \frac{\partial ^2 J}{\partial w_1w_3} = \frac{\partial ^2 J}{\partial w_3w_1} = x_1x_3 \\ \frac{\partial ^2 J}{\partial w_2w_3} = \frac{\partial ^2 J}{\partial w_3w_2} = x_2x_3 \\ \end{align} $$$$ \begin{align} \frac{\partial ^2 J}{\partial w_2^2} = x_2^2 \\ \frac{\partial ^2 J}{\partial w_3^2} = x_3^2 \end{align} $$$$ \begin{align} \therefore J^H = \begin{bmatrix} x_1^2 & x_1x_2 & x_1x_3 & \\ x_2x_1 & x_2^2 & x_2x_3 \\ x_3x_1 & x_3x_2 & x_3^2 \\ \end{bmatrix} \end{align} $$

Step 2 - Computing the Principal Minors -

From previous blog post, a function is convex if all the principal minors are greater than or equal to zero i.e. $\bigtriangleup_k$ $\geq 0 \;\; \forall$ k .

compute $\bigtriangleup_1$ -

Principal Minors of order 1 ($\bigtriangleup_1$) can be obtained by deleting any 3-1 = 2 rows and corresponding columns.

a. By deleting row 2 and 3 along with corresponding columns $ \bigtriangleup_1 $ = x_1^2

b. By deleting row 1 and 3 along with corresponding columns $ \bigtriangleup_1 $ = x_2^2

c. By deleting row 1 and 2 along with corresponding columns $ \bigtriangleup_1 $ = x_3^2

compute $\bigtriangleup_2$ -

Principal Minors of order 2 can be obtained by deleting any 3-2 = 1 row and corresponding column.

a. By deleting row 1 and corresponding column 1 -

$$ \begin{align} \bigtriangleup_2 & = \begin{vmatrix} x_2^2 & x_2x_3 \\ x_3x_2 & x_3^2 \end{vmatrix} \\ & = x_2^2x_3^2 - (x_2x_3)(x_3x_2) \\ & = x_2^2x_3^2 - x_2^2x_3^2 \\ & = 0 \end{align} $$

b. By deleting row 2 and corresponding column 2 $$ \begin{align} \bigtriangleup_2 & = \begin{vmatrix} x_1^2 & x_1x_3 \\ x_3x_1 & x_3^2 \end{vmatrix} \\ & = x_1^2x_3^2 - (x_1x_3)(x_3x_1) \\ & = x_1^2x_3^2 - x_1^2x_3^2 \\ & = 0 \end{align} $$

c. By deleting row 3 and corresponding column 3 $$ \begin{align} \bigtriangleup_2 & = \begin{vmatrix} x_1^2 & x_1x_2 \\ x_2x_1 & x_2^2 \end{vmatrix} \\ & = x_1^2x_2^2 - (x_1x_2)(x_2x_1) \\ & = x_1^2x_2^2 - x_1^2x_2^2 \\ & = 0 \end{align} $$

compute $\bigtriangleup_3$ -

Principal Minors of order 3 can be obtained by computing determinant of J(W).

$$ \begin{align} \bigtriangleup_3 & = \begin{vmatrix}J^H \end{vmatrix} \\ &= \begin{vmatrix} x_1^2 & x_1x_2 & x_1x_3 & \\ x_2x_1 & x_2^2 & x_2x_3 \\ x_3x_1 & x_3x_2 & x_3^2 \\ \end{vmatrix}\\ &= x_1^2 * (x_2^2x_3^2 - x_2^2x_3^2) - x_1x_2 * (x_1x_2x_3^2 - x_1x_2x_3^2) + x_1x_3(x_1x_2^2x_3 - x_1x_2^2x_3) \\ &= 0 \end{align} $$

Step 3 - Comment on Definiteness of Hessian of J(w) -

  • The principal minors of order 1 have a squared form. We know that a squared function is always positive.
  • The principal minors of order 2 and 3 are equal zero.
  • It can be concluded that $\bigtriangleup_k \geq 0 \;\; \forall k$
  • Hence the Hessian of J(w) is Positive Semidefinite.

Step 4 Comment on convexity -

Before we comment on the convexity of J(W), let's revise the conditions for convexity -

If $X^H$ is the Hessian Matrix of f(x) then -

  • f(x) is strictly convex in $S$ if $X^H$ is a Postive Definite Matrix.
  • f(x) is convex in $S$ if $X^H$ is a Postive Semi-Definite Matrix.
  • f(x) is strictly concave in $S$ if $X^H$ is a Negative Definite Matrix.
  • f(x) is concave in $S$ if $X^H$ is a Negative Semi-Definite Matrix.

Since the Hessian of J(w) is Positive Semidefinite, it can be concluded that the function J(w) is convex.

Final Comments -

  • This blog post is aimed at proving the convexity of MSE loss function in a Regression setting by simplifying the problem.
  • There are different ways of proving the convexity but I found this easier to comprehend.
  • Feel free to try out the process for different loss functions that you may have encountered.

Comments