Menu

PySpark Linear Regression – How to Build and Evaluate Linear Regression Models using PySpark MLlib

Written by Jagdeesh | 4 min read

MLlib, the machine learning library within PySpark, offers various tools and functions for machine learning algorithms, including linear regression. In this blog post, you will learn how to building and evaluating a linear regression model using PySpark MLlib with example code.

Linear regression is a simple yet powerful machine learning algorithm used to predict a continuous target variable based on one or more input features. PySpark, the Apache Spark library for Python, provides a distributed and scalable platform for big data processing.

Prerequisites:

To follow along, you need to have the following installed on your machine:

  1. Python 3.6 or later
  2. Apache Spark with PySpark
  3. Jupyter Notebook or any other Python IDE

1. Import required libraries and initialize SparkSession

First, let’s import the necessary libraries and create a SparkSession, the entry point to use PySpark.

python
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark import SparkFiles
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import RegressionEvaluator

spark = SparkSession.builder \
    .appName("Linear Regression with PySpark MLlib") \
    .getOrCreate()

2. Load the dataset

For this example, we will use the “Boston Housing” dataset. Save the dataset as a CSV file, and then use the following code to load the data into a PySpark DataFrame.

python
url = "https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv"
spark.sparkContext.addFile(url)

data = spark.read.csv(SparkFiles.get("BostonHousing.csv"), header=True, inferSchema=True)
data.show(5)
python
+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+----+
|   crim|  zn|indus|chas|  nox|   rm| age|   dis|rad|tax|ptratio|     b|lstat|medv|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3| 396.9| 4.98|24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8| 396.9| 9.14|21.6|
|0.02729| 0.0| 7.07|   0|0.469|7.185|61.1|4.9671|  2|242|   17.8|392.83| 4.03|34.7|
|0.03237| 0.0| 2.18|   0|0.458|6.998|45.8|6.0622|  3|222|   18.7|394.63| 2.94|33.4|
|0.06905| 0.0| 2.18|   0|0.458|7.147|54.2|6.0622|  3|222|   18.7| 396.9| 5.33|36.2|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+----+
only showing top 5 rows

3. Prepare the data

Before building the model, we need to assemble the input features into a single feature vector using the VectorAssembler class. Then, we will split the dataset into a training set (80%) and a testing set (20%).

python
assembler = VectorAssembler(
    inputCols=["crim", "zn", "indus", "chas", "nox", "rm", "age", "dis", "rad", "tax", "ptratio", "b", "lstat"],
    outputCol="features")

data = assembler.transform(data)
final_data = data.select("features", "medv")

train_data, test_data = final_data.randomSplit([0.8, 0.2], seed=42)

4. Build the linear regression model

Next, we will create an instance of the LinearRegression class and fit the model to the training data.

python
lr = LinearRegression(featuresCol="features", labelCol="medv", predictionCol="predicted_medv")
lr_model = lr.fit(train_data)

5. Make predictions and evaluate the model

Now, we will use the trained model to make predictions on the test data and evaluate its performance using the RegressionEvaluator class. We will calculate the Root Mean Squared Error (RMSE) and R-squared (R2) metrics.

python
predictions = lr_model.transform(test_data)

