Data analysis and machine learning for 11-year-olds

My 11-year-old son recently started to watch movies about science and facts on YouTube, in addition to the usual gamer shows. (I hope it is not just because he has run out of gaming videos).
A couple of weeks ago he was super excited when I came home from work, he had watched a video about Titanic and could not stop talking about it. He had leaned all about why it sunk, all the different circumstances, and how the outcome could have been different if some of even then smaller things had been different.
Then it occurred to me that one of the standard data set when getting started with data analysis and machine learning is the passenger list from Titanic, with information about fare, passenger class, age, gender, if the person had spouse or siblings on board and if it had child or parents on board.
Since my son recently learned some elementary statistics in school, I thought it would be fun to show him how we can work with the Titanic data set. I have earlier given some workshops on Apache Spark in Java, so that is what we used. Spark has great documentation of their machine learning library here.
The Titanic data set is available, with some variations, from many places. I downloaded the titanic_original.csv from this GitHub repository. (I tried the cleaned data set first, but persons with missing age is given the average age, and that would mess up what I wanted to do with the data, so I continued with the original data set).

Load data from csv

The first thing we have to do is to start spark and load the data. It is quite easy to read data with Spark. If you are lucky, the option inferSchema will figure out the correct types.

SparkConf conf = new SparkConf().setAppName("Titanic").setMaster("local[*]");
SparkSession spark = SparkSession.builder().config(conf).getOrCreate();

Dataset<Row> passengers = spark.read()
        .option("inferSchema", "true")
        .option("delimiter", ",")
        .option("header", true)
        .csv(TITANIC_DATA_PATH);

When the data has been loaded into a Dataset, it is a good idea to check that it actually contains what it should, and that the columns have the right data types. We can do that with the methods passengers.show() and passengers.printSchema(). The first one prints the first twenty rows of the data set, and the latter prints the type for each column.

+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
|pclass|survived|                name|   sex|   age|sibsp|parch|  ticket|    fare|  cabin|embarked|boat|body|           home.dest|
+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
|     1|       1|Allen, Miss. Elis...|female|  29.0|    0|    0|   24160|211.3375|     B5|       S|   2|null|        St Louis, MO|
|     1|       1|Allison, Master. ...|  male|0.9167|    1|    2|  113781|  151.55|C22 C26|       S|  11|null|Montreal, PQ / Ch...|
|     1|       0|Allison, Miss. He...|female|   2.0|    1|    2|  113781|  151.55|C22 C26|       S|null|null|Montreal, PQ / Ch...|
|     1|       0|Allison, Mr. Huds...|  male|  30.0|    1|    2|  113781|  151.55|C22 C26|       S|null| 135|Montreal, PQ / Ch...|
|     1|       0|Allison, Mrs. Hud...|female|  25.0|    1|    2|  113781|  151.55|C22 C26|       S|null|null|Montreal, PQ / Ch...|
|     1|       1| Anderson, Mr. Harry|  male|  48.0|    0|    0|   19952|   26.55|    E12|       S|   3|null|        New York, NY|
|     1|       1|Andrews, Miss. Ko...|female|  63.0|    1|    0|   13502| 77.9583|     D7|       S|  10|null|          Hudson, NY|
|     1|       0|Andrews, Mr. Thom...|  male|  39.0|    0|    0|  112050|     0.0|    A36|       S|null|null|         Belfast, NI|
|     1|       1|Appleton, Mrs. Ed...|female|  53.0|    2|    0|   11769| 51.4792|   C101|       S|   D|null| Bayside, Queens, NY|
|     1|       0|Artagaveytia, Mr....|  male|  71.0|    0|    0|PC 17609| 49.5042|   null|       C|null|  22| Montevideo, Uruguay|
|     1|       0|Astor, Col. John ...|  male|  47.0|    1|    0|PC 17757| 227.525|C62 C64|       C|null| 124|        New York, NY|
|     1|       1|Astor, Mrs. John ...|female|  18.0|    1|    0|PC 17757| 227.525|C62 C64|       C|   4|null|        New York, NY|
|     1|       1|Aubart, Mme. Leon...|female|  24.0|    0|    0|PC 17477|    69.3|    B35|       C|   9|null|       Paris, France|
|     1|       1|"Barber, Miss. El...|female|  26.0|    0|    0|   19877|   78.85|   null|       S|   6|null|                null|
|     1|       1|Barkworth, Mr. Al...|  male|  80.0|    0|    0|   27042|    30.0|    A23|       S|   B|null|       Hessle, Yorks|
|     1|       0| Baumann, Mr. John D|  male|  null|    0|    0|PC 17318|  25.925|   null|       S|null|null|        New York, NY|
|     1|       0|Baxter, Mr. Quigg...|  male|  24.0|    0|    1|PC 17558|247.5208|B58 B60|       C|null|null|        Montreal, PQ|
|     1|       1|Baxter, Mrs. Jame...|female|  50.0|    0|    1|PC 17558|247.5208|B58 B60|       C|   6|null|        Montreal, PQ|
|     1|       1|Bazzani, Miss. Al...|female|  32.0|    0|    0|   11813| 76.2917|    D15|       C|   8|null|                null|
|     1|       0|Beattie, Mr. Thomson|  male|  36.0|    0|    0|   13050| 75.2417|     C6|       C|   A|null|        Winnipeg, MN|
+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
root
 |-- pclass: integer (nullable = true)
 |-- survived: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: double (nullable = true)
 |-- sibsp: integer (nullable = true)
 |-- parch: integer (nullable = true)
 |-- ticket: string (nullable = true)
 |-- fare: double (nullable = true)
 |-- cabin: string (nullable = true)
 |-- embarked: string (nullable = true)
 |-- boat: string (nullable = true)
 |-- body: integer (nullable = true)
 |-- home.dest: string (nullable = true)

