在Python中使用决策树进行多分类问题
- 数据准备
首先,我们需要准备用于训练和测试的数据集。对于多分类问题,通常我们需要将训练集和测试集分别存储为两个csv文件,每个文件包含以下列:
- 特征(features):描述样本的属性,例如文字、数字等。
- 目标(target):样本应该被分为哪个类别。
在本例中,我们假设有一个训练集"train.csv"和一个测试集"test.csv",每个文件包含两列数据"feature"和"target",其中"target"列包含3个类别"pidancode.com"、"皮蛋编程"和"其他":
train.csv
feature,target
sample1,pidancode.com
sample2,其他
sample3,皮蛋编程
...
test.csv
feature,target
sample4,其他
sample5,皮蛋编程
sample6,pidancode.com
...
- 导入必要的库
在使用决策树库之前,需要先导入必要的库。
import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score
- 读取数据
读取训练集和测试集中的数据,分别存储在"train_data"和"test_data"中,然后将特征和目标分别存储在X_train、X_test和y_train、y_test中:
train_data = pd.read_csv('train.csv') test_data = pd.read_csv('test.csv') X_train = train_data.drop('target', axis=1) y_train = train_data['target'] X_test = test_data.drop('target', axis=1) y_test = test_data['target']
注意:
axis=1
表示去掉的是列数据,而不是行数据。
- 特征工程
在数据准备阶段,我们将特征、目标存储在X_train、X_test和y_train、y_test中。这些数据是原始数据,需要经过一些处理才能用于训练和测试决策树模型。这个过程称为特征工程。
在本例中,我们的特征是字符串,需要将它们转换为数字。
X_train = X_train.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2})) X_test = X_test.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2}))
注意:如果特征数据中包含缺失值,需要进行处理。可以使用Pandas中的fillna()函数,将缺失值填充为平均数、中位数、众数等。
- 训练模型
现在,我们已经完成了数据准备和特征工程,可以使用sklearn库的DecisionTreeClassifier类进行训练:
dtc = DecisionTreeClassifier() dtc.fit(X_train, y_train)
- 预测结果
现在,我们使用训练好的决策树模型预测测试数据集中的目标。然后,计算预测准确度。在本例中,预测准确度定义为正确预测数量与总数量的比率。这可以使用sklearn库的accuracy_score函数实现:
y_pred = dtc.predict(X_test) accuracy = accuracy_score(y_test, y_pred)
- 完整代码
import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score train_data = pd.read_csv('train.csv') test_data = pd.read_csv('test.csv') X_train = train_data.drop('target', axis=1) y_train = train_data['target'] X_test = test_data.drop('target', axis=1) y_test = test_data['target'] X_train = X_train.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2})) X_test = X_test.apply(lambda x: x.map({'pidancode.com': 0, '皮蛋编程': 1, '其他': 2})) dtc = DecisionTreeClassifier() dtc.fit(X_train, y_train) y_pred = dtc.predict(X_test) accuracy = accuracy_score(y_test, y_pred) print("Accuracy:", accuracy)
相关文章