Data analysis and machine learning for 11-year-olds

My 11-year-old son recently started to watch movies about science and facts on YouTube, in addition to the usual gamer shows. (I hope it is not just because he has run out of gaming videos).
A couple of weeks ago he was super excited when I came home from work, he had watched a video about Titanic and could not stop talking about it. He had leaned all about why it sunk, all the different circumstances, and how the outcome could have been different if some of even then smaller things had been different.
Then it occurred to me that one of the standard data set when getting started with data analysis and machine learning is the passenger list from Titanic, with information about fare, passenger class, age, gender, if the person had spouse or siblings on board and if it had child or parents on board.
Since my son recently learned some elementary statistics in school, I thought it would be fun to show him how we can work with the Titanic data set. I have earlier given some workshops on Apache Spark in Java, so that is what we used. Spark has great documentation of their machine learning library here.
The Titanic data set is available, with some variations, from many places. I downloaded the titanic_original.csv from this GitHub repository. (I tried the cleaned data set first, but persons with missing age is given the average age, and that would mess up what I wanted to do with the data, so I continued with the original data set).

Load data from csv

The first thing we have to do is to start spark and load the data. It is quite easy to read data with Spark. If you are lucky, the option inferSchema will figure out the correct types.

SparkConf conf = new SparkConf().setAppName("Titanic").setMaster("local[*]");
SparkSession spark = SparkSession.builder().config(conf).getOrCreate();

Dataset<Row> passengers = spark.read()
        .option("inferSchema", "true")
        .option("delimiter", ",")
        .option("header", true)
        .csv(TITANIC_DATA_PATH);

When the data has been loaded into a Dataset, it is a good idea to check that it actually contains what it should, and that the columns have the right data types. We can do that with the methods passengers.show() and passengers.printSchema(). The first one prints the first twenty rows of the data set, and the latter prints the type for each column.

+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
|pclass|survived|                name|   sex|   age|sibsp|parch|  ticket|    fare|  cabin|embarked|boat|body|           home.dest|
+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
|     1|       1|Allen, Miss. Elis...|female|  29.0|    0|    0|   24160|211.3375|     B5|       S|   2|null|        St Louis, MO|
|     1|       1|Allison, Master. ...|  male|0.9167|    1|    2|  113781|  151.55|C22 C26|       S|  11|null|Montreal, PQ / Ch...|
|     1|       0|Allison, Miss. He...|female|   2.0|    1|    2|  113781|  151.55|C22 C26|       S|null|null|Montreal, PQ / Ch...|
|     1|       0|Allison, Mr. Huds...|  male|  30.0|    1|    2|  113781|  151.55|C22 C26|       S|null| 135|Montreal, PQ / Ch...|
|     1|       0|Allison, Mrs. Hud...|female|  25.0|    1|    2|  113781|  151.55|C22 C26|       S|null|null|Montreal, PQ / Ch...|
|     1|       1| Anderson, Mr. Harry|  male|  48.0|    0|    0|   19952|   26.55|    E12|       S|   3|null|        New York, NY|
|     1|       1|Andrews, Miss. Ko...|female|  63.0|    1|    0|   13502| 77.9583|     D7|       S|  10|null|          Hudson, NY|
|     1|       0|Andrews, Mr. Thom...|  male|  39.0|    0|    0|  112050|     0.0|    A36|       S|null|null|         Belfast, NI|
|     1|       1|Appleton, Mrs. Ed...|female|  53.0|    2|    0|   11769| 51.4792|   C101|       S|   D|null| Bayside, Queens, NY|
|     1|       0|Artagaveytia, Mr....|  male|  71.0|    0|    0|PC 17609| 49.5042|   null|       C|null|  22| Montevideo, Uruguay|
|     1|       0|Astor, Col. John ...|  male|  47.0|    1|    0|PC 17757| 227.525|C62 C64|       C|null| 124|        New York, NY|
|     1|       1|Astor, Mrs. John ...|female|  18.0|    1|    0|PC 17757| 227.525|C62 C64|       C|   4|null|        New York, NY|
|     1|       1|Aubart, Mme. Leon...|female|  24.0|    0|    0|PC 17477|    69.3|    B35|       C|   9|null|       Paris, France|
|     1|       1|"Barber, Miss. El...|female|  26.0|    0|    0|   19877|   78.85|   null|       S|   6|null|                null|
|     1|       1|Barkworth, Mr. Al...|  male|  80.0|    0|    0|   27042|    30.0|    A23|       S|   B|null|       Hessle, Yorks|
|     1|       0| Baumann, Mr. John D|  male|  null|    0|    0|PC 17318|  25.925|   null|       S|null|null|        New York, NY|
|     1|       0|Baxter, Mr. Quigg...|  male|  24.0|    0|    1|PC 17558|247.5208|B58 B60|       C|null|null|        Montreal, PQ|
|     1|       1|Baxter, Mrs. Jame...|female|  50.0|    0|    1|PC 17558|247.5208|B58 B60|       C|   6|null|        Montreal, PQ|
|     1|       1|Bazzani, Miss. Al...|female|  32.0|    0|    0|   11813| 76.2917|    D15|       C|   8|null|                null|
|     1|       0|Beattie, Mr. Thomson|  male|  36.0|    0|    0|   13050| 75.2417|     C6|       C|   A|null|        Winnipeg, MN|
+------+--------+--------------------+------+------+-----+-----+--------+--------+-------+--------+----+----+--------------------+
root
 |-- pclass: integer (nullable = true)
 |-- survived: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: double (nullable = true)
 |-- sibsp: integer (nullable = true)
 |-- parch: integer (nullable = true)
 |-- ticket: string (nullable = true)
 |-- fare: double (nullable = true)
 |-- cabin: string (nullable = true)
 |-- embarked: string (nullable = true)
 |-- boat: string (nullable = true)
 |-- body: integer (nullable = true)
 |-- home.dest: string (nullable = true)

