在Python中使用决策树进行多分类问题

2023-04-14 00:00:00 python 分类 决策树
  1. 数据准备

首先,我们需要准备用于训练和测试的数据集。对于多分类问题,通常我们需要将训练集和测试集分别存储为两个csv文件,每个文件包含以下列:
- 特征(features):描述样本的属性,例如文字、数字等。
- 目标(target):样本应该被分为哪个类别。

在本例中,我们假设有一个训练集"train.csv"和一个测试集"test.csv",每个文件包含两列数据"feature"和"target",其中"target"列包含3个类别"pidancode.com"、"皮蛋编程"和"其他":

train.csv
feature,target
sample1,pidancode.com
sample2,其他
sample3,皮蛋编程
...

test.csv
feature,target
sample4,其他
sample5,皮蛋编程
sample6,pidancode.com
...

  1. 导入必要的库

在使用决策树库之前,需要先导入必要的库。

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
  1. 读取数据

读取训练集和测试集中的数据,分别存储在"train_data"和"test_data"中,然后将特征和目标分别存储在X_train、X_test和y_train、y_test中:

train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')

X_train = train_data.drop('target', axis=1)
y_train = train_data['target']

X_test = test_data.drop('target', axis=1)
y_test = test_data['target']

注意:axis=1表示去掉的是列数据,而不是行数据。

  1. 特征工程

在数据准备阶段,我们将特征、目标存储在X_train、X_test和y_train、y_test中。这些数据是原始数据,需要经过一些处理才能用于训练和测试决策树模型。这个过程称为特征工程。

在本例中,我们的特征是字符串,需要将它们转换为数字。

X_train = X_train.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2}))
X_test = X_test.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2}))

注意:如果特征数据中包含缺失值,需要进行处理。可以使用Pandas中的fillna()函数,将缺失值填充为平均数、中位数、众数等。

  1. 训练模型

现在,我们已经完成了数据准备和特征工程,可以使用sklearn库的DecisionTreeClassifier类进行训练:

dtc = DecisionTreeClassifier()
dtc.fit(X_train, y_train)
  1. 预测结果

现在,我们使用训练好的决策树模型预测测试数据集中的目标。然后,计算预测准确度。在本例中,预测准确度定义为正确预测数量与总数量的比率。这可以使用sklearn库的accuracy_score函数实现:

y_pred = dtc.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
  1. 完整代码
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')

X_train = train_data.drop('target', axis=1)
y_train = train_data['target']

X_test = test_data.drop('target', axis=1)
y_test = test_data['target']

X_train = X_train.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2}))
X_test = X_test.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2}))

dtc = DecisionTreeClassifier()
dtc.fit(X_train, y_train)

y_pred = dtc.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

print("Accuracy:", accuracy)

相关文章