Analytics Zoo provides a built-in BERTClassifier in TFPark for Natural Language Processing (NLP) classification tasks based on TFEstimator and BERT.
Bidirectional Encoder Representations from Transformers (BERT) is Google's state-of-the-art pre-trained NLP model. You may refer to here for more details.
BERTClassifier is a pre-built TFEstimator that takes the hidden state of the first token to do classification.
- You need to install tensorflow==1.15.0 on your driver node.
- Your operating system (OS) is required to be one of the following 64-bit systems: Ubuntu 16.04 or later and macOS 10.12.6 or later.
- To run on other systems, you need to manually compile the TensorFlow source code. Instructions can be found here.
from zoo.tfpark.text.estimator import BERTClassifier estimator = BERTClassifier(num_classes, bert_config_file, init_checkpoint=None, use_one_hot_embeddings=False, optimizer=None, model_dir=None)
num_classes: Positive int. The number of classes to be classified.
bert_config_file: The path to the json file for BERT configurations.
init_checkpoint: The path to the initial checkpoint of the pre-trained BERT model if any. Default is None.
use_one_hot_embeddings: Boolean. Whether to use one-hot for word embeddings. Default is False.
optimizer: The optimizer used to train the estimator. It can either be an instance of tf.train.Optimizer or the corresponding string representation. Default is None if no training is involved.
model_dir: The output directory for model checkpoints to be written if any. Default is None.