As mentioned, age is missing for some of the passengers, the same holds for fare, so for our purpose we will just remove the rows where age or fare is null, with a syntax quite similar to SQL.

passengers = passengers.where(col("age").isNotNull().and(col("fare").isNotNull()));

Most of the columns are self explanatory, but pclass is the passenger class, sibsp is number of siblings and spouse the passenger had on board, and similarly, parch is the number of parents and children.

Average, median and mode

Now that we’ve got the data we can finally start doing some statistics. The statistics taught in 5th grade basically covers average, median and mode of a set of numbers, and how to draw different types of charts. I asked my son what he believed would be the average age of the passengers, he guessed 30 years. Let’s see how we can get the average, together with median and mode from Spark.

There is a useful method describe on a dataset, which gives us various information about the columns we ask for. It turns out that 30 was a good guess for the average 🙂 .

passengers.describe("age").show();
+-------+------------------+
|summary|               age|
+-------+------------------+
|  count|              1045|
|   mean|29.851834162679427|
| stddev|14.389200976290613|
|    min|            0.1667|
|    max|              80.0|
+-------+------------------+

To find the mode, the value that occurs most times, we draw a plot with ages along the x-axis, and the count of passengers with that age along the y-axis. The plot is made with the Java library XChart.

By looking at the chart it seems that 24 years is the mode, we can verify that by querying the data set and look at the first row that is printed.

passengers.groupBy("age").count().orderBy(col("count").desc()).show();

The median is the middle data point when the data is sorted. We can sort the data set on age, but a dataset does not have an index we can query for. The data set has 1045 entries, so the easiest thing would be to do .show(523) and look at the last row that is printed, or we can add an id column like this (according to the docs it is not guaranteed to be consecutive, but it will be if your data is not partitioned, as in our case).

passengers.orderBy("age").withColumn("id", monotonically_increasing_id());

A more proper, but much less intutive, way of finding the median would be to use the quantile function like this

passengers.stat().approxQuantile("age", new double[]{0.5}, 0);

No matter how we find the median, the result is 28.

That was the basic statistics, let us move on to analyze who the survivors were.

Survivors by gender

Let us start by finding the amount and fraction of survivors in total. It can be done with the following line

passengers.groupBy("survived").count().withColumn("fraction", col("count").divide(sum("count").over()));

The table below shows the depressing result that only 41% of the 1045 passengers in our data set survived.

Survived Count Fraction
1 427 0,41
0 618 0,59

So how is the rate of survivors by gender?
I told my son that it was common to save children and women before men, and he was shocked; “What, is that true? That’s totally unfair!”.
Well, I actually found the paper Gender, social norms, and survival in maritime disasters where the authors have studied maritime disasters and to which extend the “social norm of women and children first” is followed. They conclude that “Women have a distinct survival disadvantage compared with men. Captains and crew survive at a significantly higher rate than passengers.”

If survival was independent of gender, we would assume that around 41% of the woman would survive, and likewise, 41% of the men. That is logical, also for a 5th grader. But what are the actual rates based on gender?

Gender Survived Total Percentage survivors by gender
Female 292 388 75,3%
Male 135 657 20,5%

