Decision Tree Classifier =========== It supports both binary and multiclass labels, as well as both continuous and categorical features. Input -------------- It takes in a DataFrame and performs Decision Tree Classification Output -------------- The Decision Tree Model generated is passed along to the next nodes. The input DataFrame is also passed along to the next nodes Type --------- ml-estimator Class --------- fire.nodes.ml.NodeDecisionTreeClassifier Fields --------- .. list-table:: :widths: 10 5 10 :header-rows: 1 * - Name - Title - Description * - featuresCol - Features Column - Features column of type vectorUDT for model fitting * - labelCol - Label Column - The label column for model fitting * - predictionCol - Prediction Column - The prediction column created during model scoring. * - splitRatio - Split Ratio - Split Ratio * - impurity - Impurity - The Criterion used for information gain calculation * - maxBins - Max Bins - The maximum number of bins used for discretizing continuous features.Must be >= 2 and >= number of categories in any categorical feature. * - maxDepth - Max Depth - The Maximum depth of a tree * - minInfoGain - Min Information Gain - The Minimum information gain for a split to be considered at a tree node * - minInstancesPerNode - Min Instances Per Node - The Minimum number of instances each child must have after split * - seed - Seed - The random seed * - cacheNodeIds - Cache Node Ids - The caching nodes IDs. Can speed up training of deeper trees. * - checkpointInterval - Checkpoint Interval - The checkpoint interval. E.g. 10 means that the cache will get checkpointed every 10 iterations.Set checkpoint interval (>= 1) or disable checkpoint (-1) * - maxMemoryInMB - Max memory - Maximum memory in MB allocated to histogram aggregation. * - minWeightFractionPerNode - Min Weight Fraction per Node - Minimum fraction of the weighted sample count that each child must have after split. * - weightCol - Weight Column - Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. * - gridSearch - Grid Search - * - minInfoGainGrid - Min Information Gain Param Grid Search - Min Information Gain Parameters for Grid Search * - maxBinsGrid - Max Bins Param Grid Search - Max Bins Parameters for Grid Search * - maxDepthGrid - Max Depth Param Grid Search - Max Depth Parameters for Grid Search * - confusionMatrix - Confusion Matrix - * - output_confusion_matrix_chart - Output Confusion Matrix Chart - whether to display confusion matrix chart. * - cm_chart_title - Confusion Matrix Chart Title - Title name to display in Confusion Matrix Chart * - cm_chart_description - Confusion Matrix Chart Description - Description to display in Confusion Matrix CHart * - confusionMatrixTargetLegend - Confusion Matrix Target Legend - Legend name to display for Target in Confusion Matrix * - confusionMatrixPredictedLabelLegend - Confusion Matrix PredictedLabel Legend - Legend name to display for Predicted Label in Confusion Matrix * - confusionMatrixCountLegend - Confusion Matrix Count Legend - Legend name to display for Count in Confusion Matrix * - Description - Confusion Matrix Description - * - confusionMatrixRowDescription - Confusion Matrix Outcome description - One can provide the business details of the outcome of the confusion matrix rows * - ROC Curve - ROC Curve - * - output_roc_curve - Output ROC Curve - whether to display confusion matrix chart. * - roc_title - ROC Curve Chart Title - Title name to display in ROC Curve Chart * - roc_description - ROC Curve Chart Description - Add Description for ROC Curve Chart * - xlabel - X Label - X label * - ylabel - Y Label - Y Label Details ------- Decision trees supports both binary and multiclass labels, as well as both continuous and categorical features. More at Spark MLlib/ML docs page : http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier Examples ------- Below example is available at : https://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // Load the data stored in LIBSVM format as a DataFrame. val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. val labelIndexer = new StringIndexer() .setInputCol("label") .setOutputCol("indexedLabel") .fit(data) // Automatically identify categorical features, and index them. val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(4) // features with > 4 distinct values are treated as continuous. .fit(data) // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a DecisionTree model. val dt = new DecisionTreeClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures") // Convert indexed labels back to original labels. val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labelsArray(0)) // Chain indexers and tree in a Pipeline. val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // Make predictions. val predictions = model.transform(testData) // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5) // Select (prediction, true label) and compute test error. val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] println(s"Learned classification tree model:\\n ${treeModel.toDebugString}")