import java.io.*;
import weka.*;
public class Prediction {
public static BufferedReader readDataFile(String file) {
BufferedReader inp = null;
try {
inp = new BufferedReader(new FileReader(file));
} catch (FileNotFoundException e) {
System.err.println(e);
}
return inp;
}
public static Evaluation classify(Classifier model,
Instances trainingSet, Instances testingSet) throws Exception {
Evaluation eval = new Evaluation(trainingSet);
model.buildClassifier(trainingSet);
eval.evaluateModel(model, testingSet);
return eval;
}
public static double calculateAccuracy(FastVector preds) {
double ct = 0;
for (int i = 0; i < preds.size(); i++) {
NominalPrediction np = (NominalPrediction) preds.elementAt(i);
if (np.predicted() == np.actual()) {
ct++;
}
}
return 100 * ct / preds.size();
}
public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) {
Instances[][] split = new Instances[2][numberOfFolds];
for (int i = 0; i < numberOfFolds; i++) {
split[0][i] = data.trainCV(numberOfFolds, i);
split[1][i] = data.testCV(numberOfFolds, i);
}
return split;
}
public static void main(String[] args) throws Exception {
BufferedReader datafile = readDataFile("elections.txt");
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
Instances[][] split = crossValidationSplit(data, 10);
Instances[] trainingSplits = split[0];
Instances[] testingSplits = split[1];
Classifier[] models = {
new J48(),
new PART(),
new DecisionTable(),
new DecisionStump()
};
for (int j = 0; j < models.length; j++) {
FastVector preds = new FastVector();
for (int i = 0; i < trainingSplits.length; i++) {
Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);
preds.appendElements(validation.preds());
}
double accuracy = calculateAccuracy(preds);
System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "
+ String.format("%.2f%%", accuracy)
+ "\n---------------------------------");
}
}
}
Comments
Leave a comment