main.cpp 5.23 KB
Newer Older
ahmedaj's avatar
ahmedaj committed
1
#include "helpers.hpp"
sgebreeg's avatar
sgebreeg committed
2
#include "DecisionTree.hpp"
ahmedaj's avatar
ahmedaj committed
3
#include "RandomForest.hpp"
ahmedaj's avatar
ahmedaj committed
4
#include "util.hpp"
sgebreeg's avatar
sgebreeg committed
5

sgebreeg's avatar
sgebreeg committed
6
using namespace std::chrono;
sgebreeg's avatar
sgebreeg committed
7
8
using namespace std;

9
10
11
// Parse CSV file of feature names
// fileName = name of the file
// Returns vector of strings, each a name of a feature
sgebreeg's avatar
sgebreeg committed
12
vector <string> parseFeatures(string fileName) {
13
14
15
    ifstream fIn;
    fIn.open(fileName, ifstream::in);
    string nextString, line, word;
sgebreeg's avatar
sgebreeg committed
16
    vector <string> features;
17
18
19
    // cout << "Features: " << flush;
    while (fIn >> line) {
        stringstream ss(line);
sgebreeg's avatar
sgebreeg committed
20
21
        while (getline(ss, word, ',')) {
            word = (string)(word.find_first_of(" ") == 0 ? word.substr(1) : word);
22
23
24
25
26
27
            features.push_back(word);
        }
    }
    return features;
}

