Text Classification


Analytics Zoo provides pre-defined models having different encoders that can be used for classifying texts. The model could be fed into NNFrames or BigDL Optimizer directly for training.


Build a TextClassifier Model

Scala

val textClassifier = TextClassifier(classNum, tokenLength, sequenceLength = 500, encoder = "cnn", encoderOutputDim = 256)

See here for the Scala example that trains the TextClassifier model on 20 Newsgroup dataset and uses the model to do prediction.

Python

text_classifier = TextClassifier(class_num, token_length, sequence_length=500, encoder="cnn", encoder_output_dim=256)

See here for the Python example that trains the TextClassifier model on 20 Newsgroup dataset and uses the model to do prediction.


Model Save

After building and training a TextClassifier model, you can save it for future use.

Scala

textClassifier.saveModel(path, weightPath = null, overWrite = false)

Python

text_classifier.save_model(path, weight_path=None, over_write=False)

Model Load

To load a TextClassifier model (with weights) saved above:

Scala

TextClassifier.loadModel[Float](path, weightPath = null)

Python

TextClassifier.load_model(path, weight_path=None)