Text Classification


Analytics Zoo provides pre-defined models having different encoders that can be used for classifying texts. The model could be fed into NNFrames or BigDL Optimizer directly for training.


Build a TextClassifier Model

You can call the following API in Scala and Python respectively to create a TextClassifier with pre-trained GloVe word embeddings as the first layer.

Scala

val textClassifier = TextClassifier(classNum, embeddingFile, wordIndex = null, sequenceLength = 500, encoder = "cnn", encoderOutputDim = 256)

See here for the Scala example that trains the TextClassifier model on 20 Newsgroup dataset and uses the model to do prediction.

Python

text_classifier = TextClassifier(class_num, embedding_file, word_index=None, sequence_length=500, encoder="cnn", encoder_output_dim=256)

See here for the Python example that trains the TextClassifier model on 20 Newsgroup dataset and uses the model to do prediction.


Save Model

After building and training a TextClassifier model, you can save it for future use.

Scala

textClassifier.saveModel(path, weightPath = null, overWrite = false)

Python

text_classifier.save_model(path, weight_path=None, over_write=False)

Load Model

To load a TextClassifier model (with weights) saved above:

Scala

TextClassifier.loadModel(path, weightPath = null)

Python

TextClassifier.load_model(path, weight_path=None)