As mentioned, age is missing for some of the passengers, the same holds for fare, so for our purpose we will just remove the rows where age or fare is null, with a syntax quite similar to SQL.

passengers = passengers.where(col("age").isNotNull().and(col("fare").isNotNull()));

Most of the columns are self explanatory, but pclass is the passenger class, sibsp is number of siblings and spouse the passenger had on board, and similarly, parch is the number of parents and children.

Average, median and mode

Now that we’ve got the data we can finally start doing some statistics. The statistics taught in 5th grade basically covers average, median and mode of a set of numbers, and how to draw different types of charts. I asked my son what he believed would be the average age of the passengers, he guessed 30 years. Let’s see how we can get the average, together with median and mode from Spark.

There is a useful method describe on a dataset, which gives us various information about the columns we ask for. It turns out that 30 was a good guess for the average 🙂 .

passengers.describe("age").show();
+-------+------------------+
|summary|               age|
+-------+------------------+
|  count|              1045|
|   mean|29.851834162679427|
| stddev|14.389200976290613|
|    min|            0.1667|
|    max|              80.0|
+-------+------------------+

To find the mode, the value that occurs most times, we draw a plot with ages along the x-axis, and the count of passengers with that age along the y-axis. The plot is made with the Java library XChart.

By looking at the chart it seems that 24 years is the mode, we can verify that by querying the data set and look at the first row that is printed.

passengers.groupBy("age").count().orderBy(col("count").desc()).show();

The median is the middle data point when the data is sorted. We can sort the data set on age, but a dataset does not have an index we can query for. The data set has 1045 entries, so the easiest thing would be to do .show(523) and look at the last row that is printed, or we can add an id column like this (according to the docs it is not guaranteed to be consecutive, but it will be if your data is not partitioned, as in our case).

passengers.orderBy("age").withColumn("id", monotonically_increasing_id());

A more proper, but much less intutive, way of finding the median would be to use the quantile function like this

passengers.stat().approxQuantile("age", new double[]{0.5}, 0);

No matter how we find the median, the result is 28.

That was the basic statistics, let us move on to analyze who the survivors were.

Survivors by gender

Let us start by finding the amount and fraction of survivors in total. It can be done with the following line

passengers.groupBy("survived").count().withColumn("fraction", col("count").divide(sum("count").over()));

The table below shows the depressing result that only 41% of the 1045 passengers in our data set survived.

Survived Count Fraction
1 427 0,41
0 618 0,59

