Дерево решений CART (Classification and Regressoin Tree) — алгоритм классификации и регрессии, основанный на бинарном дереве и являющийся фундаментальным компонентом случайного леса и бустингов, которые входят в число самых мощных алгоритмов машинного обучения на сегодняшний день. Деревья также могут быть не бинарными в зависимости от реализации. К другим популярным реализациям решающего дерева относятся следующие: ID3, C4.5, C5.0.

Ноутбук с данными алгоритмами можно загрузить на Kaggle (eng) и GitHub (rus).

Структура дерева решений

Решающее дерево состоит из следующих компонентов: корневой узел, ветви (левая и правая), решающие и листовые (терминальные) узлы. Корневой и решающие узлы представляют из себя вопросы с пороговым значением для разделения тренировочного набора на части (левая и правая), а листья являются конечными прогнозами: среднее значений в листе для регрессии и статистическая мода для классификации.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Структура CART

Каждый листовой узел соответствует определённой прямоугольной области на графике границ решений между двумя признаками. Если на графике соседние участки имеют одинаковое значение, то они автоматически объединяются и представляются как одна большая область.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

График границ решений

Выбор наилучшего разбиения

Выбор наилучшего разбиения при построении решающего узла в дереве напоминает игру, в которой нужно угадать знаменитость, задавая вопросы, на которые можно лишь услышать ответ “да” либо “нет”. Логично, что для быстрого поиска правильного ответа необходимо задавать вопросы, которые исключат наибольшее количество неверных вариантов, например, вопрос про “пол” позволит исключить сразу же половину вариантов, в то время как вопрос про “возраст” будет менее информативным. Проще говоря, выбор наилучшего вопроса заключается в поиске признака, определённое значение которого лучше всего отделяет правильный ответ от неправильных.

Показатель того, насколько хорошо вопрос в решающем узле позволяет отделить верный ответ от неверных, называется мерой загрязнённости узла. В случае классификации для оценки качества разбиения узла используются следующие критерии:

  • Неопределённость (загрязнённость) Джини — мера разнообразия в распределении вероятностей классов. Если все элементы в узле принадлежат к одному классу, то неопределённость Джини равна 0, а в случае равномерного распределения классов в узле неопределённость Джини равна 0.5.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

  • Энтропия Шеннона — мера неопределённости или беспорядка классов в узле. Она характеризует количество информации, которое необходимо для описания состояния системы: чем выше значение энтропии, тем менее упорядочена система и наоборот.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

  • Ошибка классификации — величина, отображающая долю неправильно классифицированных элементов в узле: чем меньше данное значение, тем меньше загрязнённость в узле.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

В данном случае

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

— это доля k-го класса среди обучающих образцов в i-ом узле.

На практике чаще всего используются неопределённость Джини и энтропия Шеннона за счёт большей информативности. Как видно из графика для случая бинарной классификации (где P+ — вероятность принадлежности к классу “+”), график удвоенной неопределённости Джини очень схож с графиком энтропии Шеннона: в первом случае будут получаться чуть менее сбалансированные деревья, однако при работе с большими датасетами неопределённость Джини более предпочтительна за счёт меньшей вычислительной сложности.

Код для отрисовки графика

import numpy as np import matplotlib.pyplot as plt   def gini(probas):     return np.array([1- (p ** 2 + (1-p) ** 2) for p in probas])   def entropy(probas):     return np.array([-1 * (p * np.log2(p) + (1-p) * np.log2(1-p)) for p in probas])   def misclass_error_rate(probas):     return np.array([1 - max([p, 1-p]) for p in probas])   probas = np.linspace(0, 1, 250) plt.plot(probas, entropy(probas), label="Shannon's entropy") plt.plot(probas, 2 * gini(probas),  label="Gini impurity x 2") plt.plot(probas, 2 * misclass_error_rate(probas), label="Misclass error x 2") plt.plot(probas, gini(probas), label="Gini impurity") plt.plot(probas, misclass_error_rate(probas), label="Misclass error") plt.title("Splitting criteria from P+ (binary classification case)") plt.xlabel("P+") plt.ylabel("Impurity") plt.legend();

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

В случае регрессии для оценки качества разбиения узла чаще всего используется среднеквадратичная ошибка, но также могут быть использованы Friedman MSE и MAE.

Функция потерь

Так как же в конечном счёте происходит выбор наилучшего разбиения? После выбора одного из критериев оценки качества разбиения узла (например, неопределённость Джини или MSE), для всех уникальных значений признака берутся их пороговые значения, отсортированные по возрастанию и представленные как среднее арифметическое между соседними значениями. Далее обучающий набор разделяется на 2 поднабора (узла): всё что меньше либо равно текущего порогового значения идёт в левый поднабор, а всё что больше — в правый. Для полученных поднаборов рассчитываются загрязнённости на основе выбранного критерия, после чего их взвешенная сумма представляется как функция потерь, значение которой будет соответствовать пороговому значению признака. Порог с наименьшим значением функции потерь в обучающем наборе (поднаборе) будет наилучшим разбиением.

Функции потерь будут иметь следующий вид:

  • для классификации:

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

  • для регрессии:

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

где

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

.

В случае с энтропией используется немного иной подход: рассчитывается так называемый информационный прирост — разница энтропий родительского и дочерних узлов. Порог с максимальным информационным приростом в обучающем наборе/поднаборе будет соответствовать наилучшему разбиению. :

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

