Linear Regression using PyTorch
Linear Regression is a very commonly used statistical method that allows us to determine and study the relationship between two continuous variables. The various properties of linear regression and its Python implementation has been covered in this article previously. Now, we shall find out how to implement this in PyTorch, a very popular deep learning library that is being developed by Facebook.
Firstly, you will need to install PyTorch into your Python environment. The easiest way to do this is to use the pip or conda tool. Visit pytorch.org and install the version of your Python interpreter and the package manager that you would like to use.
With PyTorch installed, let us now have a look at the code.
Write the two lines given below to import the necessary library functions and objects.
We also define some data and assign them to variables x_data and y_data as given below:
Here, x_data is our independent variable and y_data is our dependent variable. This will be our dataset for now. Next, we need to define our model. There are two main steps associated with defining our model. They are:
- Initialising our model.
- Declaring the forward pass.
We use the class given below:
As you can see, our Model class is a subclass of torch.nn.module. Also, since here we have only one input and one output, we use a Linear model with both the input and output dimension as 1.
Next, we create an object of this model.
After this, we select the optimiser and the loss criteria. Here, we will use the mean squared error (MSE) as our loss function and stochastic gradient descent (SGD) as our optimiser. Also, we arbitrarily fix a learning rate of 0.01.
We now arrive at our training step. We perform the following tasks for 500 times during training:
- Perform a forward pass by passing our data and finding out the predicted value of y.
- Compute the loss using MSE.
- Reset all the gradients to 0, perform a backpropagation and then, update the weights.
Once the training is completed, we test if we are getting correct results using the model that we defined. So, we test it for an unknown value of x_data, in this case, 4.0.
If you performed all steps correctly, you will see that for the input 4.0, you are getting a value that is very close to 8.0 as below. So, our model inherently learns the relationship between the input data and the output data without being programmed explicitly.
predict (after training) 4 7.966438293457031
For your reference, you can find the entire code of this article given below:
Attention reader! Don’t stop learning now. Get hold of all the important Machine Learning Concepts with the Machine Learning Foundation Course at a student-friendly price and become industry ready.