evaluator = RegressionEvaluator(labelCol="medv", predictionCol="predicted_medv", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data: {:.3f}".format(rmse))

evaluator_r2 = RegressionEvaluator(labelCol="medv", predictionCol="predicted_medv", metricName="r2")
r2 = evaluator_r2.evaluate(predictions)
print("R-squared (R2) on test data: {:.3f}".format(r2))
python
Root Mean Squared Error (RMSE) on test data: 4.672
R-squared (R2) on test data: 0.793

6. Inspect the model coefficients and intercept

To better understand the linear regression model, you can examine its coefficients and intercept. These values represent the weights assigned to each feature and the bias term, respectively.

python
coefficients = lr_model.coefficients
intercept = lr_model.intercept

print("Coefficients: ", coefficients)
print("Intercept: {:.3f}".format(intercept))
python
Coefficients:  [-0.11362203729408954,0.048909186934053925,0.02379542898673389,2.801771998735119,-18.4154245411894,3.5158797633120065,0.0052116821614709204,-1.4163830723539739,0.3317669315937035,-0.013607893704163878,-0.9534143338408072,0.008602677392853256,-0.519503531247664]
Intercept: 38.617

7. Analyze feature importance

To determine which features contribute most to the model’s predictions, you can analyze the absolute values of the coefficients. Features with higher absolute coefficients have a greater impact on the target variable.

python
feature_importance = sorted(list(zip(data.columns[:-1], map(abs, coefficients))), key=lambda x: x[1], reverse=True)

print("Feature Importance:")
for feature, importance in feature_importance:
    print("  {}: {:.3f}".format(feature, importance))
python
Feature Importance:
  nox: 18.415
  rm: 3.516
  chas: 2.802
  dis: 1.416
  ptratio: 0.953
  lstat: 0.520
  rad: 0.332
  crim: 0.114
  zn: 0.049
  indus: 0.024
  tax: 0.014
  b: 0.009
  age: 0.005

8. Improve the model (optional)

If the model’s performance does not meet your expectations, you can try the following strategies to improve it:

  1. Feature selection: Remove less important features or add new features based on domain knowledge.
  2. Feature scaling: Standardize or normalize the input features to ensure they are on the same scale.
  3. Hyperparameter tuning: Adjust the model’s hyperparameters, such as regularization strength or iteration count.

9. Save and load the model (optional)

If you want to reuse the model in the future, you can save it to disk and load it back when needed.

python
# Save the model
lr_model.save("lr_model")

# Load the model
from pyspark.ml.regression import LinearRegressionModel
loaded_model = LinearRegressionModel.load("lr_model")

Conclusion

In this blog post, you will learn how to build and evaluate a linear regression model using PySpark MLlib. We walked through the entire process, from loading the dataset and preparing the data to building and evaluating the model.

We also discussed how to inspect the model’s coefficients and intercept, analyze feature importance, and save and load the model.

With this knowledge, you can now leverage the power of PySpark and MLlib to build scalable and efficient linear regression models for your big data projects.

Free Course
Master Core Python — Your First Step into AI/ML

Build a strong Python foundation with hands-on exercises designed for aspiring Data Scientists and AI/ML Engineers.

Start Free Course
Trusted by 50,000+ learners
Jagdeesh
Written by
Related Course
Master PySpark — Hands-On
Join 5,000+ students at edu.machinelearningplus.com
Explore Course
Get the full course,
completely free.
Join 57,000+ students learning Python, SQL & ML. One year of access, all resources included.
📚 10 Courses
🐍 Python & ML
🗄️ SQL
📦 Downloads
📅 1 Year Access
No thanks
🎓
Free AI/ML Starter Kit
Python · SQL · ML · 10 Courses · 57,000+ students
🎉   You're in! Check your inbox (or Promotions/Spam) for the access link.
⚡ Before you go

Python.
SQL. NumPy.
All free.

Get the exact 10-course programming foundation that Data Science professionals use.

🐍
Core Python — from first line to expert level
📈
NumPy & Pandas — the #1 libraries every DS job needs
🗃️
SQL Levels I–III — basics to Window Functions
📄
Real industry data — Jupyter notebooks included
R A M S K
57,000+ students
★★★★★ Rated 4.9/5
⚡ Before you go
Python. SQL.
All Free.
R A M S K
57,000+ students  ★★★★★ 4.9/5
Get Free Access Now
10 courses. Real projects. Zero cost. No credit card.
New learners enrolling right now
🔒 100% free ☕ No spam, ever ✓ Instant access
🚀
You're in!
Check your inbox for your access link.
(Check Promotions or Spam if you don't see it)
Or start your first course right now:
Start Free Course →
Scroll to Top
Scroll to Top
Course Preview

Machine Learning A-Z™: Hands-On Python & R In Data Science

Free Sample Videos:

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science