Data analysis and machine learning for 11-year-olds

My 11-year-old son recently started to watch movies about science and facts on YouTube, in addition to the usual gamer shows. (I hope it is not just because he has run out of gaming videos).
A couple of weeks ago he was super excited when I came home from work, he had watched a video about Titanic and could not stop talking about it. He had leaned all about why it sunk, all the different circumstances, and how the outcome could have been different if some of even then smaller things had been different.
Then it occurred to me that one of the standard data set when getting started with data analysis and machine learning is the passenger list from Titanic, with information about fare, passenger class, age, gender, if the person had spouse or siblings on board and if it had child or parents on board.
Since my son recently learned some elementary statistics in school, I thought it would be fun to show him how we can work with the Titanic data set. I have earlier given some workshops on Apache Spark in Java, so that is what we used. Spark has great documentation of their machine learning library here.
The Titanic data set is available, with some variations, from many places. I downloaded the titanic_original.csv from this GitHub repository. (I tried the cleaned data set first, but persons with missing age is given the average age, and that would mess up what I wanted to do with the data, so I continued with the original data set).

Load data from csv

The first thing we have to do is to start spark and load the data. It is quite easy to read data with Spark. If you are lucky, the option inferSchema will figure out the correct types.

SparkConf conf = new SparkConf().setAppName("Titanic").setMaster("local[*]");
SparkSession spark = SparkSession.builder().config(conf).getOrCreate();

Dataset<Row> passengers = spark.read()
        .option("inferSchema", "true")
        .option("delimiter", ",")
        .option("header", true)
        .csv(TITANIC_DATA_PATH);

When the data has been loaded into a Dataset, it is a good idea to check that it actually contains what it should, and that the columns have the right data types. We can do that with the methods passengers.show() and passengers.printSchema(). The first one prints the first twenty rows of the data set, and the latter prints the type for each column.

+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
|pclass|survived|                name|   sex|   age|sibsp|parch|  ticket|    fare|  cabin|embarked|boat|body|           home.dest|
+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
|     1|       1|Allen, Miss. Elis...|female|  29.0|    0|    0|   24160|211.3375|     B5|       S|   2|null|        St Louis, MO|
|     1|       1|Allison, Master. ...|  male|0.9167|    1|    2|  113781|  151.55|C22 C26|       S|  11|null|Montreal, PQ / Ch...|
|     1|       0|Allison, Miss. He...|female|   2.0|    1|    2|  113781|  151.55|C22 C26|       S|null|null|Montreal, PQ / Ch...|
|     1|       0|Allison, Mr. Huds...|  male|  30.0|    1|    2|  113781|  151.55|C22 C26|       S|null| 135|Montreal, PQ / Ch...|
|     1|       0|Allison, Mrs. Hud...|female|  25.0|    1|    2|  113781|  151.55|C22 C26|       S|null|null|Montreal, PQ / Ch...|
|     1|       1| Anderson, Mr. Harry|  male|  48.0|    0|    0|   19952|   26.55|    E12|       S|   3|null|        New York, NY|
|     1|       1|Andrews, Miss. Ko...|female|  63.0|    1|    0|   13502| 77.9583|     D7|       S|  10|null|          Hudson, NY|
|     1|       0|Andrews, Mr. Thom...|  male|  39.0|    0|    0|  112050|     0.0|    A36|       S|null|null|         Belfast, NI|
|     1|       1|Appleton, Mrs. Ed...|female|  53.0|    2|    0|   11769| 51.4792|   C101|       S|   D|null| Bayside, Queens, NY|
|     1|       0|Artagaveytia, Mr....|  male|  71.0|    0|    0|PC 17609| 49.5042|   null|       C|null|  22| Montevideo, Uruguay|
|     1|       0|Astor, Col. John ...|  male|  47.0|    1|    0|PC 17757| 227.525|C62 C64|       C|null| 124|        New York, NY|
|     1|       1|Astor, Mrs. John ...|female|  18.0|    1|    0|PC 17757| 227.525|C62 C64|       C|   4|null|        New York, NY|
|     1|       1|Aubart, Mme. Leon...|female|  24.0|    0|    0|PC 17477|    69.3|    B35|       C|   9|null|       Paris, France|
|     1|       1|"Barber, Miss. El...|female|  26.0|    0|    0|   19877|   78.85|   null|       S|   6|null|                null|
|     1|       1|Barkworth, Mr. Al...|  male|  80.0|    0|    0|   27042|    30.0|    A23|       S|   B|null|       Hessle, Yorks|
|     1|       0| Baumann, Mr. John D|  male|  null|    0|    0|PC 17318|  25.925|   null|       S|null|null|        New York, NY|
|     1|       0|Baxter, Mr. Quigg...|  male|  24.0|    0|    1|PC 17558|247.5208|B58 B60|       C|null|null|        Montreal, PQ|
|     1|       1|Baxter, Mrs. Jame...|female|  50.0|    0|    1|PC 17558|247.5208|B58 B60|       C|   6|null|        Montreal, PQ|
|     1|       1|Bazzani, Miss. Al...|female|  32.0|    0|    0|   11813| 76.2917|    D15|       C|   8|null|                null|
|     1|       0|Beattie, Mr. Thomson|  male|  36.0|    0|    0|   13050| 75.2417|     C6|       C|   A|null|        Winnipeg, MN|
+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
root
 |-- pclass: integer (nullable = true)
 |-- survived: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: double (nullable = true)
 |-- sibsp: integer (nullable = true)
 |-- parch: integer (nullable = true)
 |-- ticket: string (nullable = true)
 |-- fare: double (nullable = true)
 |-- cabin: string (nullable = true)
 |-- embarked: string (nullable = true)
 |-- boat: string (nullable = true)
 |-- body: integer (nullable = true)
 |-- home.dest: string (nullable = true)