So how is the rate of survivors by gender?
I told my son that it was common to save children and women before men, and he was shocked; “What, is that true? That’s totally unfair!”.
Well, I actually found the paper Gender, social norms, and survival in maritime disasters where the authors have studied maritime disasters and to which extend the “social norm of women and children first” is followed. They conclude that “Women have a distinct survival disadvantage compared with men. Captains and crew survive at a significantly higher rate than passengers.”

If survival was independent of gender, we would assume that around 41% of the woman would survive, and likewise, 41% of the men. That is logical, also for a 5th grader. But what are the actual rates based on gender?

Gender Survived Total Percentage survivors by gender
Female 292 388 75,3%
Male 135 657 20,5%

We clearly see that survival is dependent on gender, so will can assume that gender will be an important feature for predicting if a person survived or not. This is the intuition behind the Chi-squared test, which also is implemented in Spark as the ChiSqSelector, which can be used to find the most important features for a data set.

Prediction with decision trees

Decision trees are one of the simplest types of algorithms in machine learning and it is easy to understand the result of the algorithm, I think of it as if-else-statements written by the program and not by the programmer. Decision trees are (usually) calculated top down, by selecting the feature that separates the data points best, in terms of grouping data with the same label (the value we try to predict).
A decision tree classifier can have a tendency to over-fit, which means that the model fit the training set very well, but becomes less good for other data sets. A common way to avoid this is to make several decision trees, a random forest, where some randomness is added to the generation of the trees to make them different. The prediction for an item is the label it gets most times when scoring it in each of the separate trees.

In spark we start by splitting the data set randomly into two sets, one training set for fitting the model, containing 70% of the data, and a test set used to evaluate the model.

Dataset<Row>[] splits = passengers.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];

Before we can feed data into the a random forest classifier we need to transform the data. Spark’s machine learning algorithms wants data with a column “label” that contains what we are predicting, and a column “features” that contains a vector of the data attributes we want to include. I find the RFormula in Spark very useful for making label and features. The syntax is a bit strange, but the value to the left of “~” is the label, and on the right side one can add the attributes one wants, either by naming each of them, with + between them, or my staring with “.” which gives all, and then “subtract” the attributes you do not want. Survived is our label, and as features we select pclass, sex, age, fare, sibsp and parch.

RFormula formula = new RFormula().setFormula("survived ~ pclass + sex + age + fare + sibsp + parch");

We then use the RandomForestClassifer, with default configuration. Spark has this nice pipeline which makes it easy to work with multiple steps in the machine learning process. We only have to fit and transform data once on the pipeline, instead of for each individual step, and it is easy to experiment and replace parts of the chain. The following lines of code creates a pipeline, and fit pipeline on the training data, which gives back a pipelineModel.

RandomForestClassifier forestClassifier = new RandomForestClassifier();
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{formula, forestClassifier});
PipelineModel model = pipeline.fit(training);

To find out how the model is doing on the test set, we can use the binary classification evaluator, since we are doing classification with a binary label (survived/died). The evaluator will classify each entry in the data set, compare with the actual label, and count up the amount of correctly classified entries.

Dataset<Row> predictions = model.transform(test);
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
double res = eval.evaluate(predictions);

The result will vary a bit each time the program is run since there is randomness both in splitting of the data set, and in the algorithm, but an example of accuracy we got is 86,5%.

The random forest model has a string representation of the trees one can use to look at the actual trees.

RandomForestClassificationModel treeModel = (RandomForestClassificationModel) model.stages()[1];
System.out.println(treeModel.toDebugString());

I visualized one of these trees, and my son wanted to see if he would have survived or not. Luckily he just turned eleven, and would according to this tree survive if we were travelling by first or second class. I would also have survived, but it certainly doesn’t look good for dad/husband.
decision tree

K-mean clustering in Neo4j

The last few days I’ve been playing with Neo4j. I’ve always been intrigued by graphs, and I’ve wanted to learn more about graph databases for some time now, and finally, I was able to find some time for it.

One challenge I usually have when I want to test out new technology is to find a suitable use case. Simple enough, but still one that illustrates some amount of the features and functionality of the new technology. In this case I ended up with a simple toy example to illustrate how the k-mean clustering algorithm, used in machine learning, can partition a data set into k disjoint clusters.

