package weka.classifiers.functions;

import java.util.ArrayList;

import weka.filters.unsupervised.attribute.PrincipalComponents;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;

public class PCR extends AbstractClassifier implements
  OptionHandler, WeightedInstancesHandler {
  
  FilteredClassifier classifier = new FilteredClassifier();
  PrincipalComponents pcaFilter = new PrincipalComponents();
  LinearRegression regressor = new LinearRegression();
  String t = null, T = null;
  boolean debug = true;
  
  public PCR() {
    m_numDecimalPlaces = 4;
    pcaFilter.setVarianceCovered(1d);
    pcaFilter.setCenterData(true);
    classifier.setFilter(pcaFilter);
    // set regressor for no attribute selection, pca takes care of that - I think, should verify
    // but this was per example
    regressor.setAttributeSelectionMethod(new SelectedTag(1, LinearRegression.TAGS_SELECTION));
    regressor.setEliminateColinearAttributes(true);
    classifier.setClassifier(regressor);
  }

  /**
   * Generates a principal components regression classifier.
   *
   * @param argv the options
   */
  public static void main(String argv[]) {
    runClassifier(new PCR(), argv);
  }
  
  /**
   * Returns a string describing this classifier
   *
   * @return a description of the classifier suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String globalInfo() {
    return "Class for using linear regression on results of Principal Components "
      + "filter. aka principal components regression.";
  }
  
  /**
   * Returns default capabilities of the classifier.
   *
   * @return the capabilities of this classifier
   */
  @Override
  public Capabilities getCapabilities() {
     return regressor.getCapabilities();
  }

  /**
   * Classifies a given instance after filtering.
   *
   * @param instance the instance to be classified
   * @return the class distribution for the given instance
   * @throws Exception if instance could not be classified successfully
   */
  public double[] distributionForInstance(Instance instance) throws Exception {
     return classifier.distributionForInstance(instance);
  }
    
  /**
   * Builds a regression model for the given data.
   *
   * @param data the training data to be used for generating the linear
   *          regression function
   * @throws Exception if the classifier could not be built successfully
   */
  @Override
  public void buildClassifier(Instances data) throws Exception {
     // Check if max attributes (components) is set. If not find best max with instances
     AccessiblePCA pca = new AccessiblePCA();
     pca.setVarianceCovered(1d);
     pca.setCenterData(true);
     Capabilities caps = pca.getCapabilities();
     // the filtered classifier always needs a class
     caps.disable(Capability.NO_CLASS);
 	 caps.testWithFail(data);
     // can classifier handle the data?
     regressor.getCapabilities().testWithFail(data);
     int bestM = -1;
     if (pca.getMaximumAttributes() == -1) {
        System.out.println("Comp #\tCorrelation     \tRMSE");
        System.out.println("######\t###########     \t####");
        bestM = search(data);
        System.out.println("Best number of components found was " + bestM);
        pca = new AccessiblePCA();
        pca.setMaximumAttributes(bestM);
        pca.setVarianceCovered(1d);
        pca.setCenterData(true);
        pcaFilter.setMaximumAttributes(bestM);

     }
	 else System.out.println("pca max attr set to " + pca.getMaximumAttributes());
	 ArrayList<String> options = new ArrayList<String>();
	 // Set the FilteredClassifier options
	 //
	 // For the filter
	 options.add("-F");
	 StringBuilder sb = new StringBuilder(".unsupervised.attribute.PrincipalComponents -R 1 -C -M ");
	 sb.append(new Integer(bestM).toString());
	 options.add(sb.toString());
	 options.add("-W");
	 options.add(".LinearRegression");
	 if (t != null) {
	    options.add("-t");
	    options.add(t);
	 }
	 if (T != null) {
	    options.add("-T");
	    options.add(T);
	 }
	 options.add("--");
	 options.add("-C");
	 options.add("-S");
	 options.add("1");
	 classifier.setOptions(options.toArray(new String[0]));
     classifier.buildClassifier(data);
  }
  
  /**
   * Determine the best number of components to use 
   */
  private int search(Instances data) throws Exception {
     AccessiblePCA pca = new AccessiblePCA();
     pca.setVarianceCovered(1d);
     pca.setCenterData(true);
     pca.setup(data);
     int nonZero = pca.getNonZeroIndex();
     int gtOne = pca.getOneValueIndex();
     if (debug) 
        System.out.println("First non-zero component at " + nonZero + ". 1 or more at " + gtOne);
     Instances fData = clone(data);	// new Instances(data,0);
     double[] result = trial(gtOne,fData);
     double r = result[0];            // Correlation for eigen values greater than one
     double rInit = r;                // Initial correlation value
     fData = clone(data);
     result = trial(nonZero,fData);
     boolean done = false;
     int rightEdge = gtOne;
     int leftEdge = gtOne;
     int compNum = nonZero;           // Initial component number is first non-zero
     int LEFT = 1;
     int RIGHT = 0;
     int dir = LEFT;
     int bestAttrNum = -1;
     while (!done) {
        fData = clone(data);
        int currCompNum = compNum;
        if (r > result[0]) {       // If prior result was better than current
           if (rInit > result[0]) {   // If worse than initial, too many components
              leftEdge = compNum;     // This becomes left edge
              compNum = leftEdge - ((leftEdge - rightEdge) / 2);  // move right
              dir = RIGHT;
           }
           else {				 
              if (dir == LEFT) {      
                 leftEdge = compNum;
                 compNum = leftEdge - ((leftEdge - rightEdge) / 2);
                 dir = RIGHT;
               }
               else {   
                 rightEdge = compNum;              
                 compNum = rightEdge + ((leftEdge - rightEdge) / 2);
			     dir = LEFT;
			   }
           }
        }
        else {
           r = result[0];         // make new best correlation
           if (dir == LEFT) {     // Keep going in current direction
              rightEdge = compNum;
              compNum = rightEdge + ((leftEdge - rightEdge) / 2);
           }
           else {
              leftEdge = compNum;
              compNum = leftEdge - ((leftEdge - rightEdge) / 2);
           }

        }
	    if (Math.abs(compNum - currCompNum) <= 5) {
	       bestAttrNum = fData.numAttributes() - 1 - currCompNum;
	       done = true;
	       continue;
	    }
        result = trial(compNum,fData);
     }
     return bestAttrNum;
  }
  
  /**
   * Do a cross-validation for a given number of components
   */
  private double[] trial(int numComps,Instances data) throws Exception {
     AccessiblePCA pca = new AccessiblePCA();
     LinearRegression regressor;
     pca.setVarianceCovered(1d);
     pca.setCenterData(true);
//     Instances fData = clone(data);	// new Instances(data,0);
     pca.setInputFormat(data); // filter capabilities are checked here
     int maxAttr = data.numAttributes() - 1 - numComps;
     pca.setMaximumAttributes(maxAttr);
     data = Filter.useFilter(data, pca);
     regressor = new LinearRegression();
     regressor.setAttributeSelectionMethod(new SelectedTag(1, LinearRegression.TAGS_SELECTION));
     regressor.setEliminateColinearAttributes(true);
     double[] result = evaluate(data,regressor);
     System.out.println(maxAttr + "\t" + result[0] + "\t" + result[1]);
     return result;
  }
  
  /**
   * Parses a given list of options.
   * <p/>
   *
   * <!-- options-start --> Valid options are:
   * <p/>
   * 
   * <pre>
   * -F &lt;filter specification&gt;
   *  Full class name of filter to use, followed
   *  by filter options.
   *  eg: "weka.filters.unsupervised.attribute.Remove -V -R 1,2"
   * </pre>
   * 
   * <pre>
   * -D
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console
   * </pre>
   * 
   * <pre>
   * -W
   *  Full name of base classifier.
   *  (default: weka.classifiers.trees.J48)
   * </pre>
   * 
   * <pre>
   * Options specific to classifier weka.classifiers.trees.J48:
   * </pre>
   * 
   * <pre>
   * -U
   *  Use unpruned tree.
   * </pre>
   * 
   * <pre>
   * -C &lt;pruning confidence&gt;
   *  Set confidence threshold for pruning.
   *  (default 0.25)
   * </pre>
   * 
   * <pre>
   * -M &lt;minimum number of instances&gt;
   *  Set minimum number of instances per leaf.
   *  (default 2)
   * </pre>
   * 
   * <pre>
   * -R
   *  Use reduced error pruning.
   * </pre>
   * 
   * <pre>
   * -N &lt;number of folds&gt;
   *  Set number of folds for reduced error
   *  pruning. One fold is used as pruning set.
   *  (default 3)
   * </pre>
   * 
   * <pre>
   * -B
   *  Use binary splits only.
   * </pre>
   * 
   * <pre>
   * -S
   *  Don't perform subtree raising.
   * </pre>
   * 
   * <pre>
   * -L
   *  Do not clean up after the tree has been built.
   * </pre>
   * 
   * <pre>
   * -A
   *  Laplace smoothing for predicted probabilities.
   * </pre>
   * 
   * <pre>
   * -Q &lt;seed&gt;
   *  Seed for random data shuffling (default 1).
   * </pre>
   * 
   * <!-- options-end -->
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
// For now we handle all options for the classifier
//     classifier.setOptions(options);
     for (int i=0; i<options.length; i++) {
        if (options[i].equals("-t"))
           t = options[++i];
        if (options[i].equals("-T"))
           T = options[++i];
        if (options[i].equals("-D"))
           debug = true;
     }
  }
  
	private double[] evaluate(Instances instances,Classifier classifier) throws Exception {
		double[] result = new double[2];
		Evaluation eval = new Evaluation(instances);
		eval.crossValidateModel(classifier, instances, 10, new java.util.Random(1));
		result[0] = eval.correlationCoefficient();
		result[1] = eval.rootMeanSquaredError();
		return result;
	}
	
	private Instances clone(Instances in) {
		Instances out = new Instances(in); 
		for (int i=0; i<in.numInstances(); i++) {
			out.add(in.instance(i));
		}
		return out;
	}
}

class AccessiblePCA extends PrincipalComponents {

  /**
   * Initializes the filter with the given input data.
   * 
   * @param instances the data to process
   * @throws Exception in case the processing goes wrong
   * @see #batchFinished()
   */
  protected void setup(Instances instances) throws Exception {
     super.setup(instances);
  }
  
  public int getNonZeroIndex() {
     for (int idx=0; idx<m_Eigenvalues.length; idx++) {
        if (m_Eigenvalues[m_SortedEigens[idx]] > 0) {
    		return idx;
    	}
     }
     return -1;
  }
  
  public int getOneValueIndex() {
     for (int idx=0; idx<m_Eigenvalues.length; idx++) {
        if (m_Eigenvalues[m_SortedEigens[idx]] >= 1) {
    		return idx;
    	}
     }
     // If none - make it one anyhow. 
     return m_Eigenvalues.length-1;
  }
}