Predict churn in music streaming services

Robert Offner
11 min readMar 13, 2021
Photo taken from Auto Bild Service Center

Problem Introduction

All services have two opponent forces fighting each other — user acquisition and user churn. As long as you acquire more than you lose, you are on a good road to success. If you’re not, there is some things you can do:

  1. Acquire more users
  2. Have users churn slower

Let us disregard acquiring more users for now, as it is expensive and will only become more expensive as time goes by.

That leaves us with reducing churn. In a data driven company, luckily we know quite a lot about our users, and if we are lucky (and we are in this case) we have a very precise knowledge about whether a user has churned.

This data is kindly provided by Udacity in their Sparkify data set.

Sparkify is an imaginary music streaming service, which collects a wide array of data about its users. This data will be analyzed with the help of Spark to identify the characteristics of churning players as well as predicting, whether a user is about to churn. Having this information is a great step in the direction of reducing churn by catering to these players and removing their pain points.

Strategy to solve the problem

To solve the problem it’s imperative we have a good understanding of the data set. For this, we’ll take a look at the metrics, engineer some additional ones and conduct an EDA.

Afterwards several classification models will be trained to find the one best suited to the data set. This includes hyperparameter tuning to find the best setup for the specific model. That way we will have a clear prediction whether a user is likely to churn based on which the product team can provide special care to them.

Finally, we’ll discuss the results we found and think about what we could improve.

To make sure we can scale this solution to a production environment with a lot of data, we will use Spark for all of this, specifically the PySpark interface.


We will start by looking at the data set we were provided. For this, the medium data set is used, to reduce computational effort.

| — artist: string (nullable = true)
| — auth: string (nullable = true)
| — firstName: string (nullable = true)
| — gender: string (nullable = true)
| — itemInSession: long (nullable = true)
| — lastName: string (nullable = true)
| — length: double (nullable = true)
| — level: string (nullable = true)
| — location: string (nullable = true)
| — method: string (nullable = true)
| — page: string (nullable = true)
| — registration: long (nullable = true)
| — sessionId: long (nullable = true)
| — song: string (nullable = true)
| — status: long (nullable = true)
| — ts: long (nullable = true)
| — userAgent: string (nullable = true)
| — userId: string (nullable = true)

There are a couple of columns that are unlikely to provide much benefit to our analysis, like firstName, lastName, length (length of the song) and method (http method). So let’s drop these columns from the data set.

Missing values will be our next step. From the graph below we can see, that we have two categories of missing values: User identifiers and song identifiers.

The missing user identifiers are caused by users not being logged in, or logged in as guests. As there is no value in looking at users that are not logged in for churn analysis — their entries will be disregarded. The empty Artist and Song values will not be removed, as the rest of their rows may provide valuable data.

Churn is a binary problem, so for evaluating the model performance, we need an evaluation metric fit for this. Based on this article by Rahul Agarwal, both Recall and F1 Score are well suited to our problem, as we want to predict as many postives as possible. For this project we will use the F1 Score to still strike a good balance between precision and recall.


Now that we have a basic sense of the data, let’s process it.

Feature Engineering

  • Churn for this report is defined as having had a Cancellation Confirmation page visit. This is the feature we are trying to predict. For users that visited this page, we will create a new column with a 1 that marks all their events. For all users that have not churned, the value is 0.
  • The userAgent column contains information on how the user accessed the service. With a bit of regular expressions, we can extract the operating system provider (Platform) from this. The versions that come up are Microsoft, Apple and Linux.
  • Using the ts column (timestamp) and the registration column (ts at the time of registering to the service), the user age at the time of the event is calculated and saved in another column.

Dropping Columns & Rows

  • To make the users comparable, we’ll only look at their latest 30 days of data and aggregate it down to each user. This will be the input to the classification model. For this, for each user the timestamp of the latest event or of the churn event is taken and 30 days are subtracted. Only events with a higher timestamp for each user are considered.
  • As stated above, firstName, lastName, length and method are dropped as they are unlikely to benefit the prediction. Rows with an empty userId are dropped as well, as we can’t use them for the model.
  • The exploratory data analysis will show, that the Platform , Level , User Age , Songs Listened and Down Votes columns are likely to be the most valuable columns as input for the model, so all other columns will be dropped after the EDA as well.


  • The subscription level is indexed manually, with paying users as 1 and free users as 0.
  • For the machine learning model, the engineered Platform column is is indexed using Sparks stringIndexer.


  • The features are scaled using a MinMaxScaler. Assuming there are no heavy outliers in our data, this will serve the model well by making sure the features have equal chance of influencing the outcome.

Exploratory Data Analysis

To get a better understanding about the data set, and see which metrics may be of value for our classification problem, let’s dive into the data.

There are roughly 13% more free than premium users in the data set.

The service has more male than female users, but not by a huge margin.

Only about 100 of the users have a Cancellation Confirmation page visit. We will later define this page visit as having churned, so about 3/4 of the users in the data set have not churned. In general, most of the engagement pages are used by most users. That more than 50 % of the users were on the Error page at one point may be worrying, as well as that a lot more users were on the settings page than users that actually saved their settings. This could mean they did not find what they were looking for.

Using created and the original metrics, we can find out if we can see any difference in the behavior of churned vs. non churned users.

  • It does not look like gender has an influence on churn.
  • Apple users are slightly more likely to churn compared to Linux and Windows users.
  • Churned users seem to be slightly more engaged, having a higher percentage of premium users, more items per session, songs listened and votes in general.
  • The one thing that stands out is that they have way give way more Down Votes on average.
  • They seem to be relatively newer to the service than the non churned users.