The k-mean clustering algorithm

The algorithm itself is fairly simple. Assume that we have a data set with observations, and each sample has data for d features, hence each sample can be considered as a real d-dimensional vector [x_1, \ldots, x_d]. This means that we can calculate the squared Euclidean distance of two vectors \textbf{x} = [x_1, \ldots, x_d] and \textbf{y} = [y_1, \ldots, y_d] as

    \[ \|\textbf{x} - \textbf{y}\|^2 = (x_1 - y_1)^2 + (x_2 - y_2)^2 + \dots + (x_d - y_d)^2, \]

and if we have m vectors \textbf{x}^1,\textbf{x}^2, \ldots, \textbf{x}^m we can calculate the mean or average as

    \[ \frac{1}{m}\sum_{j = 1}^{m}{\textbf{x}^j} = \frac{1}{m}[\sum_{j = 1}^{m}{x_1^j}, \ldots, \sum_{j = 1}^{m}{x_d^j}]. \]

To use the algorithm, we first have to decide the number of clusters, and then initialize what is called centroids, one for each cluster, \{ c_1, \ldots, c_k\}. The centroids will be the centres of each cluster. One way of initializing the centroids is to randomly pick k samples from the data set. Once we have the initial k centroids the following two steps are repeated.

Cluster assignment step

In this step each sample s is assigned to the cluster corresponding to the centroid with the smallest Euclidean distance to the sample, i.e. the centroid closest to s. Hence, s is assigned to one of the centroids of the following set. (It might happen that the sample is exactly in the middle of two or more centroids, but usually, this set consists of only one centroid).

    \[ \{c_i:  \|c_i - s\|^2 =  \min_{j}{\|c_j - s\|^2}\} \]

Centroid update step

When all the samples are assigned to a centroid, this step will calculate new centroids for each of the k clusters by taking the mean of all the assigned samples. So for assigned samples s^1, \ldots, s^r, the vector for the new centroid c_i will be

    \[ \textbf{c}_i = \frac{1}{r}\sum_{j = 1}^{r}{\textbf{s}^j} \]

These two steps are repeated until the algorithm converges, i.e., the assignment step doesn’t change the previous assignments. But the algorithm might converge to a local optimum, and the algorithm doesn’t guarantee that the global optimum is found. So a common approach would be to run the algorithm several times with different initialization of the centroids and choose the best one.

The data set

In order to do the clustering I needed a data set, and I ended up with data set of seeds from UCI Machine Learning Repository. The set consist of 210 samples of kernels of three types of wheat, 70 samples of each type. The samples each have seven attributes; area, perimeter, compactness, length of kernel, width of kernel, asymmetry coefficient and length of kernel groove, and hence each sample can be considered as a 7-dimentional vector. The set is labelled, so we know which samples that belongs to the same type of wheat. The k-mean algorithm if often used with unlabelled sets and doesn’t use the labels, but I’ve included them in the example so we easily can see how the clustering works compared to the original labelling.

Implementation in Neo4j

I have two kinds of nodes in my model, seeds and centroids. They both contain the seven attributes from the data set, and seeds also have their original labelling. Further, centroids have an index which gives the number of the corresponding cluster, and an attribute “iteration”, which contains the iteration the centroid is used in. (Another approach would be to remove the old centroids after calculating the new ones, and thus, there is no need to keep track of the iterations. But we want to see the change so we are keeping all centroids). The centroids are chosen very non-random, I’ve just picked data from one seed of each label, hence the clustering will be good already after the first cluster iteration.

The creation of the nodes is pretty straight forward. The attributes of a node in Neo4j are in JSON format, so I started off by reading the data set file in Java, creating Seed objects, and then use Gson to convert the objects to JSON strings. Before the colon and the type name one can add an identifier if one wants to refer to the same node later in the same script, but we don’t need that in our case.

CREATE (:Seed {area:15.26,perimeter:14.84,compactness:0.871,kernelLength:5.763,kernelWidth:3.312,asymmetryCoefficient:2.221,kernelGrooveLength:5.22,label:1})
CREATE (:Centroid {area:13.84,perimeter:13.94,compactness:0.8955,kernelLength:5.324,kernelWidth:3.379,asymmetryCoefficient:2.259,kernelGrooveLength:4.805, index:1, iteration:1})