As mentioned, age is missing for some of the passengers, the same holds for fare, so for our purpose we will just remove the rows where age or fare is null, with a syntax quite similar to SQL.

passengers = passengers.where(col("age").isNotNull().and(col("fare").isNotNull()));

Most of the columns are self explanatory, but pclass is the passenger class, sibsp is number of siblings and spouse the passenger had on board, and similarly, parch is the number of parents and children.

Average, median and mode

Now that we’ve got the data we can finally start doing some statistics. The statistics taught in 5th grade basically covers average, median and mode of a set of numbers, and how to draw different types of charts. I asked my son what he believed would be the average age of the passengers, he guessed 30 years. Let’s see how we can get the average, together with median and mode from Spark.

There is a useful method describe on a dataset, which gives us various information about the columns we ask for. It turns out that 30 was a good guess for the average 🙂 .

passengers.describe("age").show();
+-------+------------------+
|summary|               age|
+-------+------------------+
|  count|              1045|
|   mean|29.851834162679427|
| stddev|14.389200976290613|
|    min|            0.1667|
|    max|              80.0|
+-------+------------------+

To find the mode, the value that occurs most times, we draw a plot with ages along the x-axis, and the count of passengers with that age along the y-axis. The plot is made with the Java library XChart.

By looking at the chart it seems that 24 years is the mode, we can verify that by querying the data set and look at the first row that is printed.

passengers.groupBy("age").count().orderBy(col("count").desc()).show();

The median is the middle data point when the data is sorted. We can sort the data set on age, but a dataset does not have an index we can query for. The data set has 1045 entries, so the easiest thing would be to do .show(523) and look at the last row that is printed, or we can add an id column like this (according to the docs it is not guaranteed to be consecutive, but it will be if your data is not partitioned, as in our case).

passengers.orderBy("age").withColumn("id", monotonically_increasing_id());

A more proper, but much less intutive, way of finding the median would be to use the quantile function like this

passengers.stat().approxQuantile("age", new double[]{0.5}, 0);

No matter how we find the median, the result is 28.

That was the basic statistics, let us move on to analyze who the survivors were.

Survivors by gender

Let us start by finding the amount and fraction of survivors in total. It can be done with the following line

