weka.classifiers
Class CVParameterSelection

java.lang.Object
  |
  +--weka.classifiers.Classifier
        |
        +--weka.classifiers.CVParameterSelection
All Implemented Interfaces:
java.lang.Cloneable, OptionHandler, java.io.Serializable, Summarizable

public class CVParameterSelection
extends Classifier
implements OptionHandler, Summarizable

Class for performing parameter selection by cross-validation for any classifier. For more information, see

R. Kohavi (1995). Wrappers for Performance Enhancement and Oblivious Decision Graphs. PhD Thesis. Department of Computer Science, Stanford University.

Valid options are:

-D
Turn on debugging output.

-W classname
Specify the full class name of classifier to perform cross-validation selection on.

-X num
Number of folds used for cross validation (default 10).

-S seed
Random number seed (default 1).

-P "N 1 5 10"
Sets an optimisation parameter for the classifier with name -N, lower bound 1, upper bound 5, and 10 optimisation steps. The upper bound may be the character 'A' or 'I' to substitute the number of attributes or instances in the training data, respectively. This parameter may be supplied more than once to optimise over several classifier options simultaneously.

Options after -- are passed to the designated sub-classifier.

Author:
Len Trigg (trigg@cs.waikato.ac.nz)
See Also:
Serialized Form

Inner Class Summary
protected  class CVParameterSelection.CVParameter
           
 
Field Summary
protected  java.lang.String[] m_BestClassifierOptions
          The set of all classifier options as determined by cross-validation
protected  double m_BestPerformance
          The cross-validated performance of the best options
protected  Classifier m_Classifier
          The generated base classifier
protected  java.lang.String[] m_ClassifierOptions
          The base classifier options (not including those being set by cross-validation)
protected  FastVector m_CVParams
          The set of parameters to cross-validate over
protected  boolean m_Debug
          Debugging mode, gives extra output if true
protected  int m_NumAttributes
          The number of attributes in the data
protected  int m_NumFolds
          The number of folds used in cross-validation
protected  int m_Seed
          Random number seed
protected  int m_TrainFoldSize
          The number of instances in a training fold
 
Constructor Summary
CVParameterSelection()
           
 
Method Summary
 void addCVParameter(java.lang.String cvParam)
          Adds a scheme parameter to the list of parameters to be set by cross-validation
 void buildClassifier(Instances instances)
          Generates the classifier.
 double classifyInstance(Instance instance)
          Predicts the class value for the given test instance.
protected  java.lang.String[] createOptions()
          Create the options array to pass to the classifier.
protected  void findParamsByCrossValidation(int depth, Instances trainData)
          Finds the best parameter combination.
 Classifier getClassifier()
          Get the classifier used as the classifier
 java.lang.String getCVParameter(int index)
          Gets the scheme paramter with the given index.
 boolean getDebug()
          Gets whether debugging is turned on
 int getNumFolds()
          Get the number of folds used for cross-validation.
 java.lang.String[] getOptions()
          Gets the current settings of the Classifier.
 int getSeed()
          Gets the random number seed.
 java.util.Enumeration listOptions()
          Returns an enumeration describing the available options
static void main(java.lang.String[] argv)
          Main method for testing this class.
 void setClassifier(Classifier newClassifier)
          Set the classifier for boosting.
 void setDebug(boolean debug)
          Sets debugging mode
 void setNumFolds(int newNumFolds)
          Set the number of folds used for cross-validation.
 void setOptions(java.lang.String[] options)
          Parses a given list of options.
 void setSeed(int seed)
          Sets the seed for random number generation.
 java.lang.String toString()
          Returns description of the cross-validated classifier.
 java.lang.String toSummaryString()
          Returns a string that summarizes the object.
 
Methods inherited from class weka.classifiers.Classifier
forName, makeCopies
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Field Detail

m_Classifier

protected Classifier m_Classifier
The generated base classifier

m_ClassifierOptions

protected java.lang.String[] m_ClassifierOptions
The base classifier options (not including those being set by cross-validation)

m_BestClassifierOptions

protected java.lang.String[] m_BestClassifierOptions
The set of all classifier options as determined by cross-validation

m_BestPerformance

protected double m_BestPerformance
The cross-validated performance of the best options

m_CVParams

protected FastVector m_CVParams
The set of parameters to cross-validate over

m_NumAttributes

protected int m_NumAttributes
The number of attributes in the data

m_TrainFoldSize

protected int m_TrainFoldSize
The number of instances in a training fold

m_NumFolds

protected int m_NumFolds
The number of folds used in cross-validation

m_Seed