The complete script file for creating all the nodes is in this file.

Cluster assignment

The code below shows how I assign the seeds to the right centroid. It feels a bit clumsy, but it was the best I could come up with in pure Neo4j. If this was a part of an application, one would probably let the application code calculate the distances, and find the minimum. And as far as I can tell, Neo4j does’t support user defined functions, so I cannot create like a distance-function instead of duplicating the calculation of distance three times.
The MATCH statement does pattern matching, and one can specify paths, nodes and relations, with our without types (which I think is called labels in the graph database world) on nodes or relationships, and with identifiers to refer to the elements later in the statement. So in our case we need all the seeds, and for each seed we will calculate the distances to each of the three centroids, and then find the closest centroid. The MATCH clause finds 210 rows, one for each seed, containing the seed and the three centroids.
The three SET statements add new attributes distC1, distC2 and distC3 to each seed, containing the distance from the seed to each centroid. The following WITH clause is used to bind variables from the matching to be used later. So we want to keep each seed s, and then for each seed the closest centroid, kept as minC, and finally, we create a IN_CLUSTER relationship from the seed s to the centroid minC. After the CREATE one could have tidied up the seeds, and deleted the three distance attributes.

MATCH (s:Seed), (c1:Centroid{index: 1, iteration: 1}), (c2:Centroid{index: 2, iteration: 1}), (c3:Centroid{index: 3, iteration: 1}) 
SET s.distC1 = (s.area - c1.area)^2 + (s.perimeter - c1.perimeter)^2 + (s.compactness - c1.compactness)^2 + (s.kernelLength - c1.kernelLength)^2 + (s.kernelWidth - c1.kernelWidth)^2 + (s.asymmetryCoefficient - c1.asymmetryCoefficient)^2 + (s.kernelGrooveLength - c1.kernelGrooveLength)
SET s.distC2 = (s.area - c2.area)^2 + (s.perimeter - c2.perimeter)^2 + (s.compactness - c2.compactness)^2 + (s.kernelLength - c2.kernelLength)^2 + (s.kernelWidth - c2.kernelWidth)^2 + (s.asymmetryCoefficient - c2.asymmetryCoefficient)^2 + (s.kernelGrooveLength - c2.kernelGrooveLength)
SET s.distC3 = (s.area - c3.area)^2 + (s.perimeter - c3.perimeter)^2 + (s.compactness - c3.compactness)^2 + (s.kernelLength - c3.kernelLength)^2 + (s.kernelWidth - c3.kernelWidth)^2 + (s.asymmetryCoefficient - c3.asymmetryCoefficient)^2 + (s.kernelGrooveLength - c3.kernelGrooveLength)
WITH s, 
case 
when s.distC1 <= s.distC2 and s.distC1 <= s.distC3 then
c1
when s.distC2 <= s.distC1 and s.distC2 <= s.distC3 then c2 else c3 end as minC 
CREATE (s)-[:IN_CLUSTER]->(minC)
RETURN *

The picture shows the three cluster after the seeds have been assigned, and the table shows the statistics, the first cluster has few seeds, they should have 70 seeds each if the algorithm gets everything correct, but the first cluster has a high rate of correct seeds, compared to the third cluster where only 80% of the seeds are correct.
clustering_1

Cluster no Total assigned Correct assigned Percentage correct
1 50 44 88%
2 80 69 86.25%
3 80 64 80%

Centroid update

The next step is to find the new centroids for the next assignment step, by calculating the average values of the seven attributes for the seeds assigned to each cluster. I found the useful avg-function that calculates the average of an attribute value over all nodes with the same identifier. But in this case it is important to think through what you include in a single match. If the match statement was like MATCH (s1:Seed)-[:IN_CLUSTER]->(c1:Centroid {index: 1, iteration: 1}), (s2:Seed)-[:IN_CLUSTER]->(c2:Centroid {index: 2, iteration: 1}), (s3:Seed)-[:IN_CLUSTER]->(c3:Centroid {index: 3, iteration: 1}) we would get the Cartesian product over s1, s2 and s3, and the number of rows returned would 50 * 80 * 80 = 320 000. This would still give the right numbers when using the average function since taking the average over multiple copies of a value will give back the original value, but for other aggregate functions, like sum, one would of course get wrong values.

