TextClassificationWithModel
- class hana_ml.text.tm.TextClassificationWithModel(language=None, enable_stopwords=True, keep_numeric=None, allowed_list=None, notallowed_list=None)
Text classification class. This class enables us to train an RDT classifier for TF-IDF vectorized text data firstly and then apply it for inference.
- Parameters:
- languagestr, optional
Specify the language type. HANA cloud instance currently supports 'EN', 'DE', 'ES', 'FR', 'RU', 'PT'. If None, auto detection will be applied.
Defaults to None (auto detection).
- enable_stopwordsbool, optional
Determine whether to turn on stopwords.
Defaults to True.
- keep_numericbool, optional
Determine whether to keep numbers.
Valid only when
enable_stopwords
is True.Defaults to False.
- allowed_listbool, optional
A list of words that are retained by the stopwords logic.
Valid only when
enable_stopwords
is True.- notallowed_listbool, optional
A list of words, which are recognized and deleted by the stopwords logic.
Valid only when
enable_stopwords
is True.
Examples
>>> tc = TextClassificationWithModel(enabel_stopwords=True) >>> tc.fit(data=document_file_train_data) >>> pred_res = tc.predict(data=document_file_test_data)
Methods
fit
(data[, seed, thread_ratio])Train the model.
Get the model metrics.
Get the score metrics.
predict
(data[, rdt_top_n, thread_ratio])Predict the model.
- fit(data, seed=None, thread_ratio=None)
Train the model.
- Parameters:
- dataDataFrame
Input data, structured as follows:
1st column, ID.
2nd column, Document content.
3rd column, Document category.
- seedint, optional
Specify the seed for random number generation.
- thread_ratiofloat, optional
The ratio of total number of threads that can be used by this function.
- Returns:
- A fitted instance of class TextClassificationWithModel.
- predict(data, rdt_top_n=None, thread_ratio=None)
Predict the model.
- Parameters:
- dataDataFrame
Input data, structured as follows:
1st column, ID.
2nd column, Document content.
- rdt_top_nint, optional
Specify the number of top terms to be used for the Random Decision Tree algorithm.
- thread_ratiofloat, optional
The ratio of total number of threads that can be used by this function.
- Returns:
- DataFrame
- get_model_metrics()
Get the model metrics.
- Returns:
- DataFrame
The model metrics.
- get_score_metrics()
Get the score metrics.
- Returns:
- DataFrame
The score metrics.
Inherited Methods from PALBase
Besides those methods mentioned above, the TextClassificationWithModel class also inherits methods from PALBase class, please refer to PAL Base for more details.