Menu

PySpark Logistic Regression – How to Build and Evaluate Logistic Regression Models using PySpark MLlib

Written by Jagdeesh | 6 min read

Lets explore how to build and evaluate a Logistic Regression model using PySpark MLlib, a library for machine learning in Apache Spark.

Logistic Regression is a widely used statistical method for modeling the relationship between a binary outcome and one or more explanatory variables.

We will cover the following steps

  1. Setting up the environment
  2. Loading and preprocessing the data
  3. Building the Logistic Regression model
  4. Evaluating the model on test data
  5. Interpretation of results
  6. Example code

1. Import required libraries and initialize SparkSession

First, let’s import the necessary libraries and create a SparkSession, the entry point to use PySpark.

python
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark import SparkFiles
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.feature import VectorAssembler

spark = SparkSession.builder \
    .appName("LogisticRegression with PySpark MLlib") \
    .getOrCreate()

2. Load the dataset

For this example, we will use the Breast Cancer Wisconsin (Diagnostic) dataset

python
url = "https://raw.githubusercontent.com/pkmklong/Breast-Cancer-Wisconsin-Diagnostic-DataSet/master/data.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("data.csv"), header=True, inferSchema=True)
df.show(2)
python
+------+---------+-----------+------------+--------------+---------+---------------+----------------+--------------+-------------------+-------------+----------------------+---------+----------+------------+-------+-------------+--------------+------------+-----------------+-----------+--------------------+------------+-------------+---------------+----------+----------------+-----------------+---------------+--------------------+--------------+-----------------------+----+
|    id|diagnosis|radius_mean|texture_mean|perimeter_mean|area_mean|smoothness_mean|compactness_mean|concavity_mean|concave points_mean|symmetry_mean|fractal_dimension_mean|radius_se|texture_se|perimeter_se|area_se|smoothness_se|compactness_se|concavity_se|concave points_se|symmetry_se|fractal_dimension_se|radius_worst|texture_worst|perimeter_worst|area_worst|smoothness_worst|compactness_worst|concavity_worst|concave points_worst|symmetry_worst|fractal_dimension_worst|_c32|
+------+---------+-----------+------------+--------------+---------+---------------+----------------+--------------+-------------------+-------------+----------------------+---------+----------+------------+-------+-------------+--------------+------------+-----------------+-----------+--------------------+------------+-------------+---------------+----------+----------------+-----------------+---------------+--------------------+--------------+-----------------------+----+
|842302|        M|      17.99|       10.38|         122.8|   1001.0|         0.1184|          0.2776|        0.3001|             0.1471|       0.2419|               0.07871|    1.095|    0.9053|       8.589|  153.4|     0.006399|       0.04904|     0.05373|          0.01587|    0.03003|            0.006193|       25.38|        17.33|          184.6|    2019.0|          0.1622|           0.6656|         0.7119|              0.2654|        0.4601|                 0.1189|null|
|842517|        M|      20.57|       17.77|         132.9|   1326.0|        0.08474|         0.07864|        0.0869|            0.07017|       0.1812|               0.05667|   0.5435|    0.7339|       3.398|  74.08|     0.005225|       0.01308|      0.0186|           0.0134|    0.01389|            0.003532|       24.99|        23.41|          158.8|    1956.0|          0.1238|           0.1866|         0.2416|               0.186|         0.275|                0.08902|null|
+------+---------+-----------+------------+--------------+---------+---------------+----------------+--------------+-------------------+-------------+----------------------+---------+----------+------------+-------+-------------+--------------+------------+-----------------+-----------+--------------------+------------+-------------+---------------+----------+----------------+-----------------+---------------+--------------------+--------------+-----------------------+----+
only showing top 2 rows

3. Prepare the data

Before building the model, we need to assemble the input features into a single feature vector using the VectorAssembler class. Then, we will split the dataset into a training set (80%) and a testing set (20%).

python
# Rename the columns for better readability
columns = ['id', 'diagnosis'] + [f'feature_{i}' for i in range(1, 32)]
data = df.toDF(*columns)

#Map 'M' (malignant) to 1 and 'B' (benign) to 0
data = data.withColumn("label", (data["diagnosis"] == "M").cast("integer")).drop("diagnosis")

feature_columns = [f'feature_{i}' for i in range(1, 25)]
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")

data = assembler.transform(data)

train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

4. Building the Logistic Regression model

Create a Logistic Regression model and fit it to the training data

python
logistic_regression = LogisticRegression(featuresCol="features", labelCol="label")
model = logistic_regression.fit(train_data)

5. Inspect the model coefficients and intercept

To better understand the linear regression model, you can examine its coefficients and intercept. These values represent the weights assigned to each feature and the bias term, respectively.

python
coefficients = model.coefficients
intercept = model.intercept

print("Coefficients: ", coefficients)
print("Intercept: {:.3f}".format(intercept))
python
Coefficients:  [-2.0108625612250113,-0.3315810006345366,-0.7534860096052377,0.018384771078677746,70.8441022574601,-152.35548045168963,104.52836257785901,88.86150324173941,61.07558915635416,819.0314463442933,28.221234440039396,-6.976368215240555,-6.722492688536127,0.4827954037778578,261.52959221617755,77.51784123708751,-119.26560355293951,640.1619368253066,114.26224549821082,-2178.1476132844323,-1.4114896156553725,1.2113831638079073,0.9059951584935497,0.02056292189691024]
Intercept: -95.705

6. Evaluating the model on test data

Now that we have trained the model, we can evaluate its performance on the test data. We will use the Area Under the ROC Curve (AUC-ROC) as our primary evaluation metric, and we will also calculate the accuracy, precision, and recall to better understand the model’s performance:

python
predictions = model.transform(test_data)

# AUC-ROC
evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", labelCol="label")
auc = evaluator.evaluate(predictions)

# Accuracy, Precision, and Recall
multi_evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
accuracy = multi_evaluator.evaluate(predictions, {multi_evaluator.metricName: "accuracy"})
precision = multi_evaluator.evaluate(predictions, {multi_evaluator.metricName: "weightedPrecision"})
recall = multi_evaluator.evaluate(predictions, {multi_evaluator.metricName: "weightedRecall"})

print(f"AUC-ROC: {auc:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
python
AUC-ROC: 0.9989
Accuracy: 0.9651
Precision: 0.9653
Recall: 0.9651

7. Interpretation of results

The model’s performance can be assessed using various evaluation metrics, such as AUC-ROC, accuracy, precision, and recall. A high AUC-ROC value (close to 1) indicates that the model can effectively distinguish between the two classes (malignant and benign).

The accuracy, precision, and recall give us additional information on the model’s performance by quantifying how well it correctly classifies the samples and how often it makes false-positive or false-negative predictions.

8. Example code

Here is the complete example code:

python
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark import SparkFiles
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.feature import VectorAssembler

spark = SparkSession.builder.appName("LogisticRegression with PySpark MLlib").getOrCreate()

# Load data
url = "https://raw.githubusercontent.com/pkmklong/Breast-Cancer-Wisconsin-Diagnostic-DataSet/master/data.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("data.csv"), header=True, inferSchema=True)

# Rename columns and map diagnosis to label
columns = ['id', 'diagnosis'] + [f'feature_{i}' for i in range(1, 32)]
data = df.toDF(*columns)
data = data.withColumn("label", (data["diagnosis"] == "M").cast("integer")).drop("diagnosis")

# Assemble features into a single vector
feature_columns = [f'feature_{i}' for i in range(1, 25)]
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(data)

# Split data into training and test sets
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

# Build the Logistic Regression model
logistic_regression = LogisticRegression(featuresCol="features", labelCol="label")
model = logistic_regression.fit(train_data)

# Evaluate the model on test data
predictions = model.transform(test_data)

# AUC-ROC
evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", labelCol="label")
auc = evaluator.evaluate(predictions)

# Accuracy, Precision, and Recall
multi_evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
accuracy = multi_evaluator.evaluate(predictions, {multi_evaluator.metricName: "accuracy"})
precision = multi_evaluator.evaluate(predictions, {multi_evaluator.metricName: "weightedPrecision"})
recall = multi_evaluator.evaluate(predictions, {multi_evaluator.metricName: "weightedRecall"})

print(f"AUC-ROC: {auc:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
python
AUC-ROC: 0.9989
Accuracy: 0.9651
Precision: 0.9653
Recall: 0.9651

9. Improve the model (optional)

If the model’s performance does not meet your expectations, you can try the following strategies to improve it:

  1. Feature selection: Remove less important features or add new features based on domain knowledge.
  2. Feature scaling: Standardize or normalize the input features to ensure they are on the same scale.
  3. Hyperparameter tuning: Adjust the model’s hyperparameters, such as regularization strength or iteration count.

10. Save and load the model (optional)

If you want to reuse the model in the future, you can save it to disk and load it back when needed.

python
# Save the model
model.save("logit_model")

# Load the model
from pyspark.ml.classification import LogisticRegressionModel
loaded_model = LogisticRegressionModel.load("logit_model")

Conclusion

In this blog post, we have learned how to build and evaluate a Logistic Regression model using PySpark MLlib. We set up the environment, loaded and preprocessed the data, built the model, and evaluated its performance on the test data using multiple metrics.

By following these steps, you can easily adapt this example for your own datasets and

Free Course
Master Core Python — Your First Step into AI/ML

Build a strong Python foundation with hands-on exercises designed for aspiring Data Scientists and AI/ML Engineers.

Start Free Course
Trusted by 50,000+ learners
Jagdeesh
Written by
Related Course
Master PySpark — Hands-On
Join 5,000+ students at edu.machinelearningplus.com
Explore Course
Get the full course,
completely free.
Join 57,000+ students learning Python, SQL & ML. One year of access, all resources included.
📚 10 Courses
🐍 Python & ML
🗄️ SQL
📦 Downloads
📅 1 Year Access
No thanks
🎓
Free AI/ML Starter Kit
Python · SQL · ML · 10 Courses · 57,000+ students
🎉   You're in! Check your inbox (or Promotions/Spam) for the access link.
⚡ Before you go

Python.
SQL. NumPy.
All free.

Get the exact 10-course programming foundation that Data Science professionals use.

🐍
Core Python — from first line to expert level
📈
NumPy & Pandas — the #1 libraries every DS job needs
🗃️
SQL Levels I–III — basics to Window Functions
📄
Real industry data — Jupyter notebooks included
R A M S K
57,000+ students
★★★★★ Rated 4.9/5
⚡ Before you go
Python. SQL.
All Free.
R A M S K
57,000+ students  ★★★★★ 4.9/5
Get Free Access Now
10 courses. Real projects. Zero cost. No credit card.
New learners enrolling right now
🔒 100% free ☕ No spam, ever ✓ Instant access
🚀
You're in!
Check your inbox for your access link.
(Check Promotions or Spam if you don't see it)
Or start your first course right now:
Start Free Course →
Scroll to Top
Scroll to Top
Course Preview

Machine Learning A-Z™: Hands-On Python & R In Data Science

Free Sample Videos:

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science