An important task in ML is model selection, or using data to find the best model or parameters for a given task. This is also called tuning. You can tune a single estimator, such as LogisticRegression, or an entire pipeline that includes multiple algorithms, characterizations, and other steps. Users can tune the entire Pipeline at once, rather than tuning each element in the Pipeline individually.
An important task in ML is model selection, or using data to find the best model or parameters for a given task. This is also called tuning. You can tune a single Estimator (such as LogisticRegression) or an entire pipeline that includes multiple algorithms, characterizations, and other steps. Users can tune the entire Pipeline at once instead of tuning each element in the Pipeline individually.
MLlib supports model selection using tools such as CrossValidator and TrainValidationSplit. These tools require the following:
These model selection tools work as follows:
For each ParamMap, use these parameters to fit the Estimator to get the fitted Model , and use Evaluator to evaluate the performance of the Model.
To help construct the parameter grid, users can use ParamGridBuilder. By default, parameter sets in the parameter grid are evaluated serially. Parameter evaluation can be done in parallel by setting the degree of parallelism to 2 or more (a value of 1 will be serial) before running model selection using CrossValidator or TrainValidationSplit. The value for parallelism should be chosen carefully to maximize parallelism without exceeding cluster resources; larger values do not necessarily improve performance. Generally speaking, a value above 10 should be sufficient for most clusters.
CrossValidator Cross Validator first splits the dataset into a set of folded datasets, which are used as separate training and test datasets. For example, when k=3 times, CrossValidator will generate 3 pairs of (training, testing) data sets, each using 2/3 of the data for training and 1/3 of the data for testing. To evaluate a specific ParamMap, CrossValidator calculates the average evaluation metric by fitting 3 models produced by Estimator on 3 different (train, test) data set pairs.
After determining the best ParamMap, CrossValidator finally uses the best ParamMap and the entire data set to rematch the Estimator.
from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.tuning import CrossValidator, ParamGridBuilder # 准备训练文件,并做好标签。 training = spark.createDataFrame([ (0, "a b c d e spark", 1.0), (1, "b d", 0.0), (2, "spark f g h", 1.0), (3, "hadoop mapreduce", 0.0), (4, "b spark who", 1.0), (5, "g d a y", 0.0), (6, "spark fly", 1.0), (7, "was mapreduce", 0.0), (8, "e spark program", 1.0), (9, "a e c l", 0.0), (10, "spark compile", 1.0), (11, "hadoop software", 0.0) ], ["id", "text", "label"]) # 配置一个ML管道,它由树stages组成:tokenizer、hashingTF和lr。 tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") lr = LogisticRegression(maxIter=10) pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # 我们现在将Pipeline作为一个Estimator,将其包装在CrossValidator实例中。 # 这将允许我们共同选择所有管道阶段的参数。 # 交叉验证器需要一个Estimator、一组Estimator ParamMaps和一个Evaluator。 # 我们使用ParamGridBuilder来构造一个用于搜索的参数网格。 # hashingTF.numFeatures 的3个值, lr.regParam的2个值, # 这个网格将有3 x 2 = 6的参数设置供CrossValidator选择。 paramGrid = ParamGridBuilder() .addGrid(hashingTF.numFeatures, [10, 100, 1000]) .addGrid(lr.regParam, [0.1, 0.01]) .build() crossval = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=BinaryClassificationEvaluator(), numFolds=2)# 使用3+ folds # 运行交叉验证,并选择最佳参数集。 cvModel = crossval.fit(training) # 准备测试未标注的文件 test = spark.createDataFrame([ (4, "spark i j k"), (5, "l m n"), (6, "mapreduce spark"), (7, "apache hadoop") ], ["id", "text"]) # 对测试文档进行预测, cvModel使用发现的最佳模型(lrModel)。 prediction = cvModel.transform(test) selected = prediction.select("id", "text", "probability", "prediction") for row in selected.collect(): print(row)
Training Validation Split
In addition to CrossValidator, Spark also provides TrainValidationSplit for hyperparameter tuning. TrainValidationSplit only calculates each parameter combination once, as opposed to k times in the case of CrossValidator . Therefore, it is less expensive, but it does not produce reliable results when the training data set is not large enough.
Unlike CrossValidator, TrainValidationSplit creates a single (training, testing) data set pair. It uses the trainRatio parameter to split the dataset into these two parts. For example, when trainRatio=0.75, TrainValidationSplit will generate a train and test dataset pair with 75% of the data used for training and 25% used for validation.
Like CrossValidator, TrainValidationSplit ultimately uses the best ParamMap and matching Estimator for the entire data set.
from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit # Prepare training and test data. data = spark.read.format("libsvm") .load("data/mllib/sample_linear_regression_data.txt") train, test = data.randomSplit([0.9, 0.1], seed=12345) lr = LinearRegression(maxIter=10) # 我们使用ParamGridBuilder来构造一个用于搜索的参数网格。 # TrainValidationSplit将尝试所有值的组合,并使用评估器确定最佳模型。 paramGrid = ParamGridBuilder() .addGrid(lr.regParam, [0.1, 0.01]) .addGrid(lr.fitIntercept, [False, True]) .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0]) .build() # 在这种情况下,估计器是简单的线性回归。 # TrainValidationSplit需要一个Estimator、一组Estimator ParamMaps 和一个 Evaluator。 tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=paramGrid, evaluator=RegressionEvaluator(), # 80%的数据将用于培训,20%用于验证。 trainRatio=0.8) # 运行TrainValidationSplit,并选择最佳参数集。 model = tvs.fit(train) # 对测试数据进行预测。模型是参数组合后性能最好的模型。 model.transform(test) .select("features", "label", "prediction") .show()
The above is the detailed content of Summary of machine learning hyperparameter tuning (PySpark ML). For more information, please follow other related articles on the PHP Chinese website!