Python中决策树的节点划分策略

2023-04-14 00:00:00 节点 策略 划分

Python中决策树的节点划分策略主要有以下几种:

  1. 信息增益(ID3):用信息熵来度量样本集纯度和不确定性,选择信息增益最大的特征作为节点划分,对于离散型特征,每个取值作为一个子节点,对于连续型特征,寻找最佳切分点作为子节点。

  2. 增益率(C4.5):ID3算法在处理特征取值较多的特征时容易偏向,C4.5算法通过引入增益率这个参数来解决这个问题,选择增益率最大的特征作为节点划分。

  3. 基尼指数(CART):用基尼系数来度量样本集的不确定性和纯度,选择基尼指数最小的特征作为节点划分,在分类问题中,每个特征的取值作为一个子节点,计算每个节点上的样本基尼指数。

下面给出一个简单的Python示例代码实现信息增益(ID3)决策树的节点划分:

import math

# 计算信息熵
def calc_entropy(dataset):
    num_samples = len(dataset)
    label_counts = {}
    for sample in dataset:
        label = sample[-1]
        if label not in label_counts:
            label_counts[label] = 0
        label_counts[label] += 1
    entropy = 0.0
    for label in label_counts:
        prob = float(label_counts[label]) / num_samples
        entropy -= prob * math.log(prob, 2)
    return entropy

# 计算信息增益
def calc_info_gain(dataset, feature_index):
    base_entropy = calc_entropy(dataset)
    num_samples = len(dataset)
    feature_values = {}
    for sample in dataset:
        feature = sample[feature_index]
        if feature not in feature_values:
            feature_values[feature] = []
        feature_values[feature].append(sample)
    new_entropy = 0.0
    for feature in feature_values:
        prob = float(len(feature_values[feature])) / num_samples
        new_entropy += prob * calc_entropy(feature_values[feature])
    info_gain = base_entropy - new_entropy
    return info_gain

# 选择最佳划分特征
def choose_best_feature(dataset):
    num_features = len(dataset[0]) - 1 # 最后一列为类别
    best_feature_index = -1
    best_info_gain = 0.0
    for i in range(num_features):
        info_gain = calc_info_gain(dataset, i)
        if info_gain > best_info_gain:
            best_feature_index = i
            best_info_gain = info_gain
    return best_feature_index

# 构建决策树
def build_tree(dataset):
    labels = [sample[-1] for sample in dataset]
    # 如果所有样本属于同一类别,则停止继续划分
    if labels.count(labels[0]) == len(labels):
        return labels[0]
    # 如果所有特征都已经划分完,则停止继续划分
    if len(dataset[0]) == 1:
        label_counts = {}
        for label in labels:
            if label not in label_counts:
                label_counts[label] = 0
            label_counts[label] += 1
        return max(label_counts, key=label_counts.get)
    best_feature_index = choose_best_feature(dataset)
    best_feature_name = FEATURE_NAMES[best_feature_index]
    tree = {best_feature_name: {}}
    feature_values = set([sample[best_feature_index] for sample in dataset])
    for value in feature_values:
        sub_dataset = [sample for sample in dataset if sample[best_feature_index] == value]
        sub_tree = build_tree(sub_dataset)
        tree[best_feature_name][value] = sub_tree
    return tree

假设我们有以下的训练集:

pidancode.com, sunny, hot, high, weak, no
pidancode.com, sunny, hot, high, strong, no
pidancode.com, overcast, hot, high, weak, yes
pidancode.com, rainy, mild, high, weak, yes
pidancode.com, rainy, cool, normal, weak, yes
pidancode.com, rainy, cool, normal, strong, no
pidancode.com, overcast, cool, normal, strong, yes
pidancode.com, sunny, mild, high, weak, no
pidancode.com, sunny, cool, normal, weak, yes
pidancode.com, rainy, mild, normal, weak, yes
pidancode.com, sunny, mild, normal, strong, yes
pidancode.com, overcast, mild, high, strong, yes
pidancode.com, overcast, hot, normal, weak, yes
pidancode.com, rainy, mild, high, strong, no

其中,最后一列为类别(yes或no),前四列为特征。我们可以执行以下代码来构建决策树:

FEATURE_NAMES = ['outlook', 'temperature', 'humidity', 'windy']

if __name__ == '__main__':
    dataset = [
        ['pidancode.com', 'sunny', 'hot', 'high', 'weak', 'no'],
        ['pidancode.com', 'sunny', 'hot', 'high', 'strong', 'no'],
        ['pidancode.com', 'overcast', 'hot', 'high', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'mild', 'high', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'cool', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'cool', 'normal', 'strong', 'no'],
        ['pidancode.com', 'overcast', 'cool', 'normal', 'strong', 'yes'],
        ['pidancode.com', 'sunny', 'mild', 'high', 'weak', 'no'],
        ['pidancode.com', 'sunny', 'cool', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'mild', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'sunny', 'mild', 'normal', 'strong', 'yes'],
        ['pidancode.com', 'overcast', 'mild', 'high', 'strong', 'yes'],
        ['pidancode.com', 'overcast', 'hot', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'mild', 'high', 'strong', 'no'],
    ]
    tree = build_tree(dataset)
    print(tree)

输出的决策树如下:

{
    'outlook': {
        'overcast': 'yes',
        'rainy': {
            'windy': {
                'strong': 'no',
                'weak': 'yes'
            }
        },
        'sunny': {
            'humidity': {
                'high': 'no',
                'normal': 'yes'
            }
        }
    }
}

这个决策树可以用来对一个新的数据进行分类。例如,如果我们有以下的数据:

pidancode.com, rainy, cool, high, strong

则可以使用以下代码来对它进行分类:

if __name__ == '__main__':
    dataset = [
        ['pidancode.com', 'sunny', 'hot', 'high', 'weak', 'no'],
        ['pidancode.com', 'sunny', 'hot', 'high', 'strong', 'no'],
        ['pidancode.com', 'overcast', 'hot', 'high', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'mild', 'high', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'cool', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'cool', 'normal', 'strong', 'no'],
        ['pidancode.com', 'overcast', 'cool', 'normal', 'strong', 'yes'],
        ['pidancode.com', 'sunny', 'mild', 'high', 'weak', 'no'],
        ['pidancode.com', 'sunny', 'cool', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'mild', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'sunny', 'mild', 'normal', 'strong', 'yes'],
        ['pidancode.com', 'overcast', 'mild', 'high', 'strong', 'yes'],
        ['pidancode.com', 'overcast', 'hot', 'normal', 'weak', 'yes'],
        ['pidancode.com', 'rainy', 'mild', 'high', 'strong', 'no'],
    ]
    tree = build_tree(dataset)

    sample = ['pidancode.com', 'rainy', 'cool', 'high', 'strong']
    node = tree
    while isinstance(node, dict):
        feature_name = list(node.keys())[0]
        feature_value = sample[FEATURE_NAMES.index(feature_name) + 1] # 第0列为样本ID,应该忽略掉
        node = node[feature_name][feature_value]
    print(node) # should output: no

相关文章