Sequence to Sequence
Analytics Zoo provides Seq2seq model which is a general-purpose encoder-decoder framework that can be used for Chatbot, Machine Translation and more. The model could be fed into NNFrames or BigDL Optimizer directly for training.
Build a Seq2seq Model
Before build Seq2seq Model, you need build Encoder, Decoder. And Bridge if you want to do some transformation before passing encoder states to decoder.
Build an Encoder
Currently we only support RNNEncoder which enables you to put RNN layers into encoder.
You can call the following API in Scala and Python respectively to create a RNNEncoder.
Scala
val encoder = RNNEncoder(rnnType, numLayer, hiddenSize, embedding)
rnnTypestyle of recurrent unit, one of [SimpleRNN, LSTM, GRU]numLayersnumber of layers used in encoderhiddenSizehidden size of encoderembeddingembedding layer in encoder, default isnull
You can also define RNN layers yourself
val encoder = RNNEncoder(rnns, embedding, inputShape)
rnnsrnn layers used for encoder, support stacked rnn layersembeddingembedding layer in encoder, default isnull
Python
encoder = RNNEncoder.initialize(rnn_type, nlayers, hidden_size, embedding)
rnn_typestyle of recurrent unit, one of [SimpleRNN, LSTM, GRU]nlayersnumber of layers used in encoderhidden_sizehidden size of encoderembeddingembedding layer in encoder, default isNone
Or
encoder = RNNEncoder(rnns, embedding, input_shape)
rnnsrnn layers used for encoder, support stacked rnn layersembeddingembedding layer in encoder, default isNone
Build a Decoder
Similar to Encoder, we only support RNNDecoder and API is pretty much the same with RNNEncoder
Scala
val decoder = RNNDecoder(rnnType, numLayers, hiddenSize, embedding)
rnnTypestyle of recurrent unit, one of [SimpleRNN, LSTM, GRU]numLayersnumber of layers used in decoderhiddenSizehidden size of decoderembeddingembedding layer in decoder, default isnull
You can also define RNN layers yourself
val decoder = RNNDecoder(rnns, embedding, inputShape)
rnnsrnn layers used for decoder, support stacked rnn layersembeddingembedding layer in decoder, default isnull
Python
encoder = RNNDecoder.initialize(rnn_type, nlayers, hidden_size, embedding):
rnn_typestyle of recurrent unit, one of [SimpleRNN, LSTM, GRU]nlayersnumber of layers used in decoderhidden_sizehidden size of decoderembeddingembedding layer in decoder, default isNone
Or
decoder = RNNDecoder(rnns, embedding, input_shape)
rnnsrnn layers used for decoder, support stacked rnn layersembeddingembedding layer in decoder, default isNone
Build a Bridge
By default, encoder states are directly fed into decoder. In this case, you don't need build a Bridge. But if you want to do some transformation before feed encoder states to decoder,
please use following API to create a Bridge.
Scala
val bridge = Bridge(bridgeType, decoderHiddenSize)
bridgeTypecurrently only support "dense | densenonlinear"decoderHiddenSizehidden size of decoder
You can also specify various keras layers as a Bridge
val bridge = Bridge(bridge)
bridgekeras layers used to do the transformation
Python
bridge = Bridge.initialize(bridge_type, decoder_hidden_size)
bridge_type: currently only support "dense | densenonlinear"decoder_hidden_size: hidden size of decoder
Or
bridge = Bridge.initialize_from_keras_layer(bridge)
bridgekeras layers used to do the transformation
Build a Seq2seq
Scala
val seq2seq = Seq2seq(encoder,
decoder,
inputShape,
outputShape,
bridge,
generator)
encoderan encoder objectdecodera decoder objectinputShapeshape of encoder input, for variable length, please input -1outputShapeshape of decoder input, for variable length, please input -1bridgeconnect encoder and decoder, you can inputnullgeneratorFeeding decoder output to generator to generate final result,nullis supported
See here for the Scala example that trains the Seq2seq model and uses the model to do prediction.
Python
seq2seq = Seq2seq(encoder, decoder, input_shape, output_shape, bridge,
generator)
encoderan encoder objectdecodera decoder objectinput_shapeshape of encoder input, for variable length, please input -1output_shapeshape of decoder input, for variable length, please input -1bridgeconnect encoder and decoder, you can inputnullgeneratorFeeding decoder output to generator to generate final result,Noneis supported