TextClassificationWithModel

class hana_ml.text.tm.TextClassificationWithModel(language=None, enable_stopwords=True, keep_numeric=None, allowed_list=None, notallowed_list=None, n_estimators=None, max_depth=None, split_threshold=None, min_samples_leaf=None)

Text classification class. This class enables us to train an RDT classifier for TF-IDF vectorized text data firstly and then apply it for inference.

Parameters
languagestr, optional

Specify the language type. HANA cloud instance currently supports 'EN', 'DE', 'ES', 'FR', 'RU', 'PT'. If None, auto detection will be applied.

Defaults to None (auto detection).

enable_stopwordsbool, optional

Determine whether to turn on stopwords.

Defaults to True.

keep_numericbool, optional

Determine whether to keep numbers.

Valid only when enable_stopwords is True.

Defaults to False.

allowed_listbool, optional

A list of words that are retained by the stopwords logic.

Valid only when enable_stopwords is True.

notallowed_listbool, optional

A list of words, which are recognized and deleted by the stopwords logic.

Valid only when enable_stopwords is True.

n_estimatorsint, optional

Specifies the number of decision trees in the RDT model.

Defaults to 100.

max_depthint, optional

The maximum depth of a tree in RDT, where -1 means unlimited.

Default to 56.

split_thresholdfloat, optional

Specifies the stopping condition of the tree-growing process in RDT model: if the improvement value of the best split is less than this value, the tree stops growing.

Defaults to 1e-5.

min_samples_leafint, optional

Specifies the minimum number of records in a leaf of a tree in RDT model.

Defaults to 1.

Note

Note that parameters n_estimators, max_depth, split_threshold and min_samples_leaf are all for building the RDT model for text classification.

Attributes
tf_idf_DataFrame

The TF-IDF result table generated during model training.

doc_term_freq_DataFrame

The document term frequency table generated during model training.

doc_category_DataFrame

The document category table generated during model training.

model_list of DataFrame

A list of DataFrames including TF-IDF result table, document term frequency table, document category table and the trained RDT model table.

Methods

fit(data[, seed, thread_ratio])

Train the model.

predict(data[, rdt_top_n, thread_ratio])

Predict the model.

Examples

>>> tc = TextClassificationWithModel(enable_stopwords=True,
...                                  n_estimators=50,
...                                  max_depth=6,
...                                  min_samples_leaf=2,
...                                  split_threshold=1e-6)
>>> tc.fit(data=document_file_train_data)
>>> pred_res = tc.predict(data=document_file_test_data)
fit(data, seed=None, thread_ratio=None)

Train the model.

Parameters
dataDataFrame

Input data, structured as follows:

  • 1st column, ID.

  • 2nd column, Document content.

  • 3rd column, Document category.

seedint, optional

Specify the seed for random number generation.

  • 0: Uses the current time (in second) as seed。

  • Others: Uses the specified value as seed。

Defaults to 0.

thread_ratiofloat, optional

Specifies the ratio of threads that can be used by this function. The range of this parameter is from 0 to 1, where 0 means only using one thread, and 1 means using at most all the currently available threads. Values outside this range are ignored and this function heuristically determines the number of threads to use.

Defaults to 0.0.

Returns
A fitted instance of class TextClassificationWithModel.
predict(data, rdt_top_n=None, thread_ratio=None)

Predict the model.

Parameters
dataDataFrame

Input data, structured as follows:

  • 1st column, ID.

  • 2nd column, Document content.

rdt_top_nint, optional

Controls how many results to output.

Defaults to 1.

thread_ratiofloat, optional

Specifies the ratio of threads that can be used by this function. The range of this parameter is from 0 to 1, where 0 means only using one thread, and 1 means using at most all the currently available threads. Values outside this range are ignored and this function heuristically determines the number of threads to use.

Defaults to 0.0.

Returns
DataFrame

The result.

Inherited Methods from PALBase

Besides those methods mentioned above, the TextClassificationWithModel class also inherits methods from PALBase class, please refer to PAL Base for more details.