TextClassificationWithModel

class hana_ml.text.tm.TextClassificationWithModel(language=None, enable_stopwords=True, keep_numeric=None, allowed_list=None, notallowed_list=None, n_estimators=None, max_depth=None, split_threshold=None, min_samples_leaf=None)

Text classification class. This class enables us to train an RDT classifier for TF-IDF vectorized text data firstly and then apply it for inference.

Parameters:
languagestr, optional

Specify the language type. HANA cloud instance currently supports 'EN', 'DE', 'ES', 'FR', 'RU', 'PT'. If None, auto detection will be applied.

Defaults to None (auto detection).

enable_stopwordsbool, optional

Determine whether to turn on stopwords.

Defaults to True.

keep_numericbool, optional

Determine whether to keep numbers.

Valid only when enable_stopwords is True.

Defaults to False.

allowed_listbool, optional

A list of words that are retained by the stopwords logic.

Valid only when enable_stopwords is True.

notallowed_listbool, optional

A list of words, which are recognized and deleted by the stopwords logic.

Valid only when enable_stopwords is True.

n_estimatorsint, optional

Specifies the number of decision trees in the RDT model.

Defaults to 100.

max_depthint, optional

The maximum depth of a tree in RDT, where -1 means unlimited.

Default to 56.

split_thresholdfloat, optional

Specifies the stopping condition of the tree-growing process in RDT model: if the improvement value of the best split is less than this value, the tree stops growing.

Defaults to 1e-5.

min_samples_leafint, optional

Specifies the minimum number of records in a leaf of a tree in RDT model.

Defaults to 1.

Note

Note that parameters n_estimators, max_depth, split_threshold and min_samples_leaf are all for building the RDT model for text classification.

Examples

>>> tc = TextClassificationWithModel(enabel_stopwords=True,
...                                  n_estimators=50,
...                                  max_depth=6,
...                                  min_samples_leaf=2,
...                                  split_threshold=1e-6)
>>> tc.fit(data=document_file_train_data)
>>> pred_res = tc.predict(data=document_file_test_data)

Methods

fit(data[, seed, thread_ratio])

Train the model.

predict(data[, rdt_top_n, thread_ratio])

Predict the model.

fit(data, seed=None, thread_ratio=None)

Train the model.

Parameters:
dataDataFrame

Input data, structured as follows:

  • 1st column, ID.

  • 2nd column, Document content.

  • 3rd column, Document category.

seedint, optional

Specify the seed for random number generation.

thread_ratiofloat, optional

The ratio of total number of threads that can be used by this function.

Returns:
A fitted instance of class TextClassificationWithModel.
predict(data, rdt_top_n=None, thread_ratio=None)

Predict the model.

Parameters:
dataDataFrame

Input data, structured as follows:

  • 1st column, ID.

  • 2nd column, Document content.

rdt_top_nint, optional

Specify the number of top terms to be used for the Random Decision Tree algorithm.

thread_ratiofloat, optional

The ratio of total number of threads that can be used by this function.

Returns:
DataFrame

The result.

Inherited Methods from PALBase

Besides those methods mentioned above, the TextClassificationWithModel class also inherits methods from PALBase class, please refer to PAL Base for more details.