Text Classification API


Analytics Zoo provides pre-defined models having different encoders that can be used for classifying texts.

Highlights

  1. Easy-to-use models, could be fed into NNFrames or BigDL Optimizer for training.
  2. The encoders we support include CNN, LSTM and GRU.

Build a TextClassifier model

Scala

val textClassifier = TextClassifier(classNum, tokenLength, sequenceLength = 500, encoder = "cnn", encoderOutputDim = 256)

Python

text_classifier = TextClassifier(class_num, token_length, sequence_length=500, encoder="cnn", encoder_output_dim=256)

Train a TextClassifier model

After building the model, we can use BigDL Optimizer to train it (with validation) using RDD of Sample.

Note that raw text data may need to go through tokenization and vectorization before being fed into the Optimizer. You can refer to the examples we provide for data pre-processing.

Scala

val optimizer = Optimizer(
  model = textClassifier,
  sampleRDD = trainRDD,
  criterion = ClassNLLCriterion[Float](logProbAsInput = false),
  batchSize = 128)

optimizer
  .setOptimMethod(new Adagrad(learningRate = 0.01, learningRateDecay = 0.001))
  .setValidation(Trigger.everyEpoch, valRDD, Array(new Top1Accuracy), 128)
  .setEndWhen(Trigger.maxEpoch(20))
  .optimize()

Python

optimizer = Optimizer(
    model=text_classifier,
    training_rdd=train_rdd,
    criterion=ClassNLLCriterion(logProbAsInput=False),
    end_trigger=MaxEpoch(20),
    batch_size=128,
    optim_method=Adagrad(learningrate=0.01, learningrate_decay=0.001))

optimizer.set_validation(
    batch_size=128,
    val_rdd=val_rdd,
    trigger=EveryEpoch(),
    val_method=[Top1Accuracy()])

Do prediction

After training the model, it can be used to predict probabilities or class labels.

Scala

// Predict for probability distributions.
val results = textClassifier.predict(rdd)
// Predict for class labels. By default, label starts from 0.
val resultClasses = textClassifier.predictClasses(rdd)

Python

# Predict for probability distributions.
results = text_classifier.predict(rdd)
# Predict for class labels. By default, label starts from 0.
result_classes = text_classifier.predict_classes(rdd)

Examples

We provide an example to train the TextClassifier model on 20 Newsgroup dataset and uses the model to do prediction.

See here for the Scala example.

See here for the Python example.