ahmedaj's avatar
ahmedaj committed
28
int main(int argc, char *argv[]) {
sgebreeg's avatar
sgebreeg committed
29

mccrabb's avatar
mccrabb committed
30
31
32
    if ((argc <= 2) || (argc >= 5)) {
        cout << "Given " << to_string(argc) <<" args. Need 2 or 3." << endl;
		cout << "race [numTrees] [depth] [dataset (no suffix, optional)]" << endl;
ahmedaj's avatar
ahmedaj committed
33
34
35
36
        exit(1);
    }

    int numTrees = atoi(argv[1]);
ahmedaj's avatar
ahmedaj committed
37
    float baggingWeight = 0.7;
sgebreeg's avatar
sgebreeg committed
38
    int depth = atoi(argv[2]);
ahmedaj's avatar
ahmedaj committed
39
40
    // double featWeight = numFeatures * 0.1;

sgebreeg's avatar
sgebreeg committed
41
    // cout << featWeight << "\n";
ahmedaj's avatar
ahmedaj committed
42
    // cout << featWeight << "\n";
sgebreeg's avatar
sgebreeg committed
43

ahmedaj's avatar
ahmedaj committed
44
    vector <vector<string>> datasetAsString,encodedDatasetAsString;
sgebreeg's avatar
sgebreeg committed
45
    vector <FeatureType> featureTypes;
ahmedaj's avatar
ahmedaj committed
46
    vector <string> features,encodedfeatures;
mccrabb's avatar
mccrabb committed
47
48
	if (argc == 4) {
		cout << "Dataset: " << argv[3] << endl;
ahmedaj's avatar
ahmedaj committed
49
		datasetAsString,encodedDatasetAsString = parseDataToString( (string)argv[3] + ".data");
mccrabb's avatar
mccrabb committed
50
   		featureTypes = parseFeatureTypes( (string)argv[3] + ".featureTypes");
ahmedaj's avatar
ahmedaj committed
51
	    features,encodedfeatures = parseFeatures( (string)argv[3] + ".features");
mccrabb's avatar
mccrabb committed
52
53
54
	}
	else {
		cout << "WARNING: No dataset provided as an argument!" << endl;
ahmedaj's avatar
ahmedaj committed
55
	    datasetAsString = parseDataToString("../datasets/adult1.data");
mccrabb's avatar
mccrabb committed
56
57
58
   		featureTypes = parseFeatureTypes("../datasets/adult.featureTypes");
	    features = parseFeatures("../datasets/adult.features");
	}
ahmedaj's avatar
ahmedaj committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    
    encodedDatasetAsString = datasetAsString;
    encodedDatasetAsString.pop_back();
    encodedfeatures = features;

    vector<int> featuresToEncode;
    featuresToEncode.push_back(5);

    encodeData(datasetAsString,encodedDatasetAsString,features,encodedfeatures,featuresToEncode);
    
    for(string feature:encodedfeatures){
        cout << "   " << feature << " ";
    }
    cout<<endl;

    for(int dataIdx=0; dataIdx < encodedDatasetAsString[0].size(); dataIdx++){
        for(int i = 0; i< encodedDatasetAsString.size();i++){
            cout << encodedDatasetAsString[i][dataIdx]<<",";
        }
        cout<<endl;
    }
sgebreeg's avatar
sgebreeg committed
80

ahmedaj's avatar
ahmedaj committed
81
82
//     //pick number of features to select for random sub-spacing
//     float featureWeight = sqrt(features.size())/features.size();
ahmedaj's avatar
ahmedaj committed
83

ahmedaj's avatar
ahmedaj committed
84
85
86
87
88
89
//     double accuracy = 0.0;
//     double time = 0.0;
//     for (int x = 0; x < 3; x++) {
//         vector<int> trainingIdxs = randomSelect_WithoutReplacement(datasetAsString.at(0).size(), 0.7);
//         //vector <vector<string>> trainingData;
//         vector <int> testingIdxs = splitTrainingAndTesting(trainingIdxs, datasetAsString);
ahmedaj's avatar
ahmedaj committed
90

ahmedaj's avatar
ahmedaj committed
91
//         cout << "Over sampling training data " << endl;
ahmedaj's avatar
ahmedaj committed
92

ahmedaj's avatar
ahmedaj committed
93
//         vector<int> oversampledData = oversample(datasetAsString, trainingIdxs);
ahmedaj's avatar
ahmedaj committed
94

ahmedaj's avatar
ahmedaj committed
95
96
//         trainingIdxs.insert(trainingIdxs.end(), oversampledData.begin(), oversampledData.end());
// //        sort(trainingIdxs.begin(), trainingIdxs.end());
sgebreeg's avatar
sgebreeg committed
97

ahmedaj's avatar
ahmedaj committed
98
99
100
101
102
//         vector <string> testData;
//         string emptystring;
//         for (int featIndex = 0; featIndex < datasetAsString.size(); featIndex++) {
//             testData.push_back(emptystring);
//         }
sgebreeg's avatar
sgebreeg committed
103
104


ahmedaj's avatar
ahmedaj committed
105
//         auto start = high_resolution_clock::now();
ahmedaj's avatar
ahmedaj committed
106

ahmedaj's avatar
ahmedaj committed
107
108
//         RandomForest *randomForest = new RandomForest(datasetAsString, trainingIdxs, featureTypes, numTrees,
//                                                       baggingWeight, featureWeight, depth);
sgebreeg's avatar
sgebreeg committed
109

sgebreeg's avatar
sgebreeg committed
110

ahmedaj's avatar
ahmedaj committed
111
//         time += (high_resolution_clock::now() - start).count() / 1000000000.0;
ahmedaj's avatar
ahmedaj committed
112
113


ahmedaj's avatar
ahmedaj committed
114
//         cout << endl;
sgebreeg's avatar
sgebreeg committed
115
116
//        cout << "********************* Forest accuracy *****************" << endl;
//        accuracyReport report = randomForest->getAccuracy(datasetAsString,testingIdxs);
ahmedaj's avatar
ahmedaj committed
117

sgebreeg's avatar
sgebreeg committed
118
119
//        accuracy += report.accuracy;
//        randomForest->printAccuracyReportFile(report);
ahmedaj's avatar
ahmedaj committed
120

ahmedaj's avatar
ahmedaj committed
121

ahmedaj's avatar
ahmedaj committed
122
//         cout << "**************** prediction with explanation ********** " << endl;
ahmedaj's avatar
ahmedaj committed
123

ahmedaj's avatar
ahmedaj committed
124
125
126
127
128
//         for (int featIndex = 0; featIndex < datasetAsString.size(); featIndex++) {
//             testData.at(featIndex) = datasetAsString.at(featIndex)[testingIdxs[0]];
//             cout << datasetAsString.at(featIndex)[testingIdxs[0]] << ", ";
//         }
//         cout << endl;
ahmedaj's avatar
ahmedaj committed
129

ahmedaj's avatar
ahmedaj committed
130

ahmedaj's avatar
ahmedaj committed
131
//         randomForest->getForestPrediction(testData, randomForest, features);
ahmedaj's avatar
ahmedaj committed
132

ahmedaj's avatar
ahmedaj committed
133
//         for (int i = 0; i<randomForest->trees.size(); i++){
ahmedaj's avatar
ahmedaj committed
134
        
ahmedaj's avatar
ahmedaj committed
135
136
137
138
//             cleanTree(randomForest->trees[i]->root);
//             delete randomForest->trees[i];
//         }
//         delete randomForest;
ahmedaj's avatar
ahmedaj committed
139

ahmedaj's avatar
ahmedaj committed
140

ahmedaj's avatar
ahmedaj committed
141
142
143
144
145
146
147
148
149
//     }
//     ofstream outfile;
//     outfile.open("avg.txt", ios::app);
//     outfile << "------ Report ------  " << endl;
//     outfile << numTrees << "\t" << depth << "\t" << featureWeight << "\t" << baggingWeight << "\t" << accuracy / 3
//             << "\t" << time / 3 << endl;
// //    outfile<< numTrees<<"\t"<<10<<"\t"<<0.7<<"\t"<<baggingWeight<<"\t"<<accuracy/3<<"\t"<<time/3<<endl;

//     outfile.close();
ahmedaj's avatar
ahmedaj committed
150

sgebreeg's avatar
sgebreeg committed
151
152
153
    return 0;

}