Cross Validator

This node represents Cross Validator from Spark ML

Input

It takes in a DataFrame, Estimator and Evaluator as input.

Output

The incoming dataframe is passed to the output.

Type

ml-crossvalidator

Class

fire.nodes.ml.NodeCrossValidator

Fields

Name

Title

Description

numFolds

Num Folds

The number of folds

parallelism

Parallelism

The number of threads to use when running parallel algorithms.

collectSubModels

Collect SubModels

Param for whether to collect a list of sub-models trained during tuning.

seed

Seed

Random Seed.

Details

This node represents Cross Validator from Spark ML.

CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3 folds, CrossValidator will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular ParamMap,

CrossValidator computes the average evaluation metric for the 3 Models produced by fitting the Estimator on the 3 different (training, test) dataset pairs.

After identifying the best ParamMap, CrossValidator finally re-fits the Estimator using the best ParamMap and the entire dataset.

More at Spark MLlib/ML docs page : https://spark.apache.org/docs/latest/ml-tuning.html#cross-validation

Examples

Below example is available at : https://spark.apache.org/docs/latest/ml-tuning.html#cross-validation

import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.classification.LogisticRegression

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

import org.apache.spark.ml.feature.{HashingTF, Tokenizer}

import org.apache.spark.ml.linalg.Vector

import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}

import org.apache.spark.sql.Row

// Prepare training data from a list of (id, text, label) tuples.

val training = spark.createDataFrame(Seq(

(0L, “a b c d e spark”, 1.0),

(1L, “b d”, 0.0),

(2L, “spark f g h”, 1.0),

(3L, “hadoop mapreduce”, 0.0),

(4L, “b spark who”, 1.0),

(5L, “g d a y”, 0.0),

(6L, “spark fly”, 1.0),

(7L, “was mapreduce”, 0.0),

(8L, “e spark program”, 1.0),

(9L, “a e c l”, 0.0),

(10L, “spark compile”, 1.0),

(11L, “hadoop software”, 0.0)

)).toDF(“id”, “text”, “label”)

// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.

val tokenizer = new Tokenizer()

.setInputCol(“text”)

.setOutputCol(“words”)

val hashingTF = new HashingTF()

.setInputCol(tokenizer.getOutputCol)

.setOutputCol(“features”)

val lr = new LogisticRegression()

.setMaxIter(10)

val pipeline = new Pipeline()

.setStages(Array(tokenizer, hashingTF, lr))

// We use a ParamGridBuilder to construct a grid of parameters to search over.

// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,

// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.

val paramGrid = new ParamGridBuilder()

.addGrid(hashingTF.numFeatures, Array(10, 100, 1000))

.addGrid(lr.regParam, Array(0.1, 0.01))

.build()

// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.

// This will allow us to jointly choose parameters for all Pipeline stages.

// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.

// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric

// is areaUnderROC.

val cv = new CrossValidator()

.setEstimator(pipeline)

.setEvaluator(new BinaryClassificationEvaluator)

.setEstimatorParamMaps(paramGrid)

.setNumFolds(2) // Use 3+ in practice

.setParallelism(2) // Evaluate up to 2 parameter settings in parallel

// Run cross-validation, and choose the best set of parameters.

val cvModel = cv.fit(training)

// Prepare test documents, which are unlabeled (id, text) tuples.

val test = spark.createDataFrame(Seq(

(4L, “spark i j k”),

(5L, “l m n”),

(6L, “mapreduce spark”),

(7L, “apache hadoop”)

)).toDF(“id”, “text”)

// Make predictions on test documents. cvModel uses the best model found (lrModel).

cvModel.transform(test)

.select(“id”, “text”, “probability”, “prediction”)

.collect()

.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>

println(s”($id, $text) –> prob=$prob, prediction=$prediction”)

}