где

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

— условие (вопрос) для разбиения поднабора

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

.

Для наглядности рассмотрим следующий пример. Допустим, у нас есть 4 кота и 4 собаки, а также нам известны их некоторые визуальные признаки: “рост”, “уши” (висячие и заострённые стоячие) и “усы” (наличие либо отсутствие). Как в дереве решений строится корневой узел для классификации собак и котов? Очень просто: сначала для каждого признака животные разделяются на 2 группы согласно вопросу, после чего для каждой из групп рассчитывается неопределённость Джини. Пороговое значение признака с наименьшим значением функции потерь (взвешенной суммой неопределённостей) будет использоваться для разбиения корневого (решающего) узла. В данном случае признак “рост” имеет наименьшую загрязнённость и вопрос в решающем узле будет выглядеть как “рост ⩽ 20 см”.

Стоит добавить, что в случае с категориальными признаками происходит их бинаризация и вопросы выглядят следующим образом: “уши ⩽ 0.5”, “усы ⩽ 0.5”. Когда категориальные признаки могут принимать более 2 значений, применяются другие виды кодирования, например label или one-hot encoding.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Принцип работы дерева решений (CART)

Алгоритм строится следующим образом:

  • 1) создаётся корневой узел на основе наилучшего разбиения;

  • 2) тренировочный набор разбивается на 2 поднабора: всё что соответствует условию разбиения отправляется в левый узел, остальное — в правый;

  • 3) далее рекурсивно для каждого тренировочного поднабора повторяются шаги 1-2 пока не будет достигнут один из основных критериев останова: максимальная глубина, максимальное количество листьев, минимальное количество наблюдений в листе или минимальное снижение загрязнения в узле.

Регуляризация дерева решений

Выращенная без ограничений древовидная структура в той или иной степени будет склонна к переобучению и для решения данной проблемы используется 2 подхода: pre-pruning (ограничение роста дерева во время построения любым из критериев останова) и post-pruning (отсечение лишних ветвей после полного построения). Второй подход является более деликатным так как позволяет получить несимметричную и более точную древовидную структуру, оставляя лишь самые информативные решающие узлы.

Существует 2 типа post-puning’а:

  • Top-down pruning — метод, при котором проверка и обрезка наименее информативных ветвей начинается с корневого узла. Данный метод обладает относительно низкой вычислительной сложностью, однако, как и в случае с pre-pruning’ом, его главным недостатком также является возможность недообучения за счёт удаления ветвей, которые могли потенциально содержать информативные узлы. К самым известным видам данного прунинга относятся следующие:

    • Pessimistic Error Pruning (PEP), когда обрезаются ветви с наибольшей ожидаемой ошибкой, порог которой устанавливается заранее;

    • Critical Value Pruning (CVP), когда обрезаются ветви, информативность которых меньше определённого критического значения.

  • Bottom-up pruning — метод, при котором проверка и обрезка наименее информативных ветвей начинается с листьев. В данном случае получаются более точные деревья за счёт полного обхода снизу-вверх и оценки каждого решающего узла, однако это приводит к увеличению вычислительной сложности. Самыми популярными видами данного прунинга являются следующие:

    • Minimum Error Pruning (MEP), когда происходит поиск дерева с наименьшей ожидаемой ошибкой на отложенной выборке;

    • Reduced Error Pruning (REP), когда решающие узлы удаляются до тех пор, пока не падает точность, измеренная на отложенной выборке;

    • Cost-complexity pruning (CCP), когда строится серия поддеревьев через удаление слабейших узлов в каждом из них с помощью коэффициента, рассчитанного как разность ошибки корневого узла поддерева и общей ошибки его листьев, а выбор наилучшего поддерева производится на тестовом наборе или с помощью k-fold кросс-валидации.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Дерево до и после post-pruning’а

Minimal cost-complexity pruning

В реализации scikit-learn для деревьев решений используется модификация cost-complexity pruning, которая работает следующим образом:

  • 1) сначала строится полное дерево без ограничений;

  • 2) далее абсолютно для всех узлов в дереве рассчитывается ошибка на основе взвешенной загрязнённости в случае классификации или взвешенной MSE в случае регрессии;

  • 3) для каждого поддерева в дереве подсчитывается совокупная ошибка его листьев;

  • 4) для каждого поддерева в дереве рассчитывается коэффициент альфа, представленный как разность ошибки корневого узла поддерева и совокупная ошибка его листьев;

  • 5) поддерево с наименьшим 

    Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

     удаляется и становится листовым узлом, а сам коэффициент хранится в массиве cost_complety_pruning_path и соответствует новому обрезанному дереву;

  • 6) шаги 2-5 рекурсивно повторяются для каждого поддерева до тех пор, пока обрезка не дойдёт до корневого узла.

Если задавать определённое значение

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

 изначально, то данный коэффициент применится к каждому поддереву и в итоге останется поддерево с наименьшей ошибкой среди всех поддеревьев, а выбор наилучшего 

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

 из cost_complety_pruning_path для получения самого точного поддерева производится на тестовом наборе или с помощью k-fold кросс-валидации.

Формулы для расчётов

Регуляризация дерева:

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Эффективный

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

:

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Ошибка решающего

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

или листового

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

узлов для классификации:

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Ошибка решающего

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

