TextClassificationWithModel
- class hana_ml.text.tm.TextClassificationWithModel(language=None, enable_stopwords=True, keep_numeric=None, allowed_list=None, notallowed_list=None, n_estimators=None, max_depth=None, split_threshold=None, min_samples_leaf=None)
Text classification class. This class enables us to train an RDT classifier for TF-IDF vectorized text data firstly and then apply it for inference.
- Parameters:
- languagestr, optional
Specify the language type. HANA cloud instance currently supports 'EN', 'DE', 'ES', 'FR', 'RU', 'PT'. If None, auto detection will be applied.
Defaults to None (auto detection).
- enable_stopwordsbool, optional
Determine whether to turn on stopwords.
Defaults to True.
- keep_numericbool, optional
Determine whether to keep numbers.
Valid only when
enable_stopwords
is True.Defaults to False.
- allowed_listbool, optional
A list of words that are retained by the stopwords logic.
Valid only when
enable_stopwords
is True.- notallowed_listbool, optional
A list of words, which are recognized and deleted by the stopwords logic.
Valid only when
enable_stopwords
is True.- n_estimatorsint, optional
Specifies the number of decision trees in the RDT model.
Defaults to 100.
- max_depthint, optional
The maximum depth of a tree in RDT, where -1 means unlimited.
Default to 56.
- split_thresholdfloat, optional
Specifies the stopping condition of the tree-growing process in RDT model: if the improvement value of the best split is less than this value, the tree stops growing.
Defaults to 1e-5.
- min_samples_leafint, optional
Specifies the minimum number of records in a leaf of a tree in RDT model.
Defaults to 1.
Note
Note that parameters
n_estimators
,max_depth
,split_threshold
andmin_samples_leaf
are all for building the RDT model for text classification.
Examples
>>> tc = TextClassificationWithModel(enabel_stopwords=True, ... n_estimators=50, ... max_depth=6, ... min_samples_leaf=2, ... split_threshold=1e-6) >>> tc.fit(data=document_file_train_data) >>> pred_res = tc.predict(data=document_file_test_data)
Methods
fit
(data[, seed, thread_ratio])Train the model.
predict
(data[, rdt_top_n, thread_ratio])Predict the model.
- fit(data, seed=None, thread_ratio=None)
Train the model.
- Parameters:
- dataDataFrame
Input data, structured as follows:
1st column, ID.
2nd column, Document content.
3rd column, Document category.
- seedint, optional
Specify the seed for random number generation.
- thread_ratiofloat, optional
The ratio of total number of threads that can be used by this function.
- Returns:
- A fitted instance of class TextClassificationWithModel.
- predict(data, rdt_top_n=None, thread_ratio=None)
Predict the model.
- Parameters:
- dataDataFrame
Input data, structured as follows:
1st column, ID.
2nd column, Document content.
- rdt_top_nint, optional
Specify the number of top terms to be used for the Random Decision Tree algorithm.
- thread_ratiofloat, optional
The ratio of total number of threads that can be used by this function.
- Returns:
- DataFrame
The result.
Inherited Methods from PALBase
Besides those methods mentioned above, the TextClassificationWithModel class also inherits methods from PALBase class, please refer to PAL Base for more details.