package us.hall.weka;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import weka.core.*;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.AdaBoostM1;
import weka.classifiers.meta.Bagging;
import weka.classifiers.functions.SMO;
import weka.classifiers.functions.supportVector.Puk;

/**
 * Wrap a Bagging around a AdaBoostM1 of a SMO classification
 **/
public class BaggingBoostSMO {
	private Instances trainingData = null, testData = null;
	private Classifier classifier = null;
	private File test = null, training = null;
	private boolean reverse = false, evaluate = false;
	
	public static void main(String[] args) {
		File training = null;
		File test = null;
		int depth = 0, numFeatures = 14, numTrees = 30;
		boolean reverse = false, evaluate = false;
		
		for (int i=0;i<args.length;i++) {
			if (args[i].equals("-t")) {			// Q. training ARFF?
				i++;
				training = new File(args[i]);
			}
			else if (args[i].equals("-T")) {		// Q. test ARFF?
				i++;
				test = new File(args[i]);
			}	
			else if (args[i].equals("-r")) {		// Reverse Bagging and AdaBoostM1?
				reverse = true;
			}
			else if (args[i].equals("-e")) {		// Evaluate instead of predict?
				evaluate = true;
			}
		}
		if (training == null || !training.exists()) {
			System.out.println("Training arff dataset " + training + " is invalid or missing");
			return;
		}
		if (test == null || !test.exists()) {
			System.out.println("Test arff dataset " + test + " is invalid or missing");
			return;
		}
		BaggingBoostSMO bbsmo = new BaggingBoostSMO(training,test,reverse,evaluate);
		try {
			bbsmo.process();
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	public BaggingBoostSMO(File training,File test,boolean reverse,boolean evaluate) {
		this.training = training;
		this.test = test;
		this.reverse = reverse;
		this.evaluate = evaluate;
		try {
			Bagging bagging = new Bagging();			// Get the bagging classifier wrapper
			AdaBoostM1 boost = new AdaBoostM1();		// Get the AdaBoostM1 meta classifier
			SMO smo = new SMO();						// Get the SMO classifier
			smo.setBuildLogisticModels(true);
			smo.setC(6.0d);
			Puk puk = new Puk();
			smo.setKernel(puk);
			if (reverse) {
				classifier = boost;
				boost.setClassifier(bagging);
				bagging.setClassifier(smo);
			}
			else {
				classifier = bagging;
				bagging.setClassifier(boost);
				boost.setClassifier(smo);
			}
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	private void process() throws Exception {
		BufferedReader trainingRdr = new BufferedReader(new FileReader(training));
		trainingData = new Instances(trainingRdr);
		trainingData.setClassIndex(trainingData.numAttributes() - 1);
		if (evaluate) {
			evaluate();
		}
		else {
			build();
			predict();
		}
	}
	
	private void build() throws Exception {
		classifier.buildClassifier(trainingData);
		trainingData = null;
	}
	
	private void evaluate() throws Exception {
		Evaluation eval = new Evaluation(trainingData);
		eval.crossValidateModel(classifier, trainingData, 10, new java.util.Random(1));
		System.out.println(eval.toSummaryString("\nResults\n======\n", false));
		System.out.println("");
		System.out.println(eval.toMatrixString());
		System.out.println("");
		System.out.println(eval.toClassDetailsString());
	}
	
	private void predict() throws Exception {
		BufferedReader testRdr = new BufferedReader(new FileReader(test));
		testData = new Instances(testRdr);
		testData.setClassIndex(testData.numAttributes() - 1);
		for (int i = 0; i < testData.numInstances(); i++) {
			int pred = (int)classifier.classifyInstance(testData.instance(i)) + 1;
			System.out.println(pred);
		}
	}
}