или листового

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

узлов для регрессии:

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Совокупная ошибка листьев в дереве/поддереве:

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

T — число терминальных (листовых) узлов.

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Схема работы cost-complexity pruning’а для простейшего дерева

Импорт необходимых библиотек

import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_linnerud from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, mean_absolute_percentage_error from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree from mlxtend.plotting import plot_decision_regions from copy import deepcopy from pprint import pprint

Реализация на Python с нуля

В оригинальном дереве для создания узлов и хранения в них информации используется отдельный класс, но в данном случае дерево и вся информация об узлах хранятся в словаре с ключами в строчном формате. Такое изменение используется для возможности вывода полученного дерева и сравнения с реализацией scikit-learn.

class DecisionTreeCART:      def __init__(self, max_depth=100, min_samples=2, ccp_alpha=0.0, regression=False):         self.max_depth = max_depth         self.min_samples = min_samples         self.ccp_alpha = ccp_alpha         self.regression = regression         self.tree = None         self._y_type = None         self._num_all_samples = None      def _set_df_type(self, X, y, dtype):         X = X.astype(dtype)         y = y.astype(dtype) if self.regression else y         self._y_dtype = y.dtype          return X, y      @staticmethod     def _purity(y):         unique_classes = np.unique(y)          return unique_classes.size == 1      @staticmethod     def _is_leaf_node(node):         return not isinstance(node, dict)   # if a node/tree is a leaf      def _leaf_node(self, y):         class_index = 0          return np.mean(y) if self.regression else y.mode()[class_index]      def _split_df(self, X, y, feature, threshold):         feature_values = X[feature]         left_indexes = X[feature_values <= threshold].index         right_indexes = X[feature_values > threshold].index         sizes = np.array([left_indexes.size, right_indexes.size])          return self._leaf_node(y) if any(sizes == 0) else left_indexes, right_indexes      @staticmethod     def _gini_impurity(y):         _, counts_classes = np.unique(y, return_counts=True)         squared_probabilities = np.square(counts_classes / y.size)         gini_impurity = 1 - sum(squared_probabilities)          return gini_impurity      @staticmethod     def _mse(y):         mse = np.mean((y - y.mean()) ** 2)          return mse      @staticmethod     def _cost_function(left_df, right_df, method):         total_df_size = left_df.size + right_df.size         p_left_df = left_df.size / total_df_size         p_right_df = right_df.size / total_df_size         J_left = method(left_df)         J_right = method(right_df)         J = p_left_df*J_left + p_right_df*J_right          return J  # weighted Gini impurity or weighted mse (depends on a method)      def _node_error_rate(self, y, method):         if self._num_all_samples is None:             self._num_all_samples = y.size   # num samples of all dataframe         current_num_samples = y.size          return current_num_samples / self._num_all_samples * method(y)      def _best_split(self, X, y):         features = X.columns         min_cost_function = np.inf         best_feature, best_threshold = None, None         method = self._mse if self.regression else self._gini_impurity          for feature in features:             unique_feature_values = np.unique(X[feature])              for i in range(1, len(unique_feature_values)):                 current_value = unique_feature_values[i]                 previous_value = unique_feature_values[i-1]                 threshold = (current_value + previous_value) / 2                 left_indexes, right_indexes = self._split_df(X, y, feature, threshold)                 left_labels, right_labels = y.loc[left_indexes], y.loc[right_indexes]                 current_J = self._cost_function(left_labels, right_labels, method)                  if current_J <= min_cost_function:                     min_cost_function = current_J                     best_feature = feature                     best_threshold = threshold          return best_feature, best_threshold      def _stopping_conditions(self, y, depth, n_samples):         return self._purity(y), depth == self.max_depth, n_samples < self.min_samples      def _grow_tree(self, X, y, depth=0):         current_num_samples = y.size         X, y = self._set_df_type(X, y, np.float128)         method = self._mse if self.regression else self._gini_impurity          if any(self._stopping_conditions(y, depth, current_num_samples)):             RTi = self._node_error_rate(y, method)   # leaf node error rate             leaf_node = f'{self._leaf_node(y)} | error_rate {RTi}'             return leaf_node          Rt = self._node_error_rate(y, method)   # decision node error rate         best_feature, best_threshold = self._best_split(X, y)         decision_node = f'{best_feature} <= {best_threshold} | '                          f'as_leaf {self._leaf_node(y)} error_rate {Rt}'          left_indexes, right_indexes = self._split_df(X, y, best_feature, best_threshold)         left_X, right_X = X.loc[left_indexes], X.loc[right_indexes]         left_labels, right_labels = y.loc[left_indexes], y.loc[right_indexes]          # recursive part         tree = {decision_node: []}         left_subtree = self._grow_tree(left_X, left_labels, depth+1)         right_subtree = self._grow_tree(right_X, right_labels, depth+1)          if left_subtree == right_subtree:             tree = left_subtree         else:             tree[decision_node].extend([left_subtree, right_subtree])          return tree      def _tree_error_rate_info(self, tree, error_rates_list):         if self._is_leaf_node(tree):             *_, leaf_error_rate = tree.split()             error_rates_list.append(np.float128(leaf_error_rate))         else:             decision_node = next(iter(tree))             left_subtree, right_subtree = tree[decision_node]             self._tree_error_rate_info(left_subtree, error_rates_list)             self._tree_error_rate_info(right_subtree, error_rates_list)          RT = sum(error_rates_list)   # total leaf error rate of a tree         num_leaf_nodes = len(error_rates_list)          return RT, num_leaf_nodes      @staticmethod     def _ccp_alpha_eff(decision_node_Rt, leaf_nodes_RTt, num_leafs):          return (decision_node_Rt - leaf_nodes_RTt) / (num_leafs - 1)      def _find_weakest_node(self, tree, weakest_node_info):         if self._is_leaf_node(tree):             return tree          decision_node = next(iter(tree))         left_subtree, right_subtree = tree[decision_node]         *_, decision_node_error_rate = decision_node.split()          Rt = np.float128(decision_node_error_rate)         RTt, num_leaf_nodes = self._tree_error_rate_info(tree, [])         ccp_alpha = self._ccp_alpha_eff(Rt, RTt, num_leaf_nodes)         decision_node_index, min_ccp_alpha_index = 0, 1          if ccp_alpha <= weakest_node_info[min_ccp_alpha_index]:             weakest_node_info[decision_node_index] = decision_node             weakest_node_info[min_ccp_alpha_index] = ccp_alpha          self._find_weakest_node(left_subtree, weakest_node_info)         self._find_weakest_node(right_subtree, weakest_node_info)          return weakest_node_info      def _prune_tree(self, tree, weakest_node):         if self._is_leaf_node(tree):             return tree          decision_node = next(iter(tree))         left_subtree, right_subtree = tree[decision_node]         left_subtree_index, right_subtree_index = 0, 1         _, leaf_node = weakest_node.split('as_leaf ')          if weakest_node is decision_node:             tree = weakest_node         if weakest_node in left_subtree:             tree[decision_node][left_subtree_index] = leaf_node         if weakest_node in right_subtree:             tree[decision_node][right_subtree_index] = leaf_node          self._prune_tree(left_subtree, weakest_node)         self._prune_tree(right_subtree, weakest_node)          return tree      def cost_complexity_pruning_path(self, X: pd.DataFrame, y: pd.Series):         tree = self._grow_tree(X, y)   # grow a full tree         tree_error_rate, _ = self._tree_error_rate_info(tree, [])         error_rates = [tree_error_rate]         ccp_alpha_list = [0.0]          while not self._is_leaf_node(tree):             initial_node = [None, np.inf]             weakest_node, ccp_alpha = self._find_weakest_node(tree, initial_node)             tree = self._prune_tree(tree, weakest_node)             tree_error_rate, _ = self._tree_error_rate_info(tree, [])              error_rates.append(tree_error_rate)             ccp_alpha_list.append(ccp_alpha)          return np.array(ccp_alpha_list), np.array(error_rates)      def _ccp_tree_error_rate(self, tree_error_rate, num_leaf_nodes):          return tree_error_rate + self.ccp_alpha*num_leaf_nodes   # regularization      def _optimal_tree(self, X, y):         tree = self._grow_tree(X, y)   # grow a full tree         min_RT_alpha, final_tree = np.inf, None          while not self._is_leaf_node(tree):             RT, num_leaf_nodes = self._tree_error_rate_info(tree, [])             current_RT_alpha = self._ccp_tree_error_rate(RT, num_leaf_nodes)              if current_RT_alpha <= min_RT_alpha:                 min_RT_alpha = current_RT_alpha                 final_tree = deepcopy(tree)              initial_node = [None, np.inf]             weakest_node, _ = self._find_weakest_node(tree, initial_node)             tree = self._prune_tree(tree, weakest_node)          return final_tree      def fit(self, X: pd.DataFrame, y: pd.Series):         self.tree = self._optimal_tree(X, y)      def _traverse_tree(self, sample, tree):         if self._is_leaf_node(tree):             leaf, *_ = tree.split()             return leaf          decision_node = next(iter(tree))  # dict key         left_node, right_node = tree[decision_node]         feature, other = decision_node.split(' <=')         threshold, *_ = other.split()         feature_value = sample[feature]          if np.float128(feature_value) <= np.float128(threshold):             next_node = self._traverse_tree(sample, left_node)    # left_node         else:             next_node = self._traverse_tree(sample, right_node)   # right_node          return next_node      def predict(self, samples: pd.DataFrame):         # apply traverse_tree method for each row in a dataframe         results = samples.apply(self._traverse_tree, args=(self.tree,), axis=1)          return np.array(results.astype(self._y_dtype))