protected int m_Seed
Random number seed

m_Debug

protected boolean m_Debug
Debugging mode, gives extra output if true
Constructor Detail

CVParameterSelection

public CVParameterSelection()
Method Detail

createOptions

protected java.lang.String[] createOptions()
Create the options array to pass to the classifier. The parameter values and positions are taken from m_ClassifierOptions and m_CVParams.
Returns:
the options array

findParamsByCrossValidation

protected void findParamsByCrossValidation(int depth,
                                           Instances trainData)
                                    throws java.lang.Exception
Finds the best parameter combination. (recursive for each parameter being optimised).
Parameters:
depth - the index of the parameter to be optimised at this level
Throws:
java.lang.Exception - if an error occurs

listOptions

public java.util.Enumeration listOptions()
Returns an enumeration describing the available options
Specified by:
listOptions in interface OptionHandler
Returns:
an enumeration of all the available options

setOptions

public void setOptions(java.lang.String[] options)
                throws java.lang.Exception
Parses a given list of options. Valid options are:

-D
Turn on debugging output.

-W classname
Specify the full class name of classifier to perform cross-validation selection on.

-X num
Number of folds used for cross validation (default 10).

-S seed
Random number seed (default 1).

-P "N 1 5 10"
Sets an optimisation parameter for the classifier with name -N, lower bound 1, upper bound 5, and 10 optimisation steps. The upper bound may be the character 'A' or 'I' to substitute the number of attributes or instances in the training data, respectively. This parameter may be supplied more than once to optimise over several classifier options simultaneously.

Options after -- are passed to the designated sub-classifier.

Specified by:
setOptions in interface OptionHandler
Parameters:
options - the list of options as an array of strings
Throws:
java.lang.Exception - if an option is not supported

getOptions

public java.lang.String[] getOptions()
Gets the current settings of the Classifier.
Specified by:
getOptions in interface OptionHandler
Returns:
an array of strings suitable for passing to setOptions

buildClassifier

public void buildClassifier(Instances instances)
                     throws java.lang.Exception
Generates the classifier.
Overrides:
buildClassifier in class Classifier
Parameters:
instances - set of instances serving as training data
Throws:
java.lang.Exception - if the classifier has not been generated successfully

classifyInstance

public double classifyInstance(Instance instance)
                        throws java.lang.Exception
Predicts the class value for the given test instance.
Overrides:
classifyInstance in class Classifier
Parameters:
instance - the instance to be classified
Returns:
the predicted class value
Throws:
java.lang.Exception - if an error occurred during the prediction

setSeed

public void setSeed(int seed)
Sets the seed for random number generation.
Parameters:
seed - the random number seed

getSeed

public int getSeed()
Gets the random number seed.
Returns:
the random number seed

addCVParameter

public void addCVParameter(java.lang.String cvParam)
                    throws java.lang.Exception
Adds a scheme parameter to the list of parameters to be set by cross-validation
Parameters:
cvParam - the string representation of a scheme parameter. The format is:
param_char lower_bound upper_bound increment
eg to search a parameter -P from 1 to 10 by increments of 2:
P 1 10 2
Throws:
java.lang.Exception - if the parameter specifier is of the wrong format

getCVParameter

public java.lang.String getCVParameter(int index)
Gets the scheme paramter with the given index.

setDebug

public void setDebug(boolean debug)
Sets debugging mode
Parameters:
debug - true if debug output should be printed

getDebug

public boolean getDebug()
Gets whether debugging is turned on
Returns:
true if debugging output is on

getNumFolds

public int getNumFolds()
Get the number of folds used for cross-validation.
Returns:
the number of folds used for cross-validation.

setNumFolds

public void setNumFolds(int newNumFolds)
Set the number of folds used for cross-validation.
Parameters:
newNumFolds - the number of folds used for cross-validation.

setClassifier

public void setClassifier(Classifier newClassifier)
Set the classifier for boosting.
Parameters:
newClassifier - the Classifier to use.

getClassifier

public Classifier getClassifier()
Get the classifier used as the classifier
Returns:
the classifier used as the classifier

toString

public java.lang.String toString()
Returns description of the cross-validated classifier.
Overrides:
toString in class java.lang.Object
Returns:
description of the cross-validated classifier as a string

toSummaryString

public java.lang.String toSummaryString()
Description copied from interface: Summarizable
Returns a string that summarizes the object.
Specified by:
toSummaryString in interface Summarizable
Following copied from interface: weka.core.Summarizable
Returns:
the object summarized as a string

main

public static void main(java.lang.String[] argv)
Main method for testing this class.
Parameters:
argv - the options