weka.classifiers
Class ThresholdSelector

java.lang.Object
  |
  +--weka.classifiers.Classifier
        |
        +--weka.classifiers.DistributionClassifier
              |
              +--weka.classifiers.ThresholdSelector
All Implemented Interfaces:
java.lang.Cloneable, OptionHandler, java.io.Serializable

public class ThresholdSelector
extends DistributionClassifier
implements OptionHandler

Class for selecting a threshold on a probability output by a distribution classifier. The threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range).

Valid options are:

-C num
The class for which threshold is determined. Valid values are: 1, 2 (for first and second classes, respectively), 3 (for whichever class is least frequent), 4 (for whichever class value is most frequent), and 5 (for the first class named any of "yes","pos(itive)", "1", or method 3 if no matches). (default 5).

-W classname
Specify the full class name of the base classifier.

-X num
Number of folds used for cross validation. If just a hold-out set is used, this determines the size of the hold-out set (default 3).

-R integer
Sets whether confidence range correction is applied. This can be used to ensure the confidences range from 0 to 1. Use 0 for no range correction, 1 for correction based on the min/max values seen during threshold selection (default 0).

-S seed
Random number seed (default 1).

-E integer
Sets the evaluation mode. Use 0 for evaluation using cross-validation, 1 for evaluation using hold-out set, and 2 for evaluation on the training data (default 1).

Options after -- are passed to the designated sub-classifier.

Author:
Eibe Frank (eibe@cs.waikato.ac.nz)
See Also:
Serialized Form

Field Summary
static int EVAL_CROSS_VALIDATION
           
static int EVAL_TRAINING_SET
           
static int EVAL_TUNED_SPLIT
           
protected  double m_BestThreshold
          The threshold that lead to the best performance
protected  double m_BestValue
          The best value that has been observed
protected  DistributionClassifier m_Classifier
          The generated base classifier
protected  int m_ClassMode
          Method to determine which class to optimize for
protected  int m_DesignatedClass
          Designated class value, determined during building
protected  int m_EvalMode
          The evaluation mode
protected  double m_HighThreshold
          The upper threshold used as the basis of correction
protected  double m_LowThreshold
          The lower threshold used as the basis of correction
protected  int m_NumXValFolds
          The number of folds used in cross-validation
protected  int m_RangeMode
          The range correction mode
protected  int m_Seed
          Random number seed
protected static double MIN_VALUE
          The minimum value for the criterion.
static int OPTIMIZE_0
           
static int OPTIMIZE_1
           
static int OPTIMIZE_LFREQ
           
static int OPTIMIZE_MFREQ
           
static int OPTIMIZE_POS_NAME
           
static int RANGE_BOUNDS
           
static int RANGE_NONE
           
static Tag[] TAGS_EVAL
           
static Tag[] TAGS_OPTIMIZE
           
static Tag[] TAGS_RANGE
           
 
Constructor Summary
ThresholdSelector()
           
 
Method Summary
 void buildClassifier(Instances instances)
          Generates the classifier.
 java.lang.String designatedClassTipText()
           
 java.lang.String distributionClassifierTipText()
           
 double[] distributionForInstance(Instance instance)
          Calculates the class membership probabilities for the given test instance.
 java.lang.String evaluationModeTipText()
           
protected  void findThreshold(FastVector predictions)
          Finds the best threshold, this implementation searches for the highest FMeasure.
 SelectedTag getDesignatedClass()
          Gets the method to determine which class value to optimize.
 DistributionClassifier getDistributionClassifier()
          Get the DistributionClassifier used as the classifier.
 SelectedTag getEvaluationMode()
          Gets the evaluation mode used.
 int getNumXValFolds()
          Get the number of folds used for cross-validation.
 java.lang.String[] getOptions()
          Gets the current settings of the Classifier.
protected  FastVector getPredictions(Instances instances, int mode, int numFolds)
          Collects the classifier predictions using the specified evaluation method.
 SelectedTag getRangeCorrection()
          Gets the confidence range correction mode used.
 int getSeed()
          Gets the random number seed.
 java.lang.String globalInfo()
           
 java.util.Enumeration listOptions()
          Returns an enumeration describing the available options
static void main(java.lang.String[] argv)
          Main method for testing this class.
 java.lang.String numXValFoldsTipText()
           
 java.lang.String rangeCorrectionTipText()
           
 java.lang.String seedTipText()
           
 void setDesignatedClass(SelectedTag newMethod)
          Sets the method to determine which class value to optimize.
 void setDistributionClassifier(DistributionClassifier newClassifier)
          Set the DistributionClassifier for which threshold is set.
 void setEvaluationMode(SelectedTag newMethod)
          Sets the evaluation mode used.
 void setNumXValFolds(int newNumFolds)
          Set the number of folds used for cross-validation.
 void setOptions(java.lang.String[] options)
          Parses a given list of options.
 void setRangeCorrection(SelectedTag newMethod)
          Sets the confidence range correction mode used.
 void setSeed(int seed)
          Sets the seed for random number generation.
 java.lang.String toString()
          Returns description of the cross-validated classifier.
 
Methods inherited from class weka.classifiers.DistributionClassifier
classifyInstance
 
Methods inherited from class weka.classifiers.Classifier
forName, makeCopies
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Field Detail

RANGE_NONE

public static final int RANGE_NONE

RANGE_BOUNDS

public static final int RANGE_BOUNDS

TAGS_RANGE

public static final Tag[] TAGS_RANGE

EVAL_TRAINING_SET

public static final int EVAL_TRAINING_SET

EVAL_TUNED_SPLIT

public static final int EVAL_TUNED_SPLIT

EVAL_CROSS_VALIDATION

public static final int EVAL_CROSS_VALIDATION

TAGS_EVAL

public static final Tag[] TAGS_EVAL

OPTIMIZE_0

public static final int OPTIMIZE_0

OPTIMIZE_1

public static final int OPTIMIZE_1

OPTIMIZE_LFREQ

public static final int OPTIMIZE_LFREQ

OPTIMIZE_MFREQ

public static final int OPTIMIZE_MFREQ

OPTIMIZE_POS_NAME

public static final int OPTIMIZE_POS_NAME

TAGS_OPTIMIZE

public static final Tag[] TAGS_OPTIMIZE

m_Classifier

protected DistributionClassifier m_Classifier
The generated base classifier

m_HighThreshold

protected double m_HighThreshold
The upper threshold used as the basis of correction

m_LowThreshold

protected double m_LowThreshold
The lower threshold used as the basis of correction

m_BestThreshold

protected double m_BestThreshold
The threshold that lead to the best performance

m_BestValue

protected double m_BestValue
The best value that has been observed

m_NumXValFolds

protected int m_NumXValFolds
The number of folds used in cross-validation

m_Seed

protected int m_Seed
Random number seed

m_DesignatedClass

protected int m_DesignatedClass
Designated class value, determined during building

m_ClassMode

protected int m_ClassMode
Method to determine which class to optimize for

m_EvalMode

protected int m_EvalMode
The evaluation mode

m_RangeMode

protected int m_RangeMode
The range correction mode

MIN_VALUE

protected static final double MIN_VALUE
The minimum value for the criterion. If threshold adjustment yields less than that, the default threshold of 0.5 is used.
Constructor Detail

ThresholdSelector

public ThresholdSelector()
Method Detail

getPredictions

protected FastVector getPredictions(Instances instances,
                                    int mode,
                                    int numFolds)
                             throws java.lang.Exception
Collects the classifier predictions using the specified evaluation method.
Parameters:
instances - the set of Instances to generate predictions for.
mode - the evaluation mode.
numFolds - the number of folds to use if not evaluating on the full training set.
Returns:
a FastVector containing the predictions.
Throws:
java.lang.Exception - if an error occurs generating the predictions.

findThreshold

protected void findThreshold(FastVector predictions)
Finds the best threshold, this implementation searches for the highest FMeasure. If no FMeasure higher than MIN_VALUE is found, the default threshold of 0.5 is used.
Parameters:
predictions - a FastVector containing the predictions.

listOptions

public java.util.Enumeration listOptions()
Returns an enumeration describing the available options
Specified by:
listOptions in interface OptionHandler
Returns:
an enumeration of all the available options

setOptions

public void setOptions(java.lang.String[] options)
                throws java.lang.Exception
Parses a given list of options. Valid options are:

-C num
The class for which threshold is determined. Valid values are: 1, 2 (for first and second classes, respectively), 3 (for whichever class is least frequent), 4 (for whichever class value is most frequent), and 5 (for the first class named any of "yes","pos(itive)", "1", or method 3 if no matches). (default 3).

-W classname
Specify the full class name of classifier to perform cross-validation selection on.

-X num
Number of folds used for cross validation. If just a hold-out set is used, this determines the size of the hold-out set (default 3).

-R integer
Sets whether confidence range correction is applied. This can be used to ensure the confidences range from 0 to 1. Use 0 for no range correction, 1 for correction based on the min/max values seen during threshold selection (default 0).

-S seed
Random number seed (default 1).

-E integer
Sets the evaluation mode. Use 0 for evaluation using cross-validation, 1 for evaluation using hold-out set, and 2 for evaluation on the training data (default 1).

Options after -- are passed to the designated sub-classifier.

Specified by:
setOptions in interface OptionHandler
Parameters:
options - the list of options as an array of strings
Throws:
java.lang.Exception - if an option is not supported

getOptions

public java.lang.String[] getOptions()
Gets the current settings of the Classifier.
Specified by:
getOptions in interface OptionHandler
Returns:
an array of strings suitable for passing to setOptions

buildClassifier

public void buildClassifier(Instances instances)
                     throws java.lang.Exception
Generates the classifier.
Overrides:
buildClassifier in class Classifier
Parameters:
instances - set of instances serving as training data
Throws:
java.lang.Exception - if the classifier has not been generated successfully

distributionForInstance

public double[] distributionForInstance(Instance instance)
                                 throws java.lang.Exception
Calculates the class membership probabilities for the given test instance.
Overrides:
distributionForInstance in class DistributionClassifier
Parameters:
instance - the instance to be classified
Returns:
predicted class probability distribution
Throws:
java.lang.Exception - if instance could not be classified successfully

globalInfo

public java.lang.String globalInfo()
Returns:
a description of the classifier suitable for displaying in the explorer/experimenter gui

designatedClassTipText

public java.lang.String designatedClassTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

getDesignatedClass

public SelectedTag getDesignatedClass()
Gets the method to determine which class value to optimize. Will be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ, OPTIMIZE_POS_NAME.
Returns:
the class selection mode.

setDesignatedClass

public void setDesignatedClass(SelectedTag newMethod)
Sets the method to determine which class value to optimize. Will be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ, OPTIMIZE_POS_NAME.
Parameters:
newMethod - the new class selection mode.

evaluationModeTipText

public java.lang.String evaluationModeTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setEvaluationMode

public void setEvaluationMode(SelectedTag newMethod)
Sets the evaluation mode used. Will be one of EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
Parameters:
newMethod - the new evaluation mode.

getEvaluationMode

public SelectedTag getEvaluationMode()
Gets the evaluation mode used. Will be one of EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
Returns:
the evaluation mode.

rangeCorrectionTipText

public java.lang.String rangeCorrectionTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setRangeCorrection

public void setRangeCorrection(SelectedTag newMethod)
Sets the confidence range correction mode used. Will be one of RANGE_NONE, or RANGE_BOUNDS
Parameters:
newMethod - the new correciton mode.

getRangeCorrection

public SelectedTag getRangeCorrection()
Gets the confidence range correction mode used. Will be one of RANGE_NONE, or RANGE_BOUNDS
Returns:
the confidence correction mode.

seedTipText

public java.lang.String seedTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setSeed

public void setSeed(int seed)
Sets the seed for random number generation.
Parameters:
seed - the random number seed

getSeed

public int getSeed()
Gets the random number seed.
Returns:
the random number seed

numXValFoldsTipText

public java.lang.String numXValFoldsTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

getNumXValFolds

public int getNumXValFolds()
Get the number of folds used for cross-validation.
Returns:
the number of folds used for cross-validation.

setNumXValFolds

public void setNumXValFolds(int newNumFolds)
Set the number of folds used for cross-validation.
Parameters:
newNumFolds - the number of folds used for cross-validation.

distributionClassifierTipText

public java.lang.String distributionClassifierTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setDistributionClassifier

public void setDistributionClassifier(DistributionClassifier newClassifier)
Set the DistributionClassifier for which threshold is set.
Parameters:
newClassifier - the Classifier to use.

getDistributionClassifier

public DistributionClassifier getDistributionClassifier()
Get the DistributionClassifier used as the classifier.
Returns:
the classifier used as the classifier

toString

public java.lang.String toString()
Returns description of the cross-validated classifier.
Overrides:
toString in class java.lang.Object
Returns:
description of the cross-validated classifier as a string

main

public static void main(java.lang.String[] argv)
Main method for testing this class.
Parameters:
argv - the options