Код для отрисовки графиков

def tree_plot(sklearn_tree, Xa_train):     plt.figure(figsize=(12, 18))  # customize according to the size of your tree     plot_tree(sklearn_tree, feature_names=Xa_train.columns, filled=True, precision=6)     plt.show()   def tree_scores_plot(estimator, ccp_alphas, train_data, test_data, metric, labels):     train_scores, test_scores = [], []     X_train, y_train = train_data     X_test, y_test = test_data     x_label, y_label = labels      for ccp_alpha_i in ccp_alphas:         estimator.ccp_alpha = ccp_alpha_i         estimator.fit(X_train, y_train)         train_pred_res = estimator.predict(X_train)         test_pred_res = estimator.predict(X_test)          train_score = metric(train_pred_res, y_train)         test_score = metric(test_pred_res, y_test)         train_scores.append(train_score)         test_scores.append(test_score)      fig, ax = plt.subplots()     ax.set_xlabel(x_label)     ax.set_ylabel(y_label)     ax.set_title(f"{y_label} vs {x_label} for training and testing sets")     ax.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")     ax.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")     ax.legend()     plt.show()   def decision_boundary_plot(X, y, X_train, y_train, clf, feature_indexes, title=None):     if y.dtype != 'int':         y = pd.Series(LabelEncoder().fit_transform(y))         y_train = pd.Series(LabelEncoder().fit_transform(y_train))      feature1_name, feature2_name = X.columns[feature_indexes]     X_feature_columns = X.values[:, feature_indexes]     X_train_feature_columns = X_train.values[:, feature_indexes]     clf.fit(X_train_feature_columns, y_train.values)      plot_decision_regions(X=X_feature_columns, y=y.values, clf=clf)     plt.xlabel(feature1_name)     plt.ylabel(feature2_name)     plt.title(title)