We clearly see that survival is dependent on gender, so will can assume that gender will be an important feature for predicting if a person survived or not. This is the intuition behind the Chi-squared test, which also is implemented in Spark as the ChiSqSelector, which can be used to find the most important features for a data set.

Prediction with decision trees

Decision trees are one of the simplest types of algorithms in machine learning and it is easy to understand the result of the algorithm, I think of it as if-else-statements written by the program and not by the programmer. Decision trees are (usually) calculated top down, by selecting the feature that separates the data points best, in terms of grouping data with the same label (the value we try to predict).
A decision tree classifier can have a tendency to over-fit, which means that the model fit the training set very well, but becomes less good for other data sets. A common way to avoid this is to make several decision trees, a random forest, where some randomness is added to the generation of the trees to make them different. The prediction for an item is the label it gets most times when scoring it in each of the separate trees.

In spark we start by splitting the data set randomly into two sets, one training set for fitting the model, containing 70% of the data, and a test set used to evaluate the model.

Dataset<Row>[] splits = passengers.randomSplit(new double[]{0.7, 0.3});
Dataset<Row> training = splits[0];
Dataset<Row> test = splits[1];

Before we can feed data into the a random forest classifier we need to transform the data. Spark’s machine learning algorithms wants data with a column “label” that contains what we are predicting, and a column “features” that contains a vector of the data attributes we want to include. I find the RFormula in Spark very useful for making label and features. The syntax is a bit strange, but the value to the left of “~” is the label, and on the right side one can add the attributes one wants, either by naming each of them, with + between them, or my staring with “.” which gives all, and then “subtract” the attributes you do not want. Survived is our label, and as features we select pclass, sex, age, fare, sibsp and parch.

RFormula formula = new RFormula().setFormula("survived ~ pclass + sex + age + fare + sibsp + parch");

We then use the RandomForestClassifer, with default configuration. Spark has this nice pipeline which makes it easy to work with multiple steps in the machine learning process. We only have to fit and transform data once on the pipeline, instead of for each individual step, and it is easy to experiment and replace parts of the chain. The following lines of code creates a pipeline, and fit pipeline on the training data, which gives back a pipelineModel.

RandomForestClassifier forestClassifier = new RandomForestClassifier();
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{formula, forestClassifier});
PipelineModel model = pipeline.fit(training);

To find out how the model is doing on the test set, we can use the binary classification evaluator, since we are doing classification with a binary label (survived/died). The evaluator will classify each entry in the data set, compare with the actual label, and count up the amount of correctly classified entries.

Dataset<Row> predictions = model.transform(test);
BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
double res = eval.evaluate(predictions);

The result will vary a bit each time the program is run since there is randomness both in splitting of the data set, and in the algorithm, but an example of accuracy we got is 86,5%.

The random forest model has a string representation of the trees one can use to look at the actual trees.

RandomForestClassificationModel treeModel = (RandomForestClassificationModel) model.stages()[1];
System.out.println(treeModel.toDebugString());

I visualized one of these trees, and my son wanted to see if he would have survived or not. Luckily he just turned eleven, and would according to this tree survive if we were travelling by first or second class. I would also have survived, but it certainly doesn’t look good for dad/husband.
decision tree

A tiny theorem prover

lispAt work we have many great (and some quite funny and not so great) old books, and when I came across this LISP book, I had to borrow it. (At some point there must have been a real library, the books have numbering, and an old fashioned book card at the back, with date and name of previous borrowers. So maybe I didn’t borrow the book properly, I just took it off the shelf). The book has many great chapters, for instance “Symbolic pattern matching and simple theorem proving”, “Program writing programs and natural language interfaces” and “LISP in LISP”. Here is a little theorem prover from chapter “Symbolic pattern matching and simple theorem proving” in the book, based on the resolution rule, and rewritten in Racket.

The resolution rule produces a new clause if two clauses contains complementary literals, i.e, one clause contains c and the other contains \neg c.

    \[ \frac{a_1\lor \ldots \lor a_i \lor c \lor a_{i+1}\lor \ldots \lor a_n, b_1\lor \ldots \lor b_j \lor \neg c \lor b_{j+1}\lor \ldots \lor b_m}{a_1\lor \ldots \lor a_i \lor a_{i+1}\lor \ldots \lor a_n, b_1\lor \ldots \lor b_j \lor b_{j+1}\lor \ldots \lor b_m} \]

The simplest case is when we have the clauses a \lor c and b \lor \neg c. Since c and \neg c cannot both be true at the same time we know that either a or b has to be true, hence \frac{a \lor c, b \lor \neg c}{a \lor b}.