Based on this, the Platform , Level , User Age , Songs Listened and Down Votes metrics will be used for creating a classification model.


For this project, 4 different models will be evaluated. All columns are scales using a MinMaxScaler within a spark pipeline, that trains the model and performs hyperparameter tuning based on provided parameters. The performance while training is evaluated using the F-Score.

After training, the model with the optimum grid searched parameters is evaluated on the test data, again using the F-Score and by computing a confusion matrix.

The 4 selected models are:

  1. Logistic Regression
  2. Random Forest
  3. Gradient Boost
  4. Multilayer Perceptron

Hyperparameter Tuning

For each of the models, a set of hyperparameters is tested to find the best setup.

  • Logistic Regression
    maxIter: [5, 10]
  • Random Forest
    numTrees: [5, 10, 15]
    maxBins: [4, 8, 16]
  • Gradien Boost
    maxIter [5, 10, 15]
    stepSize [0.01, 0.1, 0.2]
  • Multilayer Perceptron
    layers [[5, 6, 5, 2], [5, 4, 4, 2], [5, 4, 2]]
    blockSize [32, 64, 128]


Model Comparison

1. Logistic Regression

The F-Score of this model is 0.72 — which is the worst performing model of the 4. Because of this, not much effort was spent on grid search parameters.

The confusion matrix shows 16 Type I errors, and no Type II errors. This is likely the favorable outcome for this model as we would be rather too careful than disregarding people that are actually close to churning.

2. Random Forest

With an F-Score of 0.8, this model is in the first place. The setup that led to this result were 15 Trees and 16 Bins.

The confusion matrix again showed 16 Type I errors, and no Type II errors.

3. Gradient Boost

An F-Score of 0.76 puts this model in the 3rd place. The setup for this was 5 Iterations and a step size of 0.2.

This model actually had a decent grasp of the data, but in spite of the other 2 models evaluated so far, it led to a couple of Type II errors, which are unfavorable for this use case.

4. Multilayer Perceptron

0.79 F-Score is just barely behind the performance of the Random Forest model. The configuration for this was 5 nodes in the input layer, 4 nodes in the intermediate layer and 2 nodes in the output layer. The block size was 32.

Despite the slightly lower F-Score, this model would be recommended as it led to the highest number of True Positives without having any Type II errors.

Model Evaluation and Validation

So we’ve decided to go forth with the Multilayer Perceptron. But why did the winning setup of hyperparameters work so well?

The best block size of this model was 32. Considering the relatively small subset of data we used here, it is expected that a low block size leads to favorable results, considering that it provides higher accuracy at the cost of lower learning speed (see Ahmed Shahzads post on this topic). Using the full data set, we may need to increase this number to keep learning speed reasonable. The lower accuracy should be balanced out by the higher data volume.

The best performing model only used one hidden layer, outperforming the models with 2 hidden layers. This means that one hidden layer was enough to capture the pattern in the data, while more layers just made the model more complex without catching more patterns. Taking into account our data, we did not have many features and they were of low complexity, so having a simple model perform well makes sense.

We can assume that this model is a robust solution to the problem as we have used cross validation with 3 folds to validate it. To improve on this validation, we could write a custom cross validation function that outputs the result of each fold and make our own judgement call. This model however has the caveat of potentially long learning times. If the model does not perform well on the entire data set at a higher block size, to keep computational effort low we might need to reconsider our choice in favor of the Random Forest classification.


How does our solution compare against the problems we want to solve?

Our goal was to predict whether a user, based on their behavior in the past 30 days, is likely to churn. We decided that we prefer to have Type I errors, rather than Type II errors. To achieve this, we decided to use Spark to train several classification models and pick one that performs well on the medium data set.

Our result is having two models that perform well on the subset of data we used. They were able to correctly predict churning users without having
Type II errors. The best performing model (MLP) is the most time consuming to train. If this is not feasible on the production data, there still is a well performing backup model (Random Forest) available.

The other 2 models had worse F-Scores and one of them had a couple of unfavorable Type II errors, so the MLP and Random Forest model are the clear winner here.


With just very limited information about a user we were able to make a decent prediction on whether this user will churn soon or not. As it turns out, the Random Forest and Multilayer Perceptron models were best suited for this task.

Of course, this would need to be validated on the full data set. It is well possible, that there is user behavior we were not able to pick up on in the small data set we used here. This could also mean that the parameters found for the evaluated models may not be the best for the full data set.

As a proof of concept this project would have proven to Sparkify that prediction churn is possible and their product team could start thinking about possible product solutions that could be provided to users identified as churners to keep them in the product, while the data team would continue refining the results.


It may well be possible, that some metrics that are important for making an accurate prediction were not found in this project. Revisiting the metrics with a more in depth EDA could uncover these.

The models were only trained on the medium data set, so the next step would be to train it on the entire data set. Searching a bigger grid of hyperparameters may be beneficial to the model performance as well.

To make this more user friendly, developing a web app where the product team can provide a users data and get their churn prediction would be a possible solution.

Special thanks for Udacity for providing the data set for this project.

You can find the code used for this analysis in this Github Repository.



Robert Offner

Have you ever been unhappy with the career path you chose? I was after my masters in Mechanical Engineering — so I became a self taught Data Engineer.