Text Classification API


Analytics Zoo provides pre-defined models having different encoders that can be used for classifying texts.

Highlights

  1. Easy-to-use Keras-Style defined models which provides compile and fit methods for training. Alternatively, they could be fed into NNFrames or BigDL Optimizer.
  2. The encoders we support include CNN, LSTM and GRU.

Build a TextClassifier model

You can call the following API in Scala and Python respectively to create a TextClassifier with pre-trained GloVe word embeddings as the first layer.

Scala

val textClassifier = TextClassifier(classNum, embeddingFile, wordIndex = null, sequenceLength = 500, encoder = "cnn", encoderOutputDim = 256)

Python

text_classifier = TextClassifier(class_num, embedding_file, word_index=None, sequence_length=500, encoder="cnn", encoder_output_dim=256)

Train a TextClassifier model

After building the model, we can call compile and fit to train it (with validation).

For training and validation data, you can first read files as TextSet (see here) and then do preprocessing (see here).

Scala

model.compile(optimizer = new Adagrad(learningRate), loss = SparseCategoricalCrossEntropy(), metrics = List(new Accuracy()))
model.fit(trainSet, batchSize, nbEpoch, validateSet)

Python

model.compile(optimizer=Adagrad(learning_rate, loss="sparse_categorical_crossentropy", metrics=['accuracy'])
model.fit(train_set, batch_size, nb_epoch, validate_set)

Do prediction

After training the model, it can be used to predict probability distributions.

Scala

val predictSet = textClassifier.predict(validateSet)

Python

predict_set = text_classifier.predict(validate_set)

Examples

We provide an example to train the TextClassifier model on 20 Newsgroup dataset and uses the model to do prediction.

See here for the Scala example.

See here for the Python example.