解释用例

DT_data.csv为样本数据,共14条记录。每一条记录共4维特征,分别为Weather(天气), Temperature(温度),Humidity(湿度),Wind(风力);其中Date(约会)为标签列。

根据样本数据,建立决策树。
输入测试数据,得到预测是否约会(yes/no)。

算法实现

1. 计算熵(entropy

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
double entropy(const vector<DataPoint>& data) {
map<string, int> labelCount;
for (const auto& d : data) {
labelCount[d.date]++;
}

double entropy = 0.0;
int total = data.size();
for (const auto& pair : labelCount) {
double prob = (double)pair.second / total;
entropy -= prob * log2(prob);
}

return entropy;
}
  • 目的:计算数据集的熵,熵是信息论中的一个概念,用于衡量信息的不确定性。
  • 具体流程
    • 首先统计每个类别的数量。
    • 计算总的熵值,遍历每个类别计算其概率,并计算其对应的熵贡献(−plog⁡2(p)−plog2​(p))。

2. 计算信息增益(informationGain

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
double informationGain(const vector<DataPoint>& data, const string& feature) {
map<string, vector<DataPoint>> subsets;
// 分割数据
for (const auto& d : data) {
// 根据特征将数据划分为子集
}

double totalEntropy = entropy(data);
double weightedEntropy = 0.0;
int total = data.size();

// 计算每个子集的熵,计算加权熵
for (const auto& pair : subsets) {
double subsetEntropy = entropy(pair.second);
weightedEntropy += ((double)pair.second.size() / total) * subsetEntropy;
}

return totalEntropy - weightedEntropy; // 信息增益
}
  • 目的:根据某一特征计算信息增益,即选择该特征来划分数据所带来的信息增益。
  • 具体流程
    • 遍历数据,按特征将数据划分为不同的子集。
    • 计算整体熵和每个子集的加权熵,然后用整体熵减去加权熵得到信息增益。

3. 计算信息增益率(gainRatio

1
2
3
4
5
6
7
8
9
10
11
12
double gainRatio(const vector<DataPoint>& data, const string& feature) {
double infoGain = informationGain(data, feature);
// 计算分裂信息
double splitInfo = 0.0;

for (const auto& pair : subsets) {
double prob = (double)pair.second.size() / total;
splitInfo -= prob * log2(prob);
}

return (splitInfo != 0) ? infoGain / splitInfo : 0.0; // 信息增益率
}
  • 目的:计算信息增益率,这是 C4.5 算法中用于选择特征的指标,通过对信息增益进行标准化来避免偏向于多个取值的特征。
  • 具体流程
    • 在进行信息增益的计算后,还需计算分裂信息(即按特征划分所产生的信息量)。
    • 最后,信息增益除以分裂信息得到信息增益率。

4. 构建决策树(buildTree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
TreeNode* buildTree(const vector<DataPoint>& data, vector<string>& features) {
// 基本情况处理
if (data.empty()) return nullptr;

// 数据纯净判断
if (labelCount.size() == 1) {
// 生成叶节点
}

// 特征选择
for (const auto& feature : features) {
double gain = useC45 ? gainRatio(data, feature) : informationGain(data, feature);
// 找到最佳特征
}

// 数据分割
for (const auto& d : data) {
// 按最佳特征划分数据
}

// 递归构建子节点
for (const auto& pair : subsets) {
node->children[pair.first] = buildTree(pair.second, remainingFeatures);
}

return node;
}
  • 目的:递归地构建决策树。
  • 具体流程
    • 首先检查基本情况,如数据集是否为空或者数据是否纯净(只有一个类别)。
    • 如果没有特征可用,则返回数量最多的类别作为叶节点。
    • 选择信息增益或信息增益率最大的特征作为当前节点的特征。
    • 将数据集按最佳特征分割成子集,并对每个子集递归调用 buildTree() 生成子节点。

5. 使用 ID3 进行预测(PredictID3)

1
2
3
4
5
6
7
8
string DecisionTree::predictID3(TreeNode * root, const DataPoint & test) {
if (root->label != "") {
// 如果到达叶节点,返回预测标签
}

// 根据特征继续向下遍历决策树
return " unknown ";
}

6. 使用 C4.5 进行预测(predictC45

1
2
3
4
5
6
7
8
string predictC45(TreeNode* root, const DataPoint& test) {
if (root->label != "") {
return root->label;
}

// 类似 ID3 的预测过程
return " unknown ";
}
  • 目的:进行预测,使用已构建的树来分类新的数据点。
  • 具体流程
    • 从树的根节点开始,判断是否为叶节点。
    • 否则,根据数据点的特征值导航至对应子节点,直到达到叶节点返回预测类别。

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
// 计算熵
double entropy(const vector<DataPoint>& data) {
// 统计每个类别的数量
map<string, int> labelCount;
for (const auto& d : data) {
labelCount[d.date]++;
}

// 计算总的熵值,遍历每个类别计算其概率,并计算其对应的熵贡献
double entropy = 0.0;
int total = data.size();
for (const auto& pair : labelCount) {
double prob = (double)pair.second / total;
entropy -= prob * log2(prob);
}

return entropy;
}

// 计算信息增益
double informationGain(const vector<DataPoint>& data, const string& feature) {
map<string, vector<DataPoint>> subsets;
// 分割数据
for (const auto& d : data) {
// 根据特征将数据划分为子集
if (feature == "weather") {
subsets[d.weather].push_back(d);
}
else if (feature == "temperature") {
subsets[d.temperature].push_back(d);
}
else if (feature == "humidity") {
subsets[d.humidity].push_back(d);
}
else if (feature == "wind") {
subsets[d.wind].push_back(d);
}
}

double totalEntropy = entropy(data);
double weightedEntropy = 0.0;
int total = data.size();

// 计算每个子集的熵,计算加权熵
for (const auto& pair : subsets) {
double subsetEntropy = entropy(pair.second);
weightedEntropy += ((double)pair.second.size() / total) * subsetEntropy;
}

return totalEntropy - weightedEntropy; // 信息增益
}

// 计算信息增益率
double gainRatio(const vector<DataPoint>& data, const string& feature) {
double infoGain = informationGain(data, feature);

map<string, vector<DataPoint>> subsets;
for (const auto& d : data) {
if (feature == "weather") {
subsets[d.weather].push_back(d);
}
else if (feature == "temperature") {
subsets[d.temperature].push_back(d);
}
else if (feature == "humidity") {
subsets[d.humidity].push_back(d);
}
else if (feature == "wind") {
subsets[d.wind].push_back(d);
}
}

// 计算分裂信息
double splitInfo = 0.0;
int total = data.size();
for (const auto& pair : subsets) {
double prob = (double)pair.second.size() / total;
splitInfo -= prob * log2(prob);
}

return (splitInfo != 0) ? infoGain / splitInfo : 0.0;
}

// 构建决策树
TreeNode* buildTree(const vector<DataPoint>& data, vector<string>& features) {
if (data.empty()) return nullptr;

// 如果数据纯净(只有一个类别),返回叶节点
map<string, int> labelCount;
for (const auto& d : data) {
labelCount[d.date]++;
}
if (labelCount.size() == 1) {
TreeNode* leaf = new TreeNode();
leaf->label = labelCount.begin()->first;
return leaf;
}

if (features.empty()) {
TreeNode* leaf = new TreeNode();
leaf->label = max_element(labelCount.begin(), labelCount.end(), [](const pair<string, int>& a, const pair<string, int>& b) {
return a.second < b.second;
})->first;
return leaf;
}

// 选择最优特征
string bestFeature;
double bestGain = -1;
for (const auto& feature : features) {
double gain = useC45 ? gainRatio(data, feature) : informationGain(data, feature);
if (gain > bestGain) {
bestGain = gain;
bestFeature = feature;
}
}

// 创建当前节点
TreeNode* node = new TreeNode();
node->feature = bestFeature;

// 按特征值分割数据
map<string, vector<DataPoint>> subsets;
for (const auto& d : data) {
if (bestFeature == "weather") {
subsets[d.weather].push_back(d);
}
else if (bestFeature == "temperature") {
subsets[d.temperature].push_back(d);
}
else if (bestFeature == "humidity") {
subsets[d.humidity].push_back(d);
}
else if (bestFeature == "wind") {
subsets[d.wind].push_back(d);
}
}

// 递归构建子节点
vector<string> remainingFeatures = features;
remainingFeatures.erase(remove(remainingFeatures.begin(), remainingFeatures.end(), bestFeature), remainingFeatures.end());
for (const auto& pair : subsets) {
node->children[pair.first] = buildTree(pair.second, remainingFeatures);
}

return node;
}

// 使用 ID3 进行预测
string predictID3(TreeNode * root, const DataPoint & test) {
if (root->label != "") {
return root->label;
}

if (root->children.find(test.weather) != root->children.end()) {
return predictID3(root->children[test.weather], test);
}
else if (root->children.find(test.temperature) != root->children.end()) {
return predictID3(root->children[test.temperature], test);
}
else if (root->children.find(test.humidity) != root->children.end()) {
return predictID3(root->children[test.humidity], test);
}
else if (root->children.find(test.wind) != root->children.end()) {
return predictID3(root->children[test.wind], test);
}

return " unknown ";
}

// 使用 C4.5 进行预测
string predictC45(TreeNode* root, const DataPoint& test) {
if (root->label != "") {
return root->label;
}

if (root->children.find(test.weather) != root->children.end()) {
return predictC45(root->children[test.weather], test);
}
else if (root->children.find(test.temperature) != root->children.end()) {
return predictC45(root->children[test.temperature], test);
}
else if (root->children.find(test.humidity) != root->children.end()) {
return predictC45(root->children[test.humidity], test);
}
else if (root->children.find(test.wind) != root->children.end()) {
return predictC45(root->children[test.wind], test);
}

return " unknown ";
}