MATCH (s1:Seed)-[:IN_CLUSTER]->(c1:Centroid {index: 1, iteration: 1})
WITH avg(s1.area) as s1Area, avg(s1.perimeter) as s1Perimeter, avg(s1.compactness) as s1Compactness, avg(s1.kernelLength) as s1KernelLength, avg(s1.kernelWidth) as s1KernelWidth, avg(s1.asymmetryCoefficient) as s1AsymmertryCoefficient, avg(s1.kernelGrooveLength) as s1KernelGrooveLength
MATCH (s2:Seed)-[:IN_CLUSTER]->(c2:Centroid {index: 2, iteration: 1})
WITH s1Area, s1Perimeter, s1Compactness, s1KernelLength, s1KernelWidth, s1AsymmertryCoefficient, s1KernelGrooveLength, avg(s2.area) as s2Area, avg(s2.perimeter) as s2Perimeter, avg(s2.compactness) as s2Compactness, avg(s2.kernelLength) as s2KernelLength, avg(s2.kernelWidth) as s2KernelWidth, avg(s2.asymmetryCoefficient) as s2AsymmertryCoefficient, avg(s2.kernelGrooveLength) as s2KernelGrooveLength
MATCH (s3:Seed)-[:IN_CLUSTER]->(c3:Centroid {index: 3, iteration: 1})
WITH s1Area, s1Perimeter, s1Compactness, s1KernelLength, s1KernelWidth, s1AsymmertryCoefficient, s1KernelGrooveLength, 
s2Area, s2Perimeter, s2Compactness, s2KernelLength, s2KernelWidth, s2AsymmertryCoefficient, s2KernelGrooveLength, 
avg(s3.area) as s3Area, avg(s3.perimeter) as s3Perimeter, avg(s3.compactness) as s3Compactness, avg(s3.kernelLength) as s3KernelLength, avg(s3.kernelWidth) as s3KernelWidth, avg(s3.asymmetryCoefficient) as s3AsymmertryCoefficient, avg(s3.kernelGrooveLength) as s3KernelGrooveLength
CREATE (:Centroid {area: s1Area, perimeter: s1Perimeter, compactness: s1Compactness, kernelLength: s1KernelLength, kernelWidth: s1KernelWidth, asymmetryCoefficient: s1AsymmertryCoefficient, kernelGrooveLength: s1KernelGrooveLength, index: 1, iteration: 2})
CREATE (:Centroid {area: s2Area, perimeter: s2Perimeter, compactness: s2Compactness, kernelLength: s2KernelLength, kernelWidth: s2KernelWidth, asymmetryCoefficient: s2AsymmertryCoefficient, kernelGrooveLength: s2KernelGrooveLength, index: 2, iteration: 2})
CREATE (:Centroid {area: s3Area, perimeter: s3Perimeter, compactness: s3Compactness, kernelLength: s3KernelLength, kernelWidth: s3KernelWidth, asymmetryCoefficient: s3AsymmertryCoefficient, kernelGrooveLength: s3KernelGrooveLength, index: 3, iteration: 2})
RETURN *

Next cluster assignment

The next cluster assignment is done by the same statement as the last, only with the iteration number in the centroids increased to two. An now we can easily see which of the seeds that are moved to a new cluster. And the statistics show that cluster two and three are improved, whereas the first cluster now has more seeds, but the correctness has decreased.
graph

Cluster no Total assigned Correct assigned Percentage correct
1 66 55 83.3%
2 73 66 90.4%
3 71 63 88.7%

Now, it is up to you to continue the clustering if you’d like, all the code is in this gist. If you find errors in this post, have suggestions for improvements or have other comments or questions, please let me know! 🙂

References

  1. http://neo4j.com/
  2. http://archive.ics.uci.edu/ml/datasets/seeds
  3. https://en.wikipedia.org/wiki/K-means_clustering