Commit a0364fe1 authored by sgebreeg's avatar sgebreeg
Browse files

Merge branch 'faster-continuous-data-split' into 'master'

Faster continuous data split

See merge request !1
parents 97ff316e cabe18e1
This diff is collapsed.
continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous,continuous
\ No newline at end of file
pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,pixel9,pixel10,pixel11,pixel12,pixel13,pixel14,pixel15,pixel16,pixel17,pixel18,pixel19,pixel20,pixel21,pixel22,pixel23,pixel24,pixel25,pixel26,pixel27,pixel28,pixel29,pixel30,pixel31,pixel32,pixel33,pixel34,pixel35,pixel36,pixel37,pixel38,pixel39,pixel40,pixel41,pixel42,pixel43,pixel44,pixel45,pixel46,pixel47,pixel48,pixel49,pixel50,pixel51,pixel52,pixel53,pixel54,pixel55,pixel56,pixel57,pixel58,pixel59,pixel60,pixel61,pixel62,pixel63,pixel64,pixel65,pixel66,pixel67,pixel68,pixel69,pixel70,pixel71,pixel72,pixel73,pixel74,pixel75,pixel76,pixel77,pixel78,pixel79,pixel80,pixel81,pixel82,pixel83,pixel84,pixel85,pixel86,pixel87,pixel88,pixel89,pixel90,pixel91,pixel92,pixel93,pixel94,pixel95,pixel96,pixel97,pixel98,pixel99,pixel100,pixel101,pixel102,pixel103,pixel104,pixel105,pixel106,pixel107,pixel108,pixel109,pixel110,pixel111,pixel112,pixel113,pixel114,pixel115,pixel116,pixel117,pixel118,pixel119,pixel120,pixel121,pixel122,pixel123,pixel124,pixel125,pixel126,pixel127,pixel128,pixel129,pixel130,pixel131,pixel132,pixel133,pixel134,pixel135,pixel136,pixel137,pixel138,pixel139,pixel140,pixel141,pixel142,pixel143,pixel144,pixel145,pixel146,pixel147,pixel148,pixel149,pixel150,pixel151,pixel152,pixel153,pixel154,pixel155,pixel156,pixel157,pixel158,pixel159,pixel160,pixel161,pixel162,pixel163,pixel164,pixel165,pixel166,pixel167,pixel168,pixel169,pixel170,pixel171,pixel172,pixel173,pixel174,pixel175,pixel176,pixel177,pixel178,pixel179,pixel180,pixel181,pixel182,pixel183,pixel184,pixel185,pixel186,pixel187,pixel188,pixel189,pixel190,pixel191,pixel192,pixel193,pixel194,pixel195,pixel196,pixel197,pixel198,pixel199,pixel200,pixel201,pixel202,pixel203,pixel204,pixel205,pixel206,pixel207,pixel208,pixel209,pixel210,pixel211,pixel212,pixel213,pixel214,pixel215,pixel216,pixel217,pixel218,pixel219,pixel220,pixel221,pixel222,pixel223,pixel224,pixel225,pixel226,pixel227,pixel228,pixel229,pixel230,pixel231,pixel232,pixel233,pixel234,pixel235,pixel236,pixel237,pixel238,pixel239,pixel240,pixel241,pixel242,pixel243,pixel244,pixel245,pixel246,pixel247,pixel248,pixel249,pixel250,pixel251,pixel252,pixel253,pixel254,pixel255,pixel256,pixel257,pixel258,pixel259,pixel260,pixel261,pixel262,pixel263,pixel264,pixel265,pixel266,pixel267,pixel268,pixel269,pixel270,pixel271,pixel272,pixel273,pixel274,pixel275,pixel276,pixel277,pixel278,pixel279,pixel280,pixel281,pixel282,pixel283,pixel284,pixel285,pixel286,pixel287,pixel288,pixel289,pixel290,pixel291,pixel292,pixel293,pixel294,pixel295,pixel296,pixel297,pixel298,pixel299,pixel300,pixel301,pixel302,pixel303,pixel304,pixel305,pixel306,pixel307,pixel308,pixel309,pixel310,pixel311,pixel312,pixel313,pixel314,pixel315,pixel316,pixel317,pixel318,pixel319,pixel320,pixel321,pixel322,pixel323,pixel324,pixel325,pixel326,pixel327,pixel328,pixel329,pixel330,pixel331,pixel332,pixel333,pixel334,pixel335,pixel336,pixel337,pixel338,pixel339,pixel340,pixel341,pixel342,pixel343,pixel344,pixel345,pixel346,pixel347,pixel348,pixel349,pixel350,pixel351,pixel352,pixel353,pixel354,pixel355,pixel356,pixel357,pixel358,pixel359,pixel360,pixel361,pixel362,pixel363,pixel364,pixel365,pixel366,pixel367,pixel368,pixel369,pixel370,pixel371,pixel372,pixel373,pixel374,pixel375,pixel376,pixel377,pixel378,pixel379,pixel380,pixel381,pixel382,pixel383,pixel384,pixel385,pixel386,pixel387,pixel388,pixel389,pixel390,pixel391,pixel392,pixel393,pixel394,pixel395,pixel396,pixel397,pixel398,pixel399,pixel400,pixel401,pixel402,pixel403,pixel404,pixel405,pixel406,pixel407,pixel408,pixel409,pixel410,pixel411,pixel412,pixel413,pixel414,pixel415,pixel416,pixel417,pixel418,pixel419,pixel420,pixel421,pixel422,pixel423,pixel424,pixel425,pixel426,pixel427,pixel428,pixel429,pixel430,pixel431,pixel432,pixel433,pixel434,pixel435,pixel436,pixel437,pixel438,pixel439,pixel440,pixel441,pixel442,pixel443,pixel444,pixel445,pixel446,pixel447,pixel448,pixel449,pixel450,pixel451,pixel452,pixel453,pixel454,pixel455,pixel456,pixel457,pixel458,pixel459,pixel460,pixel461,pixel462,pixel463,pixel464,pixel465,pixel466,pixel467,pixel468,pixel469,pixel470,pixel471,pixel472,pixel473,pixel474,pixel475,pixel476,pixel477,pixel478,pixel479,pixel480,pixel481,pixel482,pixel483,pixel484,pixel485,pixel486,pixel487,pixel488,pixel489,pixel490,pixel491,pixel492,pixel493,pixel494,pixel495,pixel496,pixel497,pixel498,pixel499,pixel500,pixel501,pixel502,pixel503,pixel504,pixel505,pixel506,pixel507,pixel508,pixel509,pixel510,pixel511,pixel512,pixel513,pixel514,pixel515,pixel516,pixel517,pixel518,pixel519,pixel520,pixel521,pixel522,pixel523,pixel524,pixel525,pixel526,pixel527,pixel528,pixel529,pixel530,pixel531,pixel532,pixel533,pixel534,pixel535,pixel536,pixel537,pixel538,pixel539,pixel540,pixel541,pixel542,pixel543,pixel544,pixel545,pixel546,pixel547,pixel548,pixel549,pixel550,pixel551,pixel552,pixel553,pixel554,pixel555,pixel556,pixel557,pixel558,pixel559,pixel560,pixel561,pixel562,pixel563,pixel564,pixel565,pixel566,pixel567,pixel568,pixel569,pixel570,pixel571,pixel572,pixel573,pixel574,pixel575,pixel576,pixel577,pixel578,pixel579,pixel580,pixel581,pixel582,pixel583,pixel584,pixel585,pixel586,pixel587,pixel588,pixel589,pixel590,pixel591,pixel592,pixel593,pixel594,pixel595,pixel596,pixel597,pixel598,pixel599,pixel600,pixel601,pixel602,pixel603,pixel604,pixel605,pixel606,pixel607,pixel608,pixel609,pixel610,pixel611,pixel612,pixel613,pixel614,pixel615,pixel616,pixel617,pixel618,pixel619,pixel620,pixel621,pixel622,pixel623,pixel624,pixel625,pixel626,pixel627,pixel628,pixel629,pixel630,pixel631,pixel632,pixel633,pixel634,pixel635,pixel636,pixel637,pixel638,pixel639,pixel640,pixel641,pixel642,pixel643,pixel644,pixel645,pixel646,pixel647,pixel648,pixel649,pixel650,pixel651,pixel652,pixel653,pixel654,pixel655,pixel656,pixel657,pixel658,pixel659,pixel660,pixel661,pixel662,pixel663,pixel664,pixel665,pixel666,pixel667,pixel668,pixel669,pixel670,pixel671,pixel672,pixel673,pixel674,pixel675,pixel676,pixel677,pixel678,pixel679,pixel680,pixel681,pixel682,pixel683,pixel684,pixel685,pixel686,pixel687,pixel688,pixel689,pixel690,pixel691,pixel692,pixel693,pixel694,pixel695,pixel696,pixel697,pixel698,pixel699,pixel700,pixel701,pixel702,pixel703,pixel704,pixel705,pixel706,pixel707,pixel708,pixel709,pixel710,pixel711,pixel712,pixel713,pixel714,pixel715,pixel716,pixel717,pixel718,pixel719,pixel720,pixel721,pixel722,pixel723,pixel724,pixel725,pixel726,pixel727,pixel728,pixel729,pixel730,pixel731,pixel732,pixel733,pixel734,pixel735,pixel736,pixel737,pixel738,pixel739,pixel740,pixel741,pixel742,pixel743,pixel744,pixel745,pixel746,pixel747,pixel748,pixel749,pixel750,pixel751,pixel752,pixel753,pixel754,pixel755,pixel756,pixel757,pixel758,pixel759,pixel760,pixel761,pixel762,pixel763,pixel764,pixel765,pixel766,pixel767,pixel768,pixel769,pixel770,pixel771,pixel772,pixel773,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783,label
\ No newline at end of file
......@@ -10,26 +10,20 @@
using namespace std;
DecisionTree::DecisionTree(vector <vector<string>> data, int maxDepth, float featureWeight,
vector <FeatureType> featureType) {
// vector<int> index = randomSelect_WithoutReplacement(data.size(), featureWeight);
// for(int i=0; i<index.size();i++){
// for(int j=1; j<data[index[i]].size(); j++){
// data[index[i]][j] = data[index[i]][0];
// }
// }
DecisionTree::DecisionTree(vector <vector<string>> &data, vector<int> &trainingIndx, int maxDepth, float featureWeight,
vector <FeatureType> &featureType) {
this->maxDepth = maxDepth;
this->featureWeight = featureWeight;
this->maxDepthReached = 0;
this->root = train(data, featureType, 0.0, &this->maxDepthReached, maxDepth, featureWeight);
//TODO int<vector> take all iteration
this->root = train(data, featureType, 0.0, 0, maxDepth, featureWeight, trainingIndx);
// this->printTree(this->root, 0);
};
Node *train(vector <vector<string>> data, vector <FeatureType> featureType,
double parentEntropy, int *currentDepth, int maxDepth, float featureWeight) {
Node *train(vector <vector<string>> &data, vector <FeatureType> &featureType,
double parentEntropy, int currentDepth, int maxDepth, float featureWeight, vector<int> nodeDatasetIndices ) {
std::pair<string, double> classificationAndEntropy = classifyWithEntropy(data);
//TODO pass data pointer and index vector
std::pair<string, double> classificationAndEntropy = classifyWithEntropy(data, nodeDatasetIndices);
string classification = classificationAndEntropy.first;
double originalEntropy = classificationAndEntropy.second;
double informationGainFromParent;
......@@ -45,18 +39,12 @@ Node *train(vector <vector<string>> data, vector <FeatureType> featureType,
} else {
// cout<<"Finding splits"<<endl;
//create a random subspace
//find possible splits
std::map<int, set<string>> potentialSplits = findAllSplittingPoints(data, featureType, featureWeight);
// cout<<"Finding best split"<<endl;
//TODO send data vector of index
//find best split point
BestSplitPoint bestSplit = findBestSplit(parentEntropy, currentDepth, data,
featureType, featureWeight, nodeDatasetIndices);
//find best split
BestSplitPoint bestSplit = findBestSplit(parentEntropy, *currentDepth, potentialSplits, data,
featureType);
//cout << "------ best split index "<<bestSplit.featureIdx<<endl;
if(bestSplit.featureIdx == -1 || bestSplit.featureIdx > data.size()-1 ){
Node *leaf = new Node(NULL, NULL, NULL, true, classification, originalEntropy, informationGainFromParent);
// cout<<"No more split"<<endl;
......@@ -65,14 +53,16 @@ Node *train(vector <vector<string>> data, vector <FeatureType> featureType,
// cout<<"splitting data"<<endl;
//split data
FeatureSplitData featureSplitData = splitData(data, bestSplit.featureIdx, featureType,
bestSplit.splitpoint);
//TODO send data and index vector
//TODO return indices for left and right
FeatureSplitDataIndx featureSplitData = splitData(data, bestSplit.featureIdx, featureType,
bestSplit.splitpoint, nodeDatasetIndices);
//No longer splittable
if (featureSplitData.dataTrue[0].size() < 1 || featureSplitData.dataFalse[0].size() < 1) {
if (featureSplitData.dataTrue.size() < 1 || featureSplitData.dataFalse.size() < 1) {
Node *leaf = new Node(NULL, NULL, NULL, true, classification, originalEntropy, informationGainFromParent);
return leaf;
}
......@@ -83,11 +73,11 @@ Node *train(vector <vector<string>> data, vector <FeatureType> featureType,
// cout<<"Next Train"<<endl;
//call train for left and right data
*currentDepth += 1;
Node *leftNode = train(featureSplitData.dataTrue, featureType, originalEntropy, currentDepth, maxDepth,
featureWeight);
Node *rightNode = train(featureSplitData.dataFalse, featureType, originalEntropy, currentDepth, maxDepth,
featureWeight);
//TODO pass int vector from splits
Node *leftNode = train(data, featureType, originalEntropy, currentDepth + 1, maxDepth,
featureWeight, featureSplitData.dataTrue);
Node *rightNode = train(data, featureType, originalEntropy, currentDepth + 1, maxDepth,
featureWeight, featureSplitData.dataFalse);
Node *node = new Node(question, leftNode, rightNode, false, classification, originalEntropy,
......@@ -97,7 +87,7 @@ Node *train(vector <vector<string>> data, vector <FeatureType> featureType,
}
}
string DecisionTree::predictSingle(vector <string> test, Node *treeRoot, PredictionReport *report) {
string DecisionTree::predictSingle(vector <string>& test, Node *treeRoot, PredictionReport *report) {
if (treeRoot->isLeaf == true) {
return treeRoot->classification;
}
......@@ -113,7 +103,7 @@ string DecisionTree::predictSingle(vector <string> test, Node *treeRoot, Predict
Node *answer;
if (featureType == CATEGORICAL) {
if (test[splitIndex] == splitValue) {
if (test[splitIndex] <= splitValue) {
answer = treeRoot->trueBranch;
} else {
answer = treeRoot->falseBranch;
......@@ -134,7 +124,7 @@ string DecisionTree::predictSingle(vector <string> test, Node *treeRoot, Predict
}
}
string DecisionTree::predictSingle(vector <string> test, Node *treeRoot) {
string DecisionTree::predictSingle(vector <string>& test, Node *treeRoot) {
if (treeRoot->isLeaf == true) {
return treeRoot->classification;
}
......
......@@ -12,10 +12,10 @@
class DecisionTree {
public:
DecisionTree(vector <vector<string>> data, int maxDepth, float featureWeight, vector <FeatureType> featureType);
DecisionTree(vector <vector<string>> &data, vector<int> &trainingIndx, int maxDepth, float featureWeight, vector <FeatureType> &featureType);
string predictSingle(vector <string> test, Node *treeRoot, PredictionReport * report);
string predictSingle(vector <string> test, Node *treeRoot);
string predictSingle(vector <string>& test, Node *treeRoot, PredictionReport * report);
string predictSingle(vector <string>& test, Node *treeRoot);
void printTree(Node *node, int space);
......@@ -30,8 +30,8 @@ private:
};
Node *train(vector <vector<string>> data, vector <FeatureType> featureType,
double entropy, int *currentDepth, int maxDepth, float featureWeight);
Node *train(vector <vector<string>> &data, vector <FeatureType> &featureType,
double entropy, int currentDepth, int maxDepth, float featureWeight, vector<int> nodeDatasetIndices);
#endif //RACER_DECISIONTREE_HPP
CC = g++
CC = g++ -g
FLAGS = -std=c++11
RACERDIR =
......
......@@ -39,14 +39,16 @@ vector<int> getParts(int trees, int cpus) {
return temp;
}
RandomForest::RandomForest(vector <vector<string>> trainingData, vector <FeatureType> featureTypes, int numTrees,
RandomForest::RandomForest(vector <vector<string>> &data, vector<int> &trainingIndx, vector <FeatureType> &featureTypes,
int numTrees,
float baggingWeight, float featureWeight, int maxDepth) {
vector < DecisionTree * > decisionTrees;
this->featureWeight = featureWeight;
this->depth = maxDepth;
unsigned num_cpus = std::thread::hardware_concurrency();
if(numTrees < num_cpus)
unsigned num_cpus = std::thread::hardware_concurrency()/2;
// unsigned num_cpus = 12;
if (numTrees < num_cpus)
num_cpus = numTrees;
// A mutex ensures orderly access.
std::mutex iomutex;
......@@ -57,14 +59,14 @@ RandomForest::RandomForest(vector <vector<string>> trainingData, vector <Feature
vector<int> temp = getParts(numTrees, num_cpus); //determine how many trees to run in parallel
for (int i = 0; i < num_cpus; i++) {
if (i < temp.size())
threads[i] = std::thread([&iomutex, i, temp, trainingData,baggingWeight,
maxDepth, featureWeight, featureTypes, &decisionTrees] {
threads[i] = std::thread([&iomutex, i, temp, &data, &trainingIndx, baggingWeight,
maxDepth, featureWeight, &featureTypes, &decisionTrees] {
for (int j = 0; j < temp.at(i); j++) {
vector <vector<string>> baggedData = bagData(trainingData, baggingWeight);
cout<<"Training tree "<< j<<" in thread "<<i<<endl;
DecisionTree *tree = new DecisionTree(baggedData, maxDepth, featureWeight, featureTypes);
cout<<"Done training tree "<< j<<" in thread "<<i<<endl;
vector <int> baggedData = bagData(trainingIndx, baggingWeight); //TODO fix this
cout << "Training tree " << j << " in thread " << i << endl;
DecisionTree *tree = new DecisionTree(data, baggedData, maxDepth, featureWeight, featureTypes);
cout << "Done training tree " << j << " in thread " << i << endl;
{
// Use a lexical scope and lock_guard to safely lock the mutex only for
// the duration of vector push.
......@@ -74,11 +76,10 @@ RandomForest::RandomForest(vector <vector<string>> trainingData, vector <Feature
}
});
}
for (auto& t : threads) {
for (auto &t : threads) {
t.join();
}
......@@ -87,7 +88,7 @@ RandomForest::RandomForest(vector <vector<string>> trainingData, vector <Feature
}
vector <string>
RandomForest::getForestPrediction(vector <string> test, RandomForest *forest, vector <string> features) {
RandomForest::getForestPrediction(vector <string>& test, RandomForest *forest, vector <string> &features) {
cout << "Trees in forest: " << to_string(forest->trees.size()) << endl;
cout << "Predicting" << endl;
vector <string> predictions(forest->trees.size());
......@@ -111,10 +112,10 @@ RandomForest::getForestPrediction(vector <string> test, RandomForest *forest, ve
explanations[treeIdx] = report;
map<string, vector<std::pair < int, double> >>::iterator itr;
cout << "Explanation "<<treeIdx<< " classified "<< report->classification << " because ";
for (itr = report->path.begin() ; itr != report->path.end(); ++itr) {
if(itr->first == report->classification){
map < string, vector < std::pair < int, double > >>::iterator itr;
cout << "Explanation " << treeIdx << " classified " << report->classification << " because ";
for (itr = report->path.begin(); itr != report->path.end(); ++itr) {
if (itr->first == report->classification) {
sort(itr->second, test, features);
}
}
......@@ -147,6 +148,10 @@ RandomForest::getForestPrediction(vector <string> test, RandomForest *forest, ve
sort(reports, test, features);
for (int j = 0; j < explanations.size(); j++) {
delete explanations[j];
}
return predictions;
}
......@@ -177,7 +182,7 @@ bool cmp(pair<int, double> &a,
return a.second > b.second;
}
vector <pair<int, double>> sort(vector<std::pair < int, double> > &M, vector <string> test, vector <string> features) {
vector <pair<int, double>> sort(vector <std::pair<int, double>> &M, vector <string> test, vector <string> features) {
// Declare vector of pairs
vector <pair<int, double>> A;
......@@ -199,7 +204,7 @@ vector <pair<int, double>> sort(vector<std::pair < int, double> > &M, vector <st
if (count > 2) {
break;
}
cout << features[it.first] << " is " << test[it.first]<< "(information gain: "<<it.second << "), ";
cout << features[it.first] << " is " << test[it.first] << "(information gain: " << it.second << "), ";
count++;
}
return A;
......@@ -236,21 +241,21 @@ vector <pair<int, double>> sort(map<int, double> &M, vector <string> test, vecto
return A;
}
vector <string> getBatchPrediction(vector <vector<string>> testData, RandomForest *forest) {
vector <string> getBatchPrediction(vector <vector<string>>& datasetAsString,vector <int>& testIdxs, RandomForest *forest) {
vector <string> predictions;
for (int testIndex = 0; testIndex < testData.at(0).size(); testIndex++) {
for (int testIndex = 0; testIndex < testIdxs.size(); testIndex++) {
map<string, int> votes;
vector <string> test;
string emptystring;
for (int featIndex = 0; featIndex < testData.size(); featIndex++) {
for (int featIndex = 0; featIndex < datasetAsString.size(); featIndex++) {
test.push_back(emptystring);
}
for (int featIndex = 0; featIndex < testData.size(); featIndex++) {
test.at(featIndex) = testData.at(featIndex).at(testIndex);
for (int featIndex = 0; featIndex < datasetAsString.size(); featIndex++) {
test.at(featIndex) = datasetAsString.at(featIndex).at(testIdxs[testIndex]);
}
//Get every tree in the forests prediction
......@@ -285,9 +290,12 @@ vector <string> getBatchPrediction(vector <vector<string>> testData, RandomFores
}
accuracyReport RandomForest::getAccuracy(vector <vector<string>> testData) {
vector <string> predictions = getBatchPrediction(testData, this);
vector <string> labels = testData.at(testData.size() - 1);
accuracyReport RandomForest::getAccuracy(vector <vector<string>>& datasetAsString,vector <int> & testIdxs) {
vector <string> predictions = getBatchPrediction(datasetAsString,testIdxs, this);
vector <string> labels;
for(int dataidx:testIdxs){
labels.push_back(datasetAsString.at(datasetAsString.size()-1).at(dataidx));
}
std::map<std::string, int> incorrectLables;
std::map<std::string, int> correctLables;
......
......@@ -12,18 +12,18 @@ public:
float featureWeight;
int depth;
RandomForest(vector <vector<string>> trainingData, vector <FeatureType> featureType, int numTrees,
RandomForest(vector <vector<string>> &data, vector<int> &trainingIndx, vector <FeatureType> &featureType, int numTrees,
float baggingWeight, float featureWeight, int maxDepth);
vector <string> getForestPrediction(vector <string> test, RandomForest *forest, vector <string> features);
accuracyReport getAccuracy(vector<vector<string>>);
vector <string> getForestPrediction(vector <string>& test, RandomForest *forest, vector <string>& features);
accuracyReport getAccuracy(vector <vector<string>>& datasetAsString,vector <int> &testIdxs);
void printAccuracyReport(accuracyReport report);
void printAccuracyReportFile(accuracyReport report);
};
vector <string> getBatchPrediction(vector<vector<string>> testData, RandomForest *forest);
vector <string> getBatchPrediction(vector <vector<string>>& datasetAsString,vector <int>& testIdxs, RandomForest *forest);
map<int, double> explain(string classification, vector<PredictionReport *> reports);
......
This diff is collapsed.
......@@ -4,6 +4,7 @@
#include <iterator>
#include <bits/stdc++.h>
#include "util.hpp"
#include "Node.hpp"
#ifndef HELPERS_HPP
#define HELPERS_HPP
......@@ -18,6 +19,11 @@ struct FeatureSplitData {
vector<vector<string>> dataFalse;
};
struct FeatureSplitDataIndx {
vector<int> dataTrue;
vector<int> dataFalse;
};
struct BestSplitPoint {
int featureIdx;
string splitpoint;
......@@ -27,20 +33,17 @@ vector<vector<string>> parseDataToString(string dataFile);
vector <FeatureType> parseFeatureTypes(string fileName);
vector<int> randSelectIdxWithoutReplacement(int originalNum, float percentTraining);
vector<int> randSelectIdxWithReplacement(int originalNum, float percent);
void splitTrainingAndTesting(vector<int> trainingIndecies,vector<vector<string>> dataString,
vector<vector<string>>& trainingDataString,vector<vector<string>>& testDataString);
bool dataIsPure(vector <vector<string>> data);
vector<int> splitTrainingAndTesting(vector<int> trainingIndecies, vector <vector<string>> &dataString);
string classifyData(vector <vector<string>> data);
std::pair<string,double> classifyWithEntropy(vector<vector<string>> data);
std::map<int,set<string>>
findAllSplittingPoints(vector <vector<string>> data, vector <FeatureType> featureType, float featureWeight);
FeatureSplitData splitData(vector<vector<string>>data, int splitFeature,vector<FeatureType> featureTypes, string splitValue);
BestSplitPoint findBestSplit(double parentEntropy, int currentDepth,std::map<int, set<string>> potentialSplits, vector<vector<string>> data, vector<FeatureType> featureTypes);
float calculateEntropy(vector<vector<string>> data);
float calculateSplitEntropy (FeatureSplitData featsplitData);
vector<vector<string>> bagData(vector<vector<string>> data, float baggingWeight);
std::pair<string,double> classifyWithEntropy(vector<vector<string>> &data, vector<int> &indices);
FeatureSplitDataIndx splitData(vector<vector<string>>& data, int splitFeature,vector<FeatureType> featureTypes, string splitValue, vector<int> &nodeDatasetIndices );
float calculateEntropy(vector <vector<string>>& data, vector<int> indices) ;
float calculateSplitEntropy (FeatureSplitDataIndx featsplitData, vector<vector<string>> &data);
vector <int> bagData(vector <int> &indices, float baggingWeight);
vector<int> randomSelect_WithoutReplacement(int originalNum, float percentTraining);
vector<vector<string>> oversample(vector<vector<string>> data);
vector<int> oversample(vector<vector<string>> &data, vector<int> &indices);
BestSplitPoint findBestSplit(double parentEntropy, int currentDepth, vector <vector<string>> &data,
vector <FeatureType> featureType, float featureWeight, vector<int>& nodeDatasetIndices );
void cleanTree(Node *node);
#endif
\ No newline at end of file
......@@ -35,7 +35,7 @@ int main(int argc, char *argv[]) {
int numTrees = atoi(argv[1]);
float baggingWeight = 0.7;
int depth = atoi(argv[2]);
float featureWeight = 0.3;
float featureWeight = 0.005;
// double featWeight = numFeatures * 0.1;
// cout << featWeight << "\n";
......@@ -44,80 +44,74 @@ int main(int argc, char *argv[]) {
vector <vector<string>> datasetAsString;
vector <FeatureType> featureTypes;
vector <string> features;
datasetAsString = parseDataToString("../datasets/adult.data");
featureTypes = parseFeatureTypes("../datasets/adult.featureTypes");
features = parseFeatures("../datasets/adult.features");
datasetAsString = parseDataToString("../datasets/mnist.data");
featureTypes = parseFeatureTypes("../datasets/mnist.featureTypes");
features = parseFeatures("../datasets/mnist.features");
double accuracy = 0.0;
double time = 0.0;
for (int x = 0; x < 3; x++) {
vector<int> trainingIdxs = randSelectIdxWithoutReplacement(datasetAsString.at(0).size(), 0.7);
vector <vector<string>> trainingData;
vector <vector<string>> testingData;
//vector <vector<string>> trainingData;
vector <int> testingIdxs = splitTrainingAndTesting(trainingIdxs, datasetAsString);
cout << "Over sampling training data " << endl;
splitTrainingAndTesting(trainingIdxs, datasetAsString, trainingData, testingData);
vector<int> oversampledData = oversample(datasetAsString, trainingIdxs);
cout << "Over sampling training data " << endl;;
vector <vector<string>> oversampledData = oversample(trainingData);
// cout << "over sampled data size "<< oversampledData.at(0).size() <<endl;
for (int fIdx = 0; fIdx < trainingData.size(); ++fIdx) {
for (int oIdx = 0; oIdx < oversampledData.at(0).size(); ++oIdx) {
trainingData.at(fIdx).emplace_back(oversampledData.at(fIdx).at(oIdx));
}
}
// cout<< "training data size after oversample" << trainingData.at(0).size()<<endl;
trainingIdxs.insert(trainingIdxs.end(), oversampledData.begin(), oversampledData.end());
// sort(trainingIdxs.begin(), trainingIdxs.end());
vector <string> testData;
string emptystring;
for (int featIndex = 0; featIndex < testingData.size(); featIndex++) {
for (int featIndex = 0; featIndex < datasetAsString.size(); featIndex++) {
testData.push_back(emptystring);
}
// string data = testingData.at(1).at(0);
// cout << data << endl;
// for (int featIndex = 0; featIndex < testingData.size(); featIndex++) {
// testData.at(featIndex) = testingData.at(featIndex).at(41);
// //cout<<testingData.at(featIndex).at(0)<<", ";
// }
//cout<<endl;
auto start = high_resolution_clock::now();
RandomForest *randomForest = new RandomForest(trainingData, featureTypes, numTrees, baggingWeight, featureWeight, depth);
RandomForest *randomForest = new RandomForest(datasetAsString, trainingIdxs, featureTypes, numTrees,
baggingWeight, featureWeight, depth);
time += (high_resolution_clock::now() - start).count() / 1000000000.0;
cout << endl;
cout << "********************* Forest accuracy *****************" << endl;
accuracyReport report = randomForest->getAccuracy(testingData);
accuracyReport report = randomForest->getAccuracy(datasetAsString,testingIdxs);
accuracy += report.accuracy;
randomForest->printAccuracyReportFile(report);
cout << "**************** prediction with explanation ********** " << endl;
for (int featIndex = 0; featIndex < testingData.size(); featIndex++) {
testData.at(featIndex) = testingData.at(featIndex)[0];
cout << testingData.at(featIndex).at(0) << ", ";
for (int featIndex = 0; featIndex < datasetAsString.size(); featIndex++) {
testData.at(featIndex) = datasetAsString.at(featIndex)[testingIdxs[0]];
cout << datasetAsString.at(featIndex)[testingIdxs[0]] << ", ";
}
cout << endl;
randomForest->getForestPrediction(testData, randomForest, features);
for (int i = 0; i<randomForest->trees.size(); i++){
cleanTree(randomForest->trees[i]->root);
delete randomForest->trees[i];
}
delete randomForest;
}
ofstream outfile;
outfile.open("avg.txt", ios::app);
outfile<< "------ Report ------ " <<endl;
outfile<< numTrees<<"\t"<<depth<<"\t"<<featureWeight<<"\t"<<baggingWeight<<"\t"<<accuracy/3<<"\t"<<time/3<<endl;
outfile << "------ Report ------ " << endl;
outfile << numTrees << "\t" << depth << "\t" << featureWeight << "\t" << baggingWeight << "\t" << accuracy / 3
<< "\t" << time / 3 << endl;
// outfile<< numTrees<<"\t"<<10<<"\t"<<0.7<<"\t"<<baggingWeight<<"\t"<<accuracy/3<<"\t"<<time/3<<endl;
outfile.close();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment