Class CrossValidationIterator

  • All Implemented Interfaces:
    java.io.Serializable, java.util.Iterator<InstanceList[]>

    public class CrossValidationIterator
    extends java.lang.Object
    implements java.util.Iterator<InstanceList[]>, java.io.Serializable
    An iterator which splits an InstanceList into n-folds and iterates over the folds for use in n-fold cross-validation. For each iteration, list[0] contains a InstanceList with n-1 folds typically used for training and list[1] contains an InstanceList with 1 fold typically used for validation. This class uses MultiInstanceList to avoid creating a new InstanceList each iteration. TODO - currently the distribution is completely random, an improvement would be to provide a stratified random distribution.
    Author:
    Aron Culotta culotta@cs.umass.edu
    See Also:
    MultiInstanceList, InstanceList, Serialized Form
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      void clear()
      Calls clear on each fold.
      boolean hasNext()  
      InstanceList[] next()
      Returns the next training/testing split.
      InstanceList[] nextSplit()
      Returns the next training/testing split.
      InstanceList[] nextSplit​(int numTrainFolds)
      Returns the next training/testing split.
      void remove()  
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
      • Methods inherited from interface java.util.Iterator

        forEachRemaining
    • Constructor Detail

      • CrossValidationIterator

        public CrossValidationIterator​(InstanceList ilist,
                                       int nfolds,
                                       java.util.Random r)
        Constructs a new n-fold cross-validation iterator
        Parameters:
        ilist - instance list to split into folds and iterate over
        nfolds - number of folds to split InstanceList into
        r - The source of randomness to use in shuffling.
      • CrossValidationIterator

        public CrossValidationIterator​(InstanceList ilist,
                                       int _nfolds)
        Constructs a new n-fold cross-validation iterator
        Parameters:
        ilist - instance list to split into folds and iterate over
        _nfolds - number of folds to split InstanceList into
    • Method Detail

      • clear

        public void clear()
        Calls clear on each fold. It is recommended that this be always be called when the iterator is no longer needed so that implementations of InstanceList such as PagedInstanceList can clean up any temporary data they may have outside the JVM.
      • hasNext

        public boolean hasNext()
        Specified by:
        hasNext in interface java.util.Iterator<InstanceList[]>
      • nextSplit

        public InstanceList[] nextSplit()
        Returns the next training/testing split.
        Returns:
        A two element array of InstanceList, where InstanceList[0] contains n-1 folds for training and InstanceList[1] contains 1 fold for testing.
      • nextSplit

        public InstanceList[] nextSplit​(int numTrainFolds)
        Returns the next training/testing split.
        Returns:
        A two element array of InstanceList, where InstanceList[0] contains numTrainingFolds folds for training and InstanceList[1] contains n - numTrainingFolds folds for testing.
      • next

        public InstanceList[] next()
        Returns the next training/testing split.
        Specified by:
        next in interface java.util.Iterator<InstanceList[]>
        Returns:
        A two element array of InstanceList, where InstanceList[0] contains n-1 folds for training and InstanceList[1] contains 1 fold for testing.
        See Also:
        Iterator.next()
      • remove

        public void remove()
        Specified by:
        remove in interface java.util.Iterator<InstanceList[]>