#include <iostream>
#include <random>
#include <algorithm>
#include <vector>
#include <set>
#include "stdafx.h"
#include "utils.h"
using namespace std;
int classK = 3;
int classList[3] = { 1,3,6 };
float* pk = new float[classK];
int ATTRS = 19;
int ROWS = 117;
float** data_X;
float** data_y;
int ntrace = 821;
int nSample = 476;
int SampleRate = 2000;
int headNum = 120;
struct Node {
vector<int> ids;
vector<int> attrs;
int attr;
float value;
Node* left;
Node* right;
};
vector<int> ids_init_random() {
vector<int> vec;
for (int i = 0; i < ROWS; ++i) {
vec.push_back(rand() % ROWS);
}
return vec;
}
vector<int> attrs_init_random() {
vector<int> vec;
int i = 0;
int tmp = 0;
bool flag = false;
while (i<5) {
tmp = rand() % ATTRS;
for (vector<int>::iterator it = vec.begin(); it != vec.end(); it++) {
if (*it == tmp) {
flag = true;
break;
}
}
if (!flag) {
vec.push_back(tmp);
++i;
}
flag = false;
}
return vec;
}
float gini(vector<int> ids) {
memset(pk, 0, sizeof(float)*classK);
for (vector<int>::iterator it = ids.begin(); it != ids.end(); it++) {
for (int k = 0; k < classK; k++) {
if (data_y[*it][0] == classList[k]) {
pk[k] = pk[k] + 1;
break;
}
}
}
float gini_p = 0;
for (int k = 0; k < classK; k++) {
if (ids.size() != 0)
pk[k] = pk[k] / ids.size();
gini_p = gini_p + pk[k] * (1 - pk[k]);
}
return gini_p;
}
int cart(Node* root) {
memset(pk, 0, sizeof(float)*classK);
for (vector<int>::iterator it = root->ids.begin(); it != root->ids.end(); it++) {
for (int k = 0; k < classK; k++) {
if (data_y[*it][0] == classList[k]) {
pk[k] = pk[k] + 1;
break;
}
}
}
float mostvote = 0;
int mostid = 0;
for (int k = 0; k < classK; k++) {
if (pk[k] > mostvote) {
mostvote = pk[k];
mostid = k;
}
}
if (mostvote / root->ids.size() > 0.8) {
root->value = classList[mostid];
return 0;
}
float gini_p = gini(root->ids);
vector<int> ids1;
vector<int> ids2;
float gini_p_min = FLT_MAX;
int gini_p_attr = 1;
float gini_p_value = 0;
float gini_tmp = 0;
bool flag = false;
for (vector<int>::iterator it_attr = root->attrs.begin(); it_attr!=root->attrs.end(); it_attr++) {
for (vector<int>::iterator it_value = root->ids.begin(); it_value != root->ids.end(); it_value++) {
for (vector<int>::iterator it = root->ids.begin(); it != root->ids.end(); it++) {
if (data_X[*it][*it_attr] <= data_X[*it_value][*it_attr]) {
ids1.push_back(*it);
}
else {
ids2.push_back(*it);
}
}
gini_tmp = gini(ids1)*ids1.size() / root->ids.size() + gini(ids2)*ids2.size() / root->ids.size();
if (gini_tmp < gini_p_min) {
gini_p_min = gini_tmp;
gini_p_attr = *it_attr;
gini_p_value = data_X[*it_value][*it_attr];
}
ids1.clear();
ids2.clear();
}
}
for (vector<int>::iterator it = root->ids.begin(); it != root->ids.end(); it++) {
if (data_X[*it][gini_p_attr] <= gini_p_value) {
ids1.push_back(*it);
}
else {
ids2.push_back(*it);
}
}
if (ids2.size() <= 5 || ids1.size() <= 5) {
root->value = classList[mostid];
return 0;
}
root->attr = gini_p_attr;
root->value = gini_p_value;
root->left = new Node;
*root->left = { ids1, root->attrs, 0, 0, NULL, NULL };
cart(root->left);
root->right = new Node;
*root->right = { ids2, root->attrs, 0, 0, NULL, NULL };
cart(root->right);
}
float cart_predict(Node* root, float* predict_X) {
Node* cur = root;
while (!(cur->left == NULL && cur->right == NULL)) {
if (predict_X[cur->attr] <= cur->value) {
cur = cur->left;
}
else {
cur = cur->right;
}
}
return cur->value;
}