package us.hall.weka;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import weka.core.*;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.Vote;
import weka.classifiers.meta.AdaBoostM1;
import weka.classifiers.meta.Bagging;
import weka.classifiers.lazy.IBk;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.Puk;

/**
 * Kaggle Forest Competition Vote ensemble of reasonably well working classifiers
 **/
public class ForestVote {
	private Instances trainingData = null, testData = null;
	private Classifier classifier = null;
	private File test = null, training = null;
	private boolean evaluate = false;
	
	public static void main(String[] args) {
		File training = null;
		File test = null;
		boolean evaluate = false;
		
		for (int i=0;i<args.length;i++) {
			if (args[i].equals("-t")) {			// Q. training ARFF?
				i++;
				training = new File(args[i]);
			}
			else if (args[i].equals("-T")) {		// Q. test ARFF?
				i++;
				test = new File(args[i]);
			}	
			else if (args[i].equals("-e")) {		// Evaluate instead of predict?
				evaluate = true;
			}
		}
		if (training == null || !training.exists()) {
			System.out.println("Training arff dataset " + training + " is invalid or missing");
			return;
		}
		if (test == null || !test.exists()) {
			System.out.println("Test arff dataset " + test + " is invalid or missing");
			return;
		}
		ForestVote fv = new ForestVote(training,test,evaluate);
		try {
			fv.process();
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	public ForestVote(File training,File test,boolean evaluate) {
		this.training = training;
		this.test = test;
		this.evaluate = evaluate;
		try {
			Vote vote = new Vote();
			classifier = vote;
			Bagging bagging = new Bagging();			// Get the bagging classifier wrapper
			AdaBoostM1 boost = new AdaBoostM1();		// Get the AdaBoostM1 meta classifier
			RandomForest rf = new RandomForest();		// Get the RandomForest classifier
			rf.setMaxDepth(0);
			rf.setNumFeatures(14);
			rf.setNumTrees(35);
			boost.setClassifier(bagging);
			bagging.setClassifier(rf);
			vote.aggregate(bagging);					// 1st classifier AdaBoostM1 -> Bagging -> RandomForest
			AdaBoostM1 boost2 = new AdaBoostM1();		
			SMO smo = new SMO();
			smo.setBuildLogisticModels(true);
			smo.setC(6.0d);
			Puk puk = new Puk();
			smo.setKernel(puk);
			boost2.setClassifier(smo);
			vote.aggregate(boost2);						// 2nd classifier AdaBoostM1 -> SMO
			IBk knn = new IBk();				
			knn.setKNN(2);
			knn.setDistanceWeighting(new SelectedTag(IBk.WEIGHT_SIMILARITY,IBk.TAGS_WEIGHTING));
			vote.aggregate(knn);						// 3rd classifier IBk straight up
			vote.finalizeAggregation();		
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	private void process() throws Exception {
		BufferedReader trainingRdr = new BufferedReader(new FileReader(training));
		trainingData = new Instances(trainingRdr);
		trainingData.setClassIndex(trainingData.numAttributes() - 1);
		if (evaluate) {
			evaluate();
		}
		else {
			build();
			predict();
		}
	}
	
	private void build() throws Exception {
		classifier.buildClassifier(trainingData);
		trainingData = null;
	}
	
	private void evaluate() throws Exception {
		Evaluation eval = new Evaluation(trainingData);
		eval.crossValidateModel(classifier, trainingData, 10, new java.util.Random(1));
		System.out.println(eval.toSummaryString("\nResults\n======\n", false));
		System.out.println("");
		System.out.println(eval.toMatrixString());
		System.out.println("");
		System.out.println(eval.toClassDetailsString());
	}
	
	private void predict() throws Exception {
		BufferedReader testRdr = new BufferedReader(new FileReader(test));
		testData = new Instances(testRdr);
		testData.setClassIndex(testData.numAttributes() - 1);
		for (int i = 0; i < testData.numInstances(); i++) {
			int pred = (int)classifier.classifyInstance(testData.instance(i)) + 1;
			System.out.println(pred);
		}
	}
}