Загрузка датасетов

Для обучения моделей будет использован Iris dataset, где необходимо верно определить типы цветков на основе их признаков. В случае регрессии используется load_linnerud dataset из scikit-learn.

df_path = "/content/drive/MyDrive/iris.csv" iris = pd.read_csv(df_path) X1, y1 = iris.iloc[:, :-1], iris.iloc[:, -1] X1_train, X1_test, y1_train, y1_test = train_test_split(X1, y1, test_size=0.3, random_state=0) print(iris)        sepal_length  sepal_width  petal_length  petal_width    species 0             5.1          3.5           1.4          0.2     setosa 1             4.9          3.0           1.4          0.2     setosa 2             4.7          3.2           1.3          0.2     setosa 3             4.6          3.1           1.5          0.2     setosa 4             5.0          3.6           1.4          0.2     setosa ..            ...          ...           ...          ...        ... 145           6.7          3.0           5.2          2.3  virginica 146           6.3          2.5           5.0          1.9  virginica 147           6.5          3.0           5.2          2.0  virginica 148           6.2          3.4           5.4          2.3  virginica 149           5.9          3.0           5.1          1.8  virginica  [150 rows x 5 columns]
X2, y2 = load_linnerud(return_X_y=True, as_frame=True) y2 = y2['Pulse'] X2_train, X2_test, y2_train, y2_test = train_test_split(X2, y2, test_size=0.2, random_state=0) print(X2, y2, sep='n')       Chins  Situps  Jumps 0     5.0   162.0   60.0 1     2.0   110.0   60.0 2    12.0   101.0  101.0 3    12.0   105.0   37.0 4    13.0   155.0   58.0 5     4.0   101.0   42.0 6     8.0   101.0   38.0 7     6.0   125.0   40.0 8    15.0   200.0   40.0 9    17.0   251.0  250.0 10   17.0   120.0   38.0 11   13.0   210.0  115.0 12   14.0   215.0  105.0 13    1.0    50.0   50.0 14    6.0    70.0   31.0 15   12.0   210.0  120.0 16    4.0    60.0   25.0 17   11.0   230.0   80.0 18   15.0   225.0   73.0 19    2.0   110.0   43.0  0     50.0 1     52.0 2     58.0 3     62.0 4     46.0 5     56.0 6     56.0 7     60.0 8     74.0 9     56.0 10    50.0 11    52.0 12    64.0 13    50.0 14    46.0 15    62.0 16    54.0 17    52.0 18    54.0 19    68.0 Name: Pulse, dtype: float64

Обучение моделей и оценка полученных результатов

В случае классификации дерево на данных ирис показало высокую точность. После прунинга точность не увеличилась, зато удалось найти более оптимальное дерево (alpha=0.0143) с такой же точностью и меньшим количеством узлов.

А вот в случае регрессии удалось получить прирост в плане точности, подобрав alpha=3.613, которое создаёт дерево с минимальной ошибкой на тестовом наборе. Все результаты приведены ниже.

Для более удобного отображения полученных результатов в ручных реализациях лучше использовать скачанный ноутбук в самом начале.

Классификация до прунинга

