package us.hall.weka;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import weka.core.*;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.AdaBoostM1;
import weka.classifiers.meta.Bagging;
import weka.classifiers.lazy.IBk;

/**
 * Wrap a Bagging around a AdaBoostM1 of a IBk classification
 **/
public class BaggingBoostIBk {
	private Instances trainingData = null, testData = null;
	private int knn = 1;
	private Classifier classifier = null;
	private File test = null, training = null;
	private boolean reverse = false, evaluate = false;
	
	public static void main(String[] args) {
		File training = null;
		File test = null;
		int knn = 1;
		boolean reverse = false, evaluate = false;
		
		for (int i=0;i<args.length;i++) {
			if (args[i].equals("-t")) {			// Q. training ARFF?
				i++;
				training = new File(args[i]);
			}
			else if (args[i].equals("-T")) {		// Q. test ARFF?
				i++;
				test = new File(args[i]);
			}	
			else if (args[i].equals("-K")) {		// Q. number of nearest neighbors	?
				i++;
				knn = new Integer(args[i]).intValue();
			}
			else if (args[i].equals("-r")) {		// Reverse Bagging and AdaBoostM1?
				reverse = true;
			}
			else if (args[i].equals("-e")) {		// Evaluate instead of predict?
				evaluate = true;
			}
		}
		if (training == null || !training.exists()) {
			System.out.println("Training arff dataset " + training + " is invalid or missing");
			return;
		}
		if (test == null || !test.exists()) {
			System.out.println("Test arff dataset " + test + " is invalid or missing");
			return;
		}
		BaggingBoostIBk bbIBk = new BaggingBoostIBk(training,test,knn,reverse,evaluate);
		try {
			bbIBk.process();
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	public BaggingBoostIBk(File training,File test,int knn,boolean reverse,boolean evaluate) {
		this.knn = knn;
		this.training = training;
		this.test = test;
		this.reverse = reverse;
		this.evaluate = evaluate;
		try {
			Bagging bagging = new Bagging();			// Get the bagging classifier wrapper
			AdaBoostM1 boost = new AdaBoostM1();		// Get the AdaBoostM1 meta classifier
			IBk ibk = new IBk();						// Get the IBk classifier
			ibk.setKNN(knn);
			if (reverse) {
				classifier = boost;
				boost.setClassifier(bagging);
				bagging.setClassifier(ibk);
			}
			else {
				classifier = bagging;
				bagging.setClassifier(boost);
				boost.setClassifier(ibk);
			}
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	private void process() throws Exception {
		BufferedReader trainingRdr = new BufferedReader(new FileReader(training));
		trainingData = new Instances(trainingRdr);
		trainingData.setClassIndex(trainingData.numAttributes() - 1);
		if (evaluate) {
			evaluate();
		}
		else {
			build();
			predict();
		}
	}
	
	private void build() throws Exception {
		classifier.buildClassifier(trainingData);
		trainingData = null;
	}
	
	private void evaluate() throws Exception {
		Evaluation eval = new Evaluation(trainingData);
		eval.crossValidateModel(classifier, trainingData, 10, new java.util.Random(1));
		System.out.println(eval.toSummaryString("\nResults\n======\n", false));
		System.out.println("");
		System.out.println(eval.toMatrixString());
		System.out.println("");
		System.out.println(eval.toClassDetailsString());
	}
	
	private void predict() throws Exception {
		BufferedReader testRdr = new BufferedReader(new FileReader(test));
		testData = new Instances(testRdr);
		testData.setClassIndex(testData.numAttributes() - 1);
		for (int i = 0; i < testData.numInstances(); i++) {
			int pred = (int)classifier.classifyInstance(testData.instance(i)) + 1;
			System.out.println(pred);
		}
	}
}