Site icon Machine Learning For Analytics

LASSO REGRESSION AND ITS IMPLEMENTATION WITH PYTHON

Hi Everyone! Today, we will learn about Lasso regression/L1 regularization, the mathematics behind ridge regression and how to implement it using Python!

To build a great foundation on the basics, let’s understand few points given below:

Hence. at this value of µ, the likelihood function is maximized.

and prior is given by:

Now, let’s see how to implement L1 regularization or Lasso Regression by using Gradient Descent(I will be covering gradient descent in a separate post).

In [1]:
from __future__ import print_function, division
from builtins import range

import numpy as np # importing numpy with alias np
import matplotlib.pyplot as plt # importing matplotlib.pyplot with alias plt
In [2]:
No_of_observations = 50  
No_of_Dimensions = 50

X_input = (np.random.random((No_of_observations, No_of_Dimensions))-0.5)*10 #Generating 50x50 matrix forX with random values centered round 0.5      
w_dash =  np.array([1, 0.5, -0.5] + [0]*(No_of_Dimensions-3)) # Making first 3 features significant by setting w for them as non-zero and others zero
Y_output = X_input.dot(w_dash) + np.random.randn(No_of_observations)*0.5 #Setting Y = X.w + some random noise
In [3]:
costs = [] #Setting empty list for costs
w = np.random.randn(No_of_Dimensions)/np.sqrt(No_of_Dimensions) #Setting w to random values
L1_coeff = 5    
learning_rate = 0.001 #Setting learning rate to small value so that the gradient descent algo doesn't skip the minima
In [4]:
for i in range(500):
    Yhat = X_input.dot(w)
    delta = Yhat - Y_output #the error between predicted output and actual output
    w = w - learning_rate*(X_input.T.dot(delta) + L1_coeff*np.sign(w)) #performing gradient descent for w
    meanSquareError = delta.dot(delta)/No_of_observations #Finding mean square error
    costs.append(meanSquareError) #Appending mse for each iteration in costs list
    
In [5]:
plt.plot(costs)
plt.title("Plot of costs of L1 Regularization")
plt.ylabel("Costs")
plt.show()
In [6]:
print("final w:", w) #The final w output. As you can see, first 3 w's are significant , the rest are very small
final w: [  9.65816491e-01   4.27099719e-01  -4.39501114e-01   7.26803718e-04
   1.44676529e-03   4.29653783e-03  -1.88827800e-02   5.01402266e-03
  -1.45435498e-02   2.98832870e-03  -1.94071569e-03  -1.47917010e-02
   3.56488642e-02   2.44495593e-02  -3.40885499e-03  -2.23948913e-02
  -8.56983401e-04   1.00292301e-02   3.33973800e-03   8.51922055e-03
  -3.72198952e-02   5.31823613e-03  -3.35052948e-02   7.15853488e-03
  -1.00094617e-02  -1.44190084e-03   2.96771082e-03  -6.51081371e-03
   3.54465569e-02  -3.30111666e-02   4.42377796e-03  -7.87768360e-03
   1.26511065e-02  -5.43831611e-04  -4.58914064e-04   5.53972101e-03
  -8.31677251e-03   8.63159114e-03  -6.17622135e-03  -3.08958154e-03
   1.39908214e-02   9.34415972e-03  -3.76350383e-03  -2.16322570e-03
   3.84337810e-03  -6.68382801e-04  -2.84473367e-03   2.48744388e-03
  -8.91564845e-03   6.97568406e-02]
In [7]:
# plot our w vs true w
plt.plot(w_dash, label='true w')
plt.plot(w, label='w_map')
plt.legend()
plt.show()
Exit mobile version