tree_classifier = DecisionTreeCART() tree_classifier.fit(X1_train, y1_train) clf_ccp_alphas, _ = tree_classifier.cost_complexity_pruning_path(X1_train, y1_train) clf_ccp_alphas = clf_ccp_alphas[:-1]  sk_tree_classifier = DecisionTreeClassifier(random_state=0) sk_tree_classifier.fit(X1_train, y1_train) sk_clf_path = sk_tree_classifier.cost_complexity_pruning_path(X1_train, y1_train) sk_clf_ccp_alphas = sk_clf_path.ccp_alphas[:-1]  sk_clf_estimator = DecisionTreeClassifier(random_state=0) train1_data, test1_data = [X1_train, y1_train], [X1_test, y1_test] metric = accuracy_score labels = ['Alpha', 'Accuracy']  pprint(tree_classifier.tree, width=180) tree_plot(sk_tree_classifier, X1_train) print(f'tree alphas: {clf_ccp_alphas}', f'sklearn alphas: {sk_clf_ccp_alphas}', sep='n') tree_scores_plot(sk_clf_estimator, clf_ccp_alphas, train1_data, test1_data, metric, labels)   {'petal_width <= 0.75 | as_leaf virginica error_rate 0.6643083900226757': ['setosa | error_rate 0.0',                                                                            {'petal_length <= 4.95 | as_leaf virginica error_rate 0.33480885311871234': [{'petal_width <= 1.65 | as_leaf versicolor error_rate 0.0521008403361345': ['versicolor '                                                                                                                                                                                                                                     '| '                                                                                                                                                                                                                                     'error_rate '                                                                                                                                                                                                                                     '0.0',                                                                                                                                                                                                                                     {'sepal_width <= 3.1 | as_leaf virginica error_rate 0.014285714285714287': ['virginica '                                                                                                                                                                                                                                                                                                                 '| '                                                                                                                                                                                                                                                                                                                 'error_rate '                                                                                                                                                                                                                                                                                                                 '0.0',                                                                                                                                                                                                                                                                                                                 'versicolor '                                                                                                                                                                                                                                                                                                                 '| '                                                                                                                                                                                                                                                                                                                 'error_rate '                                                                                                                                                                                                                                                                                                                 '0.0']}]},                                                                                                                                                         {'petal_width <= 1.75 | as_leaf virginica error_rate 0.018532818532818515': [{'petal_width <= 1.65 | as_leaf virginica error_rate 0.014285714285714287': ['virginica '                                                                                                                                                                                                                                                                                                                   '| '                                                                                                                                                                                                                                                                                                                   'error_rate '                                                                                                                                                                                                                                                                                                                   '0.0',                                                                                                                                                                                                                                                                                                                   'versicolor '                                                                                                                                                                                                                                                                                                                   '| '                                                                                                                                                                                                                                                                                                                   'error_rate '                                                                                                                                                                                                                                                                                                                   '0.0']},                                                                                                                                                                                                                                      'virginica '                                                                                                                                                                                                                                      '| '                                                                                                                                                                                                                                      'error_rate '                                                                                                                                                                                                                                      '0.0']}]}]} 

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

tree alphas: [0.         0.00926641 0.01428571 0.03781513 0.26417519] sklearn alphas: [0.         0.00926641 0.01428571 0.03781513 0.26417519]

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Классификация после прунинга

tree_clf_prediction = tree_classifier.predict(X1_test) tree_clf_accuracy = accuracy_score(tree_clf_prediction, y1_test) sk_tree_clf_prediction = sk_tree_classifier.predict(X1_test) sk_clf_accuracy = accuracy_score(sk_tree_clf_prediction, y1_test)  best_clf_ccp_alpha = 0.0143 # based on a plot best_tree_classifier = DecisionTreeCART(ccp_alpha=best_clf_ccp_alpha) best_tree_classifier.fit(X1_train, y1_train) best_tree_clf_prediction = best_tree_classifier.predict(X1_test) best_tree_clf_accuracy = accuracy_score(best_tree_clf_prediction, y1_test)  best_sk_tree_classifier = DecisionTreeClassifier(random_state=0, ccp_alpha=best_clf_ccp_alpha) best_sk_tree_classifier.fit(X1_train, y1_train) best_sk_tree_clf_prediction = best_sk_tree_classifier.predict(X1_test) best_sk_clf_accuracy = accuracy_score(best_sk_tree_clf_prediction, y1_test)  print('tree prediction', tree_clf_prediction, ' ', sep='n') print('sklearn prediction', sk_tree_clf_prediction, ' ', sep='n') print('best tree prediction', best_tree_clf_prediction, ' ', sep='n') print('best sklearn prediction', best_sk_tree_clf_prediction, ' ', sep='n')  pprint(best_tree_classifier.tree, width=180) tree_plot(best_sk_tree_classifier, X1_train) print(f'our tree pruning accuracy: before {tree_clf_accuracy} -> after {best_tree_clf_accuracy}') print(f'sklearn tree pruning accuracy: before {sk_clf_accuracy} -> after {best_sk_clf_accuracy}')   tree prediction ['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'  'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'  'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'  'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'  'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'  'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'  'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'  'setosa' 'setosa']   sklearn prediction ['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'  'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'  'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'  'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'  'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'  'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'  'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'  'setosa' 'setosa']   best tree prediction ['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'  'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'  'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'  'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'  'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'  'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'  'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'  'setosa' 'setosa']   best sklearn prediction ['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'  'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'  'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'  'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'  'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'  'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'  'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'  'setosa' 'setosa']   {'petal_width <= 0.75 | as_leaf virginica error_rate 0.6643083900226757': ['setosa | error_rate 0.0',                                                                            {'petal_length <= 4.95 | as_leaf virginica error_rate 0.33480885311871234': [{'petal_width <= 1.65 | as_leaf versicolor error_rate 0.0521008403361345': ['versicolor '                                                                                                                                                                                                                                     '| '                                                                                                                                                                                                                                     'error_rate '                                                                                                                                                                                                                                     '0.0',                                                                                                                                                                                                                                     'virginica '                                                                                                                                                                                                                                     'error_rate '                                                                                                                                                                                                                                     '0.014285714285714287']},                                                                                                                                                         'virginica error_rate '                                                                                                                                                         '0.018532818532818515']}]}

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

our tree pruning accuracy: before 0.9777777777777777 -> after 0.9777777777777777 sklearn tree pruning accuracy: before 0.9777777777777777 -> after 0.9777777777777777

Визуализация решающих границ до и после прунинга