(If you’re not familiar with logic operands and terms, here is the shortest intro, from a non-expert: the sign \neg is negation, and a literal is either a letter or the negation of a letter, and a letter is a variable that can be either true or false. The operators \lor is “or” (disjunction) and \land is “and” (conjuction), and a clause is a finite disjuction of literal, i.e., letters or negation of letters connected with or. In the fraction above, the clauses above the line are arguments, and the clause below the line is the conclusion, and it follows/can be deduced from the arguments. Hence, the conclusion is true if all the arguments are true)

The logic operators and the literals are symbols/quotes in the code (since we’re not in the “Lisp in Lisp”-chapter of the book), and the following resolve-function is my Racket version of the equivalent function in Common Lisp from the book. The code in the book uses SETQ to update variables, and GO for jumping in code, a lot, but I’m using let-expressions and recursing instead.
The resolve function uses two helper functions; invert puts the symbol ‘not in front of a literal, or removes the ‘not if the literal is a negation, and the combine function is a cleanup function that adds a literal to a list of literals if itself or the negation is not in the list already, and removes duplicates.

(define (resolve x y)
  (letrec ([res (λ (rest-x rest-y)
                 (cond
                   [(null? rest-x) 'no-resolvent]
                   [(member? (invert (car rest-x)) rest-y) (cons 'or (combine (car rest-x) (append (cdr x) (cdr y))))]
                   [else (res (cdr rest-x) rest-y)]))])
          (res (cdr x) (cdr y))))

And some tests to show how the function works:

(check-equal? (resolve '(or P) '(or Q)) 'no-resolvent)
(check-equal? (resolve '(or P) '(or (not P))) '(or))
(check-equal? (resolve '(or P (not Q)) '(or Q (not R))) '(or P (not R)))

So how can we use the resolution rule in a proof?

A conjecture is proved if the premise implies the conclusion, hence a conjecture is false if there is a combination of literals that makes the premise true and the conclusion false at the same time. Which again is equivalent to say that the conjecture is false if there is a combination of literals that makes the premise and the negation of the conclusion true at the same time.

We require both the premise and the negation of the conclusion to be on conjunctive normal form, which means that they both are of the form of clauses connected with “and”s. So for the premise and the negation of the conclusion to be true at the same time all the clauses have to be true.

If then the combined clauses contain both a literal and its negation then it is impossible for all the clauses to be true at the same time, hence the conjecture is not false, it is true. And this is the key idea of the proof algorithm; to loop through all clauses, and use the resolution rule one each pair of clauses. If the resolvent is empty, we know that the teorem is true, if we get a non-empty resolvent back, it is added to the list of clauses, and the resolvent is again used in the resolution rule with each of the other clauses in the list. If we have looped all the clauses and not have met an empty resolvent then the conjecture is false.

(define (prove premise negation)
  (letrec ([try-next-remainder (λ (remainder clauses)
                     (cond [(null? remainder)
                            (displayln '(theorem not proved))
                            #f]
                           [else (try-next-resolvent (car remainder) (cdr remainder) remainder clauses)]))]
           [try-next-resolvent (λ (first rest remainder clauses)
                       (if (null? rest) (try-next-remainder (cdr remainder) clauses)
                       (let ([resolvent (resolve first (car rest))])
                         (cond
                           [(or (equal? resolvent 'no-resolvent)
                                (member? resolvent clauses)) (try-next-resolvent first (cdr rest) remainder clauses)]
                           [(null? (cdr resolvent)) (display-clauses first (car rest))
                                                    (displayln '(produce the empty resolvent))
                                                    (displayln '(theorem proved))
                                                    #t]
                           [else (display-clauses first (car rest))
                                 (displayln (append '(produce a resolvent:) (list resolvent)))
                                 (try-next-remainder (cons resolvent clauses) (cons resolvent clauses))]))))]
           [clauses (append (cdr premise) (cdr negation))])
    (try-next-remainder clauses clauses)))
(test-case
 "prove true"
 (let ([premise '(and (or Q (not P))
                      (or R (not Q))
                      (or S (not R))
                      (or (not U) (not S)))]
       [negation '(and (or P) (or U))])
   (check-true (prove premise negation))))

The complete code, with tests, can be found here.

References

  1. https://en.wikipedia.org/wiki/Propositional_variable
  2. https://en.wikipedia.org/wiki/Propositional_calculus
  3. https://en.wikipedia.org/wiki/Resolution_(logic)
  4. https://en.wikipedia.org/wiki/Conjunctive_normal_form
  5. Lisp by Winston, Patrick Henry; Horn, B.K.P.