Text Classification API

Analytics Zoo provides pre-defined models having different encoders that can be used for classifying texts.


  1. Easy-to-use Keras-Style defined models which provides compile and fit methods for training. Alternatively, they could be fed into NNFrames or BigDL Optimizer.
  2. The encoders we support include CNN, LSTM and GRU.

Build a TextClassifier model

You can call the following API in Scala and Python respectively to create a TextClassifier with pre-trained GloVe word embeddings as the first layer.


val textClassifier = TextClassifier(classNum, embeddingFile, wordIndex = null, sequenceLength = 500, encoder = "cnn", encoderOutputDim = 256)


text_classifier = TextClassifier(class_num, embedding_file, word_index=None, sequence_length=500, encoder="cnn", encoder_output_dim=256)

Train a TextClassifier model

After building the model, we can call compile and fit to train it (with validation).

For training and validation data, you can first read files as TextSet (see here) and then do preprocessing (see here).


model.compile(optimizer = new Adagrad(learningRate), loss = SparseCategoricalCrossEntropy(), metrics = List(new Accuracy()))
model.fit(trainSet, batchSize, nbEpoch, validateSet)


model.compile(optimizer=Adagrad(learning_rate, loss="sparse_categorical_crossentropy", metrics=['accuracy'])
model.fit(train_set, batch_size, nb_epoch, validate_set)

Do prediction

After training the model, it can be used to predict probability distributions.


val predictSet = textClassifier.predict(validateSet)


predict_set = text_classifier.predict(validate_set)


We provide an example to train the TextClassifier model on 20 Newsgroup dataset and uses the model to do prediction.

See here for the Scala example.

See here for the Python example.