feature_indexes = [2, 3] title1 = 'Classification tree surface before pruning' decision_boundary_plot(X1, y1, X1_train, y1_train, sk_tree_classifier, feature_indexes, title1)

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

feature_indexes = [2, 3] title2 = 'Classification tree surface after pruning' decision_boundary_plot(X1, y1, X1_train, y1_train, best_sk_tree_classifier, feature_indexes, title2)

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

feature_indexes = [2, 3] plt.figure(figsize=(10, 15))  for i, alpha in enumerate(clf_ccp_alphas):     sk_tree_clf = DecisionTreeClassifier(random_state=0, ccp_alpha=alpha)     plt.subplot(3, 2, i + 1)     plt.subplots_adjust(hspace=0.5)     title = f'ccp_alpha = {alpha}'     decision_boundary_plot(X1, y1, X1_train, y1_train, sk_tree_clf, feature_indexes, title)

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Прунинг при всех полученных alpha ccp

Регрессия до прунинга

tree_regressor = DecisionTreeCART(regression=True) tree_regressor.fit(X2_train, y2_train) reg_ccp_alphas, _ = tree_regressor.cost_complexity_pruning_path(X2_train, y2_train) reg_ccp_alphas = reg_ccp_alphas[:-1]  sk_tree_regressor = DecisionTreeRegressor(random_state=0) sk_tree_regressor.fit(X2_train, y2_train) sk_reg_path = sk_tree_regressor.cost_complexity_pruning_path(X2_train, y2_train) sk_reg_ccp_alphas = sk_reg_path.ccp_alphas[:-1]  reg_estimator = DecisionTreeCART(regression=True) sk_reg_estimator = DecisionTreeRegressor(random_state=0) train2_data, test2_data = [X2_train, y2_train], [X2_test, y2_test] metric = mean_absolute_percentage_error labels = ['Alpha', 'MAPE']  pprint(tree_regressor.tree) tree_plot(sk_tree_regressor, X2_train)  print(f'CART alphas: {reg_ccp_alphas}') tree_scores_plot(reg_estimator, reg_ccp_alphas, train2_data, test2_data, metric, labels) print(f'sklearn_alphas: {sk_reg_ccp_alphas}') tree_scores_plot(sk_reg_estimator, sk_reg_ccp_alphas, train2_data, test2_data, metric, labels)   {'Jumps <= 90.5 | as_leaf 54.625 error_rate 29.359375': [{'Jumps <= 46.0 | as_leaf 52.90909090909091 error_rate 17.181818181818183': [{'Jumps <= 34.0 | as_leaf 54.857142857142854 error_rate 11.428571428571429': [{'Jumps <= 28.0 | as_leaf 50.0 error_rate 2.0': ['54.0 '                                                                                                                                                                                                                                                                      '| '                                                                                                                                                                                                                                                                      'error_rate '                                                                                                                                                                                                                                                                      '0.0',                                                                                                                                                                                                                                                                      '46.0 '                                                                                                                                                                                                                                                                      '| '                                                                                                                                                                                                                                                                      'error_rate '                                                                                                                                                                                                                                                                      '0.0']},                                                                                                                                                                                                                     {'Chins <= 14.5 | as_leaf 56.8 error_rate 5.3': [{'Situps <= 103.0 | as_leaf 58.5 error_rate 1.6875': ['56.0 '                                                                                                                                                                                                                                                                                                                            '| '                                                                                                                                                                                                                                                                                                                            'error_rate '                                                                                                                                                                                                                                                                                                                            '0.0',                                                                                                                                                                                                                                                                                                                            {'Jumps <= 38.5 | as_leaf 61.0 error_rate 0.125': ['62.0 '                                                                                                                                                                                                                                                                                                                                                                               '| '                                                                                                                                                                                                                                                                                                                                                                               'error_rate '                                                                                                                                                                                                                                                                                                                                                                               '0.0',                                                                                                                                                                                                                                                                                                                                                                               '60.0 '                                                                                                                                                                                                                                                                                                                                                                               '| '                                                                                                                                                                                                                                                                                                                                                                               'error_rate '                                                                                                                                                                                                                                                                                                                                                                               '0.0']}]},                                                                                                                                                                                                                                                                      '50.0 '                                                                                                                                                                                                                                                                      '| '                                                                                                                                                                                                                                                                      'error_rate '                                                                                                                                                                                                                                                                      '0.0']}]},                                                                                                                                       {'Chins <= 12.0 | as_leaf 49.5 error_rate 1.1875': [{'Jumps <= 70.0 | as_leaf 50.666666666666664 error_rate 0.16666666666666666': ['50.0 '                                                                                                                                                                                                                                                                          '| '                                                                                                                                                                                                                                                                          'error_rate '                                                                                                                                                                                                                                                                          '0.0',                                                                                                                                                                                                                                                                          '52.0 '                                                                                                                                                                                                                                                                          '| '                                                                                                                                                                                                                                                                          'error_rate '                                                                                                                                                                                                                                                                          '0.0']},                                                                                                                                                                                           '46.0 '                                                                                                                                                                                           '| '                                                                                                                                                                                           'error_rate '                                                                                                                                                                                           '0.0']}]},                                                          {'Jumps <= 110.0 | as_leaf 58.4 error_rate 5.7': [{'Jumps <= 103.0 | as_leaf 61.0 error_rate 1.125': ['58.0 '                                                                                                                                                                '| '                                                                                                                                                                'error_rate '                                                                                                                                                                '0.0',                                                                                                                                                                '64.0 '                                                                                                                                                                '| '                                                                                                                                                                'error_rate '                                                                                                                                                                '0.0']},                                                                                                            {'Chins <= 12.5 | as_leaf 56.666666666666664 error_rate 3.1666666666666665': ['62.0 '                                                                                                                                                                                          '| '                                                                                                                                                                                          'error_rate '                                                                                                                                                                                          '0.0',                                                                                                                                                                                          {'Jumps <= 182.5 | as_leaf 54.0 error_rate 0.5': ['52.0 '                                                                                                                                                                                                                                            '| '                                                                                                                                                                                                                                            'error_rate '                                                                                                                                                                                                                                            '0.0',                                                                                                                                                                                                                                            '56.0 '                                                                                                                                                                                                                                            '| '                                                                                                                                                                                                                                            'error_rate '                                                                                                                                                                                                                                            '0.0']}]}]}]}

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