passengers.groupBy("survived").count().withColumn("fraction", col("count").divide(sum("count").over()));

The table below shows the depressing result that only 41% of the 1045 passengers in our data set survived.

Survived Count Fraction
1 427 0,41
0 618 0,59

So how is the rate of survivors by gender?
I told my son that it was common to save children and women before men, and he was shocked; “What, is that true? That’s totally unfair!”.
Well, I actually found the paper Gender, social norms, and survival in maritime disasters where the authors have studied maritime disasters and to which extend the “social norm of women and children first” is followed. They conclude that “Women have a distinct survival disadvantage compared with men. Captains and crew survive at a significantly higher rate than passengers.”

If survival was independent of gender, we would assume that around 41% of the woman would survive, and likewise, 41% of the men. That is logical, also for a 5th grader. But what are the actual rates based on gender?

Gender Survived Total Percentage survivors by gender
Female 292 388 75,3%
Male 135 657 20,5%

We clearly see that survival is dependent on gender, so will can assume that gender will be an important feature for predicting if a person survived or not. This is the intuition behind the Chi-squared test, which also is implemented in Spark as the ChiSqSelector, which can be used to find the most important features for a data set.

Prediction with decision trees

Decision trees are one of the simplest types of algorithms in machine learning and it is easy to understand the result of the algorithm, I think of it as if-else-statements written by the program and not by the programmer. Decision trees are (usually) calculated top down, by selecting the feature that separates the data points best, in terms of grouping data with the same label (the value we try to predict).
A decision tree classifier can have a tendency to over-fit, which means that the model fit the training set very well, but becomes less good for other data sets. A common way to avoid this is to make several decision trees, a random forest, where some randomness is added to the generation of the trees to make them different. The prediction for an item is the label it gets most times when scoring it in each of the separate trees.

In spark we start by splitting the data set randomly into two sets, one training set for fitting the model, containing 70% of the data, and a test set used to evaluate the model.

Dataset<Row>[] splits = passengers.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];

Before we can feed data into the a random forest classifier we need to transform the data. Spark’s machine learning algorithms wants data with a column “label” that contains what we are predicting, and a column “features” that contains a vector of the data attributes we want to include. I find the RFormula in Spark very useful for making label and features. The syntax is a bit strange, but the value to the left of “~” is the label, and on the right side one can add the attributes one wants, either by naming each of them, with + between them, or my staring with “.” which gives all, and then “subtract” the attributes you do not want. Survived is our label, and as features we select pclass, sex, age, fare, sibsp and parch.

RFormula formula = new RFormula().setFormula("survived ~ pclass + sex + age + fare + sibsp + parch");

We then use the RandomForestClassifer, with default configuration. Spark has this nice pipeline which makes it easy to work with multiple steps in the machine learning process. We only have to fit and transform data once on the pipeline, instead of for each individual step, and it is easy to experiment and replace parts of the chain. The following lines of code creates a pipeline, and fit pipeline on the training data, which gives back a pipelineModel.

RandomForestClassifier forestClassifier = new RandomForestClassifier();
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{formula, forestClassifier});
PipelineModel model = pipeline.fit(training);

To find out how the model is doing on the test set, we can use the binary classification evaluator, since we are doing classification with a binary label (survived/died). The evaluator will classify each entry in the data set, compare with the actual label, and count up the amount of correctly classified entries.

Dataset<Row> predictions = model.transform(test);
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
double res = eval.evaluate(predictions);

The result will vary a bit each time the program is run since there is randomness both in splitting of the data set, and in the algorithm, but an example of accuracy we got is 86,5%.

The random forest model has a string representation of the trees one can use to look at the actual trees.

RandomForestClassificationModel treeModel = (RandomForestClassificationModel) model.stages()[1];
System.out.println(treeModel.toDebugString());

I visualized one of these trees, and my son wanted to see if he would have survived or not. Luckily he just turned eleven, and would according to this tree survive if we were travelling by first or second class. I would also have survived, but it certainly doesn’t look good for dad/husband.
decision tree