GridSearchCV

class hana_ml.algorithms.pal.model_selection.GridSearchCV(estimator, param_grid, train_control, scoring)

Exhaustive search over specified parameter values for an estimator with crossover validation (CV).

Parameters:
estimatorestimator object

This is assumed to implement the PAL estimator interface.

param_griddict

Dictionary with parameters names (string) as keys and lists of parameter settings to try as values in which case the grids spanned by each dictionary in the list are explored. This enables searching over any sequence of parameter settings.

train_controldict

Controlling parameters for model evaluation and parameter selection.

scoringstr

A string of scoring method to evaluate the predictions. The options of scoring depend on the specific algorithm of an used estimator.

Examples

>>> uhgc = UnifiedClassification(func='HybridGradientBoostingTree')

Create a "GridSearchCV" object:

>>> gscv = GridSearchCV(estimator=uhgc,
                        param_grid={'learning_rate': [0.1, 0.4, 0.7, 1],
                                    'n_estimators': [4, 6, 8, 10],
                                    'split_threshold': [0.1, 0.4, 0.7, 1]},
                        train_control=dict(fold_num=5,
                                           resampling_method='cv',
                                           random_state=1,
                                           ref_metric=['error_rate']),
                        scoring='error_rate')

Invoke fit():

>>> gscv.fit(data=df_train,
             key= 'ID',
             label='CLASS',
             partition_method='stratified',
             partition_random_state=1,
             stratified_column='CLASS')
Attributes:
estimatoran estimator object.

Methods

fit(data, **kwargs)

Fit the model to the training dataset.

predict(data, **kwargs)

Predict function.

set_resampling_method(method)

Specifies the resampling method for model evaluation or parameter selection.

set_scoring_metric(metric)

Specifies the scoring metric.

set_seed(seed[, seed_name])

Specifies the seed for random generation.

set_timeout(timeout)

Specifies the maximum running time for model evaluation or parameter selection.

fit(data, **kwargs)

Fit the model to the training dataset.

Parameters:
dataDataFrame

Input DataFrame.

**kwargs: dict

A dict of the keyword args passed to the function. Please refer to the documentation of the specific function for parameter information.

predict(data, **kwargs)

Predict function.

Parameters:
dataDataFrame

Input DataFrame.

**kwargs: dict

A dict of the keyword args passed to the function. Please refer to the documentation of the specific function for parameter information.

set_resampling_method(method)

Specifies the resampling method for model evaluation or parameter selection.

Parameters:
methodstr

Specifies the resampling method for parameter selection.

  • "cv"

  • "cv_sha"

  • "cv_hyperband"

  • "stratified_cv"

  • "stratified_cv_sha"

  • "stratified_cv_hyperband"

  • "bootstrap"

  • "bootstrap_sha"

  • "bootstrap_hyperband"

  • "stratified_bootstrap"

  • "stratified_bootstrap_sha"

  • "stratified_bootstrap_hyperband"

Resampling methods with prefix "stratified" can only apply to classification algorithms.

set_scoring_metric(metric)

Specifies the scoring metric.

Parameters:
metricstr

Specifies the evaluation metric for model evaluation or parameter selection.

  • "accuracy"

  • "error_rate"

  • "f1_score"

  • "rmse"

  • "mae"

  • "auc"

  • "nll" (negative log likelihood)

set_seed(seed, seed_name=None)

Specifies the seed for random generation. Use system time when 0 is specified.

Parameters:
seedint

The random seed number.

seed_nameint, optional

The name of the random seed.

Defaults to None.

set_timeout(timeout)

Specifies the maximum running time for model evaluation or parameter selection. Unit is second. No timeout when 0 is specified.

Parameters:
timeoutint

The maximum running time. The unit is second.

Inherited Methods from PALBase

Besides those methods mentioned above, the GridSearchCV class also inherits methods from PALBase class, please refer to PAL Base for more details.