CART alphas: [0.         0.125      0.16666667 0.5        1.02083333 1.125  1.5625     2.         2.0375     3.6125     4.12857143 4.56574675]

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

sklearn_alphas: [0.         0.125      0.16666667 0.5        1.02083333 1.125  1.5625     2.         2.0375     3.6125     4.12857143 4.56574675]

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

Регрессия после прунинга

tree_reg_prediction = tree_regressor.predict(X2_test) tree_reg_error = mean_absolute_percentage_error(tree_reg_prediction, y2_test) sk_tree_reg_prediction = sk_tree_regressor.predict(X2_test) sk_reg_error= mean_absolute_percentage_error(sk_tree_reg_prediction, y2_test)  best_reg_ccp_alpha = 3.613   # based on a plot best_tree_regressor = DecisionTreeCART(ccp_alpha=best_reg_ccp_alpha, regression=True) best_tree_regressor.fit(X2_train, y2_train) best_tree_reg_prediction = best_tree_regressor.predict(X2_test) lowest_tree_reg_error = mean_absolute_percentage_error(best_tree_reg_prediction, y2_test)  best_sk_tree_regressor = DecisionTreeRegressor(random_state=0, ccp_alpha=best_reg_ccp_alpha) best_sk_tree_regressor.fit(X2_train, y2_train) best_sk_tree_reg_prediction = best_sk_tree_regressor.predict(X2_test) lowest_sk_reg_error = mean_absolute_percentage_error(best_sk_tree_reg_prediction, y2_test)  print('tree prediction', tree_reg_prediction, ' ', sep='n') print('sklearn prediction', sk_tree_reg_prediction, ' ', sep='n') print('best tree prediction', best_tree_reg_prediction, ' ', sep='n') print('best sklearn prediction', best_sk_tree_reg_prediction, ' ', sep='n')  pprint(best_tree_regressor.tree) tree_plot(best_sk_tree_regressor, X2_train) print(f'tree error: before {tree_reg_error} -> after pruning {lowest_tree_reg_error}') print(f'sklearn tree error: before {sk_reg_error} -> after pruning {lowest_sk_reg_error}')   tree prediction [46. 50. 60. 50.]   sklearn prediction [46. 50. 60. 50.]   best tree prediction [49.5 49.5 56.8 56.8]   best sklearn prediction [49.5 49.5 56.8 56.8]   {'Jumps <= 90.5 | as_leaf 54.625 error_rate 29.359375': [{'Jumps <= 46.0 | as_leaf 52.90909090909091 error_rate 17.181818181818183': [{'Jumps <= 34.0 | as_leaf 54.857142857142854 error_rate 11.428571428571429': ['50.0 '                                                                                                                                                                                                                     'error_rate '                                                                                                                                                                                                                     '2.0',                                                                                                                                                                                                                     '56.8 '                                                                                                                                                                                                                     'error_rate '                                                                                                                                                                                                                     '5.3']},                                                                                                                                       '49.5 '                                                                                                                                       'error_rate '                                                                                                                                       '1.1875']},                                                          '58.4 error_rate 5.7']}

Дерево решений (CART). От теоретических основ до продвинутых техник и реализации с нуля на Python

tree error: before 0.20681159420289855 -> after pruning 0.16035353535353536 sklearn tree error: before 0.20681159420289855 -> after pruning 0.1603535353535354

Преимущества и недостатки дерева решений

Преимущества:

  • простота в интерпретации и визуализации;

  • неплохая работа с нелинейными зависимостями в данных;

  • не требуется особой подготовки тренировочного набора;

  • относительно высокая скорость обучения и прогнозирования.

Недостатки:

  • поиск оптимального дерева является NP-полной задачей;

  • нестабильность работы даже при небольшом изменении данных;

  • возможность переобучения из-за чувствительности к шуму и выбросам в данных.

Дополнительные источники

Статья «The CART Decision Tree for Mining Data Streams», Leszek Rutkowskia, Maciej Jaworskia, Lena Pietruczuka, Piotr Dudaa.

Документация:

  • описание CART ;

  • DecisionTreeClassifier ;

  • DecisionTreeRegressor ;

  • pruning .

Лекции: один, два, три, четыре.

Пошаговое построение дерева решений: один, два, три.

Pruning:

  • Статья.

  • Видео: один, два, три.

Бэггинг и случайный лес 🡆