Package cc.mallet.types
Class CrossValidationIterator
- java.lang.Object
-
- cc.mallet.types.CrossValidationIterator
-
- All Implemented Interfaces:
java.io.Serializable
,java.util.Iterator<InstanceList[]>
public class CrossValidationIterator extends java.lang.Object implements java.util.Iterator<InstanceList[]>, java.io.Serializable
An iterator which splits anInstanceList
into n-folds and iterates over the folds for use in n-fold cross-validation. For each iteration, list[0] contains aInstanceList
with n-1 folds typically used for training and list[1] contains anInstanceList
with 1 fold typically used for validation. This class usesMultiInstanceList
to avoid creating a newInstanceList
each iteration. TODO - currently the distribution is completely random, an improvement would be to provide a stratified random distribution.- Author:
- Aron Culotta culotta@cs.umass.edu
- See Also:
MultiInstanceList
,InstanceList
, Serialized Form
-
-
Constructor Summary
Constructors Constructor Description CrossValidationIterator(InstanceList ilist, int _nfolds)
Constructs a new n-fold cross-validation iteratorCrossValidationIterator(InstanceList ilist, int nfolds, java.util.Random r)
Constructs a new n-fold cross-validation iterator
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
clear()
Calls clear on each fold.boolean
hasNext()
InstanceList[]
next()
Returns the next training/testing split.InstanceList[]
nextSplit()
Returns the next training/testing split.InstanceList[]
nextSplit(int numTrainFolds)
Returns the next training/testing split.void
remove()
-
-
-
Constructor Detail
-
CrossValidationIterator
public CrossValidationIterator(InstanceList ilist, int nfolds, java.util.Random r)
Constructs a new n-fold cross-validation iterator- Parameters:
ilist
- instance list to split into folds and iterate overnfolds
- number of folds to split InstanceList intor
- The source of randomness to use in shuffling.
-
CrossValidationIterator
public CrossValidationIterator(InstanceList ilist, int _nfolds)
Constructs a new n-fold cross-validation iterator- Parameters:
ilist
- instance list to split into folds and iterate over_nfolds
- number of folds to split InstanceList into
-
-
Method Detail
-
clear
public void clear()
Calls clear on each fold. It is recommended that this be always be called when the iterator is no longer needed so that implementations of InstanceList such as PagedInstanceList can clean up any temporary data they may have outside the JVM.
-
hasNext
public boolean hasNext()
- Specified by:
hasNext
in interfacejava.util.Iterator<InstanceList[]>
-
nextSplit
public InstanceList[] nextSplit()
Returns the next training/testing split.- Returns:
- A two element array of
InstanceList
, whereInstanceList[0]
contains n-1 folds for training andInstanceList[1]
contains 1 fold for testing.
-
nextSplit
public InstanceList[] nextSplit(int numTrainFolds)
Returns the next training/testing split.- Returns:
- A two element array of
InstanceList
, whereInstanceList[0]
containsnumTrainingFolds
folds for training andInstanceList[1]
contains n -numTrainingFolds
folds for testing.
-
next
public InstanceList[] next()
Returns the next training/testing split.- Specified by:
next
in interfacejava.util.Iterator<InstanceList[]>
- Returns:
- A two element array of
InstanceList
, whereInstanceList[0]
contains n-1 folds for training andInstanceList[1]
contains 1 fold for testing. - See Also:
Iterator.next()
-
remove
public void remove()
- Specified by:
remove
in interfacejava.util.Iterator<InstanceList[]>
-
-