Day 7 of 100 Days of AI

Decision Trees. I went through the intro to ML material on decision trees this morning.

I created some data in Excel and a simple decision tree machine learning model that predicts whether someone works in tech (value 1) or not (value 0). The data is synthetic, and I shaped it so that the ‘salary’ feature was the key predictor instead of age, sex, and region. Here’s the output tree.

My model had an accuracy of 77%. The confusion matrix below provides more on performance. For example, the prediction ‘precision’ of whether someone works in tech roles was quite good, at 85.7% (12 true positives and 2 false positives out of 14 positive predictions).

However, the model was less good at identifying people who don’t work in tech, with a true negative rate of 68.8% (11 true negatives out of 16 negative predictions, including 5 cases where non-tech job cases were mistakenly identified as tech).

Key takeaways:

  • Decision trees are a supervised machine learning technique for classifying data (via classification trees) or predicting numeric values (via regression trees).
  • As you can see from the first chart, decision trees are good because they are more easily interpretable, and you can follow the steps to know how a classification was made. Cases where this is useful include:
    • Finance situations, e.g. Loan application decisions, investment decisions
    • Healthcare situations, e.g. Diagnosis by going through symptoms and other features
    • Marketing, e.g. Customer segmentation via some attributes, churn prediction.
  • How you create decision trees?
    • The easiest thing to do is to use a Python library that does this for you. Here are some simple examples I did in Python.
    • Otherwise, the general process revolves mainly around feature selection. This is as follows:
      • Start at the root node.
      • Find the best feature that splits the data according to a metric, such as ‘information gain’ or ‘gini impurity’. You do this by going through all the features and splitting your data according to each one in isolation to see how well the data is split. Once you find the best feature, you build a branch with that feature.
      • You then repeat the above process, but below the previous branch and with a subset of data.
      • You keep going until you’re happy with the depth (note that if you go too deep, you might have issues with overfitting).

Read more