Self Attention Layers
TransformerLayer
A network architecture based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Refer to this paper for more details.
Input is a Table which consists of 2 tensors.
- Token id tensor: shape (batch, seqLen) with the word token indices in the vocabulary
- Position id tensor: shape (batch, seqLen) with positions in the sentence.
Output is a Table as well.
- The states of Transformer layer.
- The pooled output which processes the hidden state of the last layer with regard to the first token of the sequence. This would be useful for segment-level tasks.
With Default Embedding:
Scala:
TransformerLayer[Float](vocab = 40990,
seqLen = 77,
nBlock = 12,
residPdrop = 0.1,
attnPdrop = 0.1,
nHead = 12,
hiddenSize = 768,
embeddingDrop = 0,
initializerRange = 0.02,
bidirectional = false,
outputAllBlock = false)
Python:
TransformerLayer.init(vocab=40990, seq_len=77, n_block=12, hidden_drop=0.1,
attn_drop=0.1, n_head=12, hidden_size=768,
embedding_drop=0.1, initializer_range=0.02,
bidirectional=False, output_all_block=False)
Parameters:
vocab
: vocabulary size of training data, default is 40990seqLen
: max sequence length of training data, default is 77nBlock
: block number, default is 12residPdrop
: drop probability of projection, default is 0.1attnPdrop
: drop probability of attention, default is 0.1nHead
: head number, default is 12hiddenSize
: is also embedding sizeembeddingDrop
: drop probability of embedding layer, default is 0.1initializerRange
: weight initialization range, default is 0.02bidirectional
: whether unidirectional or bidirectional, default is falseoutputAllBlock
: whether output all blocks' output, default is false
With Customized Embedding:
Scala:
TransformerLayer[Float](nBlock = 12,
residPdrop = 0.1,
attnPdrop = 0.1,
nHead = 12,
bidirectional = false,
initializerRange = 0.02,
outputAllBlock = true,
embeddingLayer = embedding.asInstanceOf[KerasLayer[Activity, Tensor[Float], Float]])
Python:
TransformerLayer(n_block=12,
hidden_drop=0.1,
attn_drop=0.1,
n_head=12,
initializer_range=0.02,
bidirectional=False,
output_all_block=False,
embedding_layer=embedding,
input_shape=((seq_len,), (seq_len,)),
intermediate_size=0)
Parameters:
nBlock
: block numberresidPdrop
: drop probability of projectionattnPdrop
: drop probability of attentionnHead
: head numberinitializerRange
: weight initialization rangebidirectional
: whether unidirectional or bidirectionaloutputAllBlock
: whether output all blocks' outputembeddingLayer
: embedding layer
Scala example:
val shape1 = Shape(20)
val shape2 = Shape(20)
val input1 = Variable[Float](shape1)
val input2 = Variable[Float](shape2)
val input = Array(input1, input2)
val seq = TransformerLayer[Float](200, hiddenSize = 128, nHead = 8,
seqLen = 20, nBlock = 1).from(input: _*)
val model = Model[Float](input, seq)
val trainToken = Tensor[Float](1, 20).rand()
val trainPos = Tensor.ones[Float](1, 20)
val input3 = T(trainToken, trainPos)
val output = model.forward(input3)
Input is:
{
2: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0
[com.intel.analytics.bigdl.tensor.DenseTensor$mcF$sp of size 1x20]
1: 0.8087359 0.16409875 0.7404631 0.4836999 0.034994964 0.033039592 0.6694243 0.84700763 0.32154092 0.17410904 0.66117364 0.30495027 0.19573595 0.058101892 0.65923077 0.84077805 0.50113535 0.48393667 0.06523132 0.0667426
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x20]
}
Output is:
{
2: 0.83383083 0.72725344 0.16394942 -0.79005975 0.8877357 -0.9060916 -0.6796065 0.46835706 -0.4700584 0.43868023 0.6641587 0.6711142 -0.70056283 -0.42694178 0.7615595 -0.25590983 0.21654142 0.35254374 0.83790034 0.1103606 -0.20419843 -0.9739706 0.6150182 0.4499923 0.3355538 -0.01543447 -0.99528116 0.45984524 -0.22544041 0.10049125 0.8418835 -0.116228305 -0.112435654 0.5183222 -0.59375525 0.31828925 0.50506884 0.14892755 0.94327587 -0.19001998 0.54074824 -0.07616825 -0.79334164 -0.49726814 0.23889944 -0.91731304 -0.5484148 0.5048103 0.9743351 0.10505025 0.81167877 -0.47498485 -0.83443964 -0.89340115 0.6443838 0.10184191 -0.38618097 -0.32026938 0.51587516 -0.40602723 -0.2931675 -0.86100364 0.109585665 0.9023708 0.46609795 0.0028693299 -0.5746851 -0.45607233 -0.9075561 -0.91294044 0.8077997 0.23019081 0.51124465 -0.39125186 0.16946821 -0.36827865 -0.32563296 0.62560886 -0.7278883 0.8076773 0.89344263 -0.9259615 0.21476166 0.67077845 0.5857905 -0.32905066 -0.16318946 0.6435858 -0.28905967 -0.6991412 -0.5289766 -0.6954091 0.1577004 0.5618301 -0.6290018 0.114078626 -0.52474076 0.27916297 -0.76610357 0.67119384 -0.4308661 0.063731246 -0.5281069 -0.65910465 0.5383283 -0.2875557 0.24594739 -0.6789035 0.7002648 -0.64659894 -0.70994437 -0.8416273 0.4666695 -0.55062526 0.14995292 -0.978979 0.40934727 -0.9028927 0.38194665 0.2334618 -0.9481384 -0.51903373 -0.947906 0.2667679 -0.76987743 -0.7490675 0.6777159 0.9593161
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x128]
1: (1,.,.) =
0.8369983 -0.9907519 0.74404025 ... 0.6154673 0.107825294 -0.806892
0.7676861 -0.962961 0.73240614 ... 0.534349 0.0049344404 -0.81643736
0.7487803 -0.9717681 0.7315394 ... 0.59831613 0.010904985 -0.82502025
...
0.06956328 -1.2103055 1.4155688 ... -0.759053 0.6966926 -0.53496075
0.0759853 -1.2265961 1.4023252 ... -0.7500985 0.68647313 -0.52275336
0.06356962 -1.2309887 1.3984702 ... -0.751963 0.69192046 -0.52820134
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x20x128]
}
Python example:
model = TransformerLayer.init(
vocab=200, hidden_size=128, n_head=4, seq_len=20)
train_token = np.random.randint(20, size=(2, 20))
train_pos = np.zeros((2, 20), dtype=np.int32)
input = [train_token, train_pos]
output = model.forward(input)
Input is:
<type 'list'>: [array([[11, 2, 16, 6, 17, 18, 2, 4, 5, 16, 18, 15, 13, 19, 5, 15,
14, 14, 2, 9],
[10, 15, 13, 6, 12, 0, 11, 3, 16, 13, 6, 13, 17, 13, 3, 4,
15, 5, 7, 15]]), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
dtype=int32)]
Output is
<type 'list'>: [array([[[ 0.26004127, -0.31793368, -1.1605529 , ..., -0.81875914,
-0.02121837, -0.8328352 ],
[-0.8622302 , -0.35201085, 0.63190293, ..., 2.0652232 ,
1.5278 , 0.38224357],
[-2.5103235 , 1.4465114 , 0.71134603, ..., 1.1776686 ,
0.6882701 , 0.3534629 ],
...,
[-0.22725764, 1.2112792 , -0.40597847, ..., 2.2241254 ,
0.2580125 , -1.1470895 ],
[-0.56174546, 1.3353435 , -0.7445968 , ..., 1.1259638 ,
0.6951011 , -1.1421459 ],
[-0.6615135 , 1.1899865 , -0.81727505, ..., 2.0474243 ,
0.20160393, -0.7789728 ]],
[[-1.1624268 , -0.5375418 , -0.7274868 , ..., -0.99061227,
-0.57117355, 1.0684316 ],
[ 0.11317759, -0.7231343 , 0.7723393 , ..., 1.6518786 ,
1.0916579 , 0.18682887],
[-1.9651127 , 0.9987117 , 0.32025027, ..., 0.94719195,
-0.21028236, -0.02251417],
...,
[-0.6677234 , 0.69822913, -0.9714249 , ..., 2.208334 ,
0.7719772 , -0.93855625],
[-0.63691545, 1.3876344 , -0.8491991 , ..., 2.060551 ,
0.34702447, -0.8160082 ],
[-0.6608573 , 1.2608795 , -0.46634364, ..., 2.100828 ,
0.2967869 , -1.0938305 ]]], dtype=float32), array([[ 0.06879381, 0.6821829 , -0.8267953 , -0.02695777, -0.53899264,
0.8241045 , 0.6976903 , 0.31741282, 0.23590134, 0.5565326 ,
0.95292866, 0.5658284 , -0.2916065 , -0.37934095, -0.2774958 ,
0.73409927, -0.71731025, 0.07897043, 0.88609815, -0.27966806,
0.93520796, 0.72740096, 0.1626402 , -0.26063287, 0.28597558,
-0.12945679, 0.7151408 , -0.8463592 , -0.48385444, -0.29313505,
0.86453205, -0.93834317, 0.41815573, 0.92436415, 0.8209114 ,
0.6627246 , -0.574135 , 0.607416 , 0.04769071, -0.29779348,
-0.26268572, -0.78998053, -0.7522611 , 0.89941144, -0.15754697,
0.9298859 , -0.8327022 , -0.63423705, -0.63789636, -0.14168388,
-0.56104964, -0.80995566, 0.9244693 , 0.4679966 , -0.16284083,
0.8478645 , 0.29836348, -0.15369722, -0.4490478 , 0.11052075,
0.23767054, 0.59320366, -0.79055625, 0.22201608, -0.88366413,
-0.4410687 , 0.8762162 , -0.6516914 , -0.5993653 , -0.5972125 ,
-0.86697286, -0.17644943, 0.95839834, -0.06382846, 0.7430881 ,
-0.59690744, 0.3901914 , 0.06803267, 0.9142394 , 0.7583274 ,
-0.18442968, 0.56280667, -0.37844184, -0.41195455, -0.8376329 ,
0.87641823, -0.98970294, -0.6764397 , -0.86945957, -0.69273126,
0.9911777 , 0.417286 , -0.8774987 , 0.17141937, 0.7204654 ,
-0.62387246, -0.8795049 , 0.62618923, -0.29725042, -0.4565646 ,
-0.47798416, -0.97555065, -0.94241685, -0.97800356, 0.8523641 ,
-0.96860206, 0.5378995 , -0.73754525, -0.01649606, -0.4274561 ,
-0.5290453 , 0.11851768, 0.48821065, 0.4822751 , 0.49497148,
-0.5734494 , -0.29612035, -0.7254394 , -0.1418346 , -0.56686646,
0.03665365, -0.9586826 , -0.0983429 , -0.09348761, -0.96338177,
0.76481736, 0.87975204, 0.70463663],
[-0.09654156, 0.78266025, -0.9125131 , -0.6706971 , -0.58709925,
-0.94729275, -0.32309514, -0.95263994, 0.2036015 , -0.9297767 ,
0.6164713 , 0.3484337 , 0.46247053, 0.21615174, -0.8382687 ,
-0.55828595, -0.59234536, -0.9643932 , 0.9310115 , -0.12657425,
0.63812125, 0.80040973, -0.47581342, 0.9823402 , -0.5400171 ,
0.5864317 , -0.19979174, -0.5721838 , 0.9190707 , 0.31628668,
0.08952013, 0.8719338 , 0.26684833, 0.8955768 , -0.9275499 ,
-0.81994563, 0.28863704, -0.16376448, 0.15855551, 0.04302022,
0.4440408 , -0.7293209 , 0.2255107 , 0.16333969, 0.38721767,
-0.04512435, -0.5473172 , -0.5812051 , -0.8219114 , -0.43659028,
-0.04860768, -0.8912252 , 0.62100273, 0.7187475 , -0.06158534,
0.6554498 , -0.62163985, 0.63035303, 0.19207267, -0.68847877,
0.10341872, -0.88906926, -0.38804066, -0.8157233 , -0.81641346,
0.8846337 , -0.70225614, 0.6281251 , -0.81235796, 0.77828485,
0.9393982 , -0.42554784, 0.4150426 , -0.32612413, -0.721988 ,
0.96166253, -0.6080237 , -0.7312329 , 0.06843777, -0.09806018,
-0.7357863 , -0.28613612, -0.8895085 , -0.9027925 , 0.56311375,
0.85699487, -0.32128897, 0.80635303, -0.01190906, -0.23292968,
-0.5115769 , 0.17153661, -0.79993784, 0.6232265 , -0.06049479,
-0.83510727, 0.9652135 , 0.08310007, -0.9671807 , -0.17466563,
0.48009604, 0.594712 , 0.19612817, -0.9279629 , -0.59968966,
-0.36079255, -0.7250685 , 0.59395283, 0.7574965 , -0.4377294 ,
0.45312116, 0.7117049 , -0.82085943, -0.10442825, 0.73688287,
0.38598123, 0.35439053, -0.3862137 , -0.56253886, 0.7388591 ,
-0.6024478 , -0.699977 , -0.46581215, -0.79513186, 0.09657894,
0.280869 , -0.38445532, -0.98311806]], dtype=float32)]
BERT
Bidirectional Encoder Representations from Transformers. Refer https://arxiv.org/pdf/1810.04805.pdf
Input is a Table which consists of 4 tensors.
- Token id tensor: shape (batch, seqLen) with the word token indices in the vocabulary
- Token type id tensor: shape (batch, seqLen) with the token types in (0, 1).
0 means
sentence A
and 1 means asentence B
(see BERT paper for more details). - Position id tensor: shape (batch, seqLen) with positions in the sentence.
- Attention_mask tensor: shape (batch, seqLen) with indices in (0, 1). It's a mask to be used if the input sequence length is smaller than seqLen in the current batch.
Output is a Table as well.
- The states of BERT layer.
- The pooled output which processes the hidden state of the last layer with regard to the first token of the sequence. This would be useful for segment-level tasks.
With Default Embedding:
Scala:
BERT[Float](vocab: Int = 40990,
hiddenSize: Int = 768,
nBlock: Int = 12,
nHead: Int = 12,
maxPositionLen: Int = 512,
intermediateSize: Int = 3072,
hiddenPDrop: Double = 0.1,
attnPDrop: Double = 0.1,
initializerRange: Double = 0.02,
outputAllBlock: Boolean = true,
inputSeqLen: Int = -1)
Python:
BERT.init(vocab=40990,
hidden_size=768,
n_block=12,
n_head=12,
seq_len=512,
intermediate_size=3072,
hidden_drop=0.1,
attn_drop=0.1,
initializer_range=0.02,
output_all_block=True)
Parameters:
vocab
: vocabulary size of training data, default is 40990hiddenSize
: size of the encoder layers, default is 768nBlock
: block number, default is 12nHead
: head number, default is 12maxPositionLen
: sequence length, default is 512intermediateSize
: The size of the "intermediate" (i.e., feed-forward), default is 3072hiddenPDrop
: The dropout probability for all fully connected layers, default is 0.1attnPDrop
: drop probability of attention, default is 0.1initializerRange
: weight initialization range, default is 0.02outputAllBlock
: whether output all blocks' output, default is falseinputSeqLen
: sequence length of input, default is -1 which means the same with maxPositionLen
With Customized Embedding:
Scala:
BERT[Float](nBlock = 12,
nHead = 12,
intermediateSize = 3072,
hiddenPDrop = 0.1,
attnPDrop = 0.1,
initializerRange = 0.02,
outputAllBlock = true,
embeddingLayer = embedding)
Python:
BERT(n_block=12,
n_head=12,
intermediate_size=3072,
hidden_drop=0.1,
attn_drop=0.1,
initializer_range=0.02,
output_all_block=True,
embedding_layer=embedding,
input_shape=((seq_len,), (seq_len,), (seq_len,), (1, 1, seq_len)))
Parameters:
nBlock
: block numbernHead
: head numberintermediateSize
: The size of the "intermediate" (i.e., feed-forward)hiddenPDrop
: The dropout probability for all fully connected layersattnPdrop
: drop probability of attentioninitializerRange
: weight initialization rangeoutputAllBlock
: whether output all blocks' outputembeddingLayer
: embedding layer
Loading from existing pretrained model:
Scala:
BERT[Float](path = "",
weightPath = null,
inputSeqLen = 11,
hiddenPDrop = 0.1,
attnPDrop = 0.1,
outputAllBlock = true)
Python:
BERT.init_from_existing_model(path="",
weight_path=None,
input_seq_len=-1.0,
hidden_drop=-1.0,
attn_drop=-1.0,
output_all_block=True)
Parameters:
path
: The path for the pre-defined model. Local file system, HDFS and Amazon S3 are supported. Amazon S3 path should be like "s3a://bucket/xxx".weightPath
: The path for pre-trained weights if anyinputSeqLen
: sequence length of input, will be ignored if existing model is built with customized embeddinghiddenPDrop
: The dropout probability for all fully connected layers, will be ignored if existing model is built with customized embeddingattnPdrop
: drop probability of attention, will be ignored if existing model is built with customized embedding
Scala example:
val layer = BERT[Float](vocab = 100,
hiddenSize = 10,
nBlock = 3,
nHead = 2,
intermediateSize = 64,
hiddenPDrop = 0.1,
attnPDrop = 0.1,
maxPositionLen = 10,
outputAllBlock = false,
inputSeqLen = 10)
val shape = Shape(List(Shape(1, 10), Shape(1, 10), Shape(1, 10), Shape(1, 1, 1, 10)))
layer.build(shape)
val inputIds = Tensor[Float](Array[Float](7, 20, 39, 27, 10,
39, 30, 21, 17, 15), Array(1, 10))
val segmentIds = Tensor[Float](Array[Float](0, 0, 0, 0, 0, 1, 1, 1, 1, 1), Array(1, 10))
val positionIds = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(1, 10))
val masks = Tensor[Float](1, 1, 1, 10).fill(1.0f)
val output = layer.forward(T(inputIds, segmentIds, positionIds, masks))
Input is:
{
2: 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
4: (1,1,.,.) =
1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x1x1x10]
1: 7.0 20.0 39.0 27.0 10.0 39.0 30.0 21.0 17.0 15.0
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
3: 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
}
Output is:
{
2: 0.5398573 0.08571402 -0.9461041 -0.35362077 -0.24374364 0.24349216 0.9587727 -0.03278971 -0.826852 -0.8808889
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
1: (1,.,.) =
1.3381815 1.7575556 -1.1870699 0.8455374 -1.6000531 0.115945406 -0.33695826 -0.39254665 -0.33637434 -0.20421773
-0.08370285 0.056055143 -0.91990083 1.6324282 -0.093128644 -0.4484297 -2.0828273 0.10244746 0.577287 1.2597716
0.3563086 0.37092525 -0.5089354 0.4525072 1.7706354 0.65231055 -2.0269241 -0.2548585 0.3711578 -1.1831268
0.2429675 -0.023419544 -0.28389466 0.6601246 -0.009858845 -0.028412571 -2.5104556 1.0338438 1.3621751 -0.44306967
1.7147139 1.1627073 -0.19394834 0.8043055 -1.0080436 -1.7716306 -0.7668168 -0.19861369 0.45103902 -0.19371253
0.077525005 0.0722655 1.0745171 0.07997274 0.06562643 1.6474637 0.18938908 -2.377528 -0.6107291 -0.21850263
-1.3190242 1.7057956 0.32655835 0.5711799 -0.80318034 0.2776545 1.4860673 -0.676896 -0.39734793 -1.1708072
-0.4327645 -0.19849697 0.3695452 -0.08213705 1.2378154 0.591234 -1.505518 1.684885 -1.6251724 -0.03939093
0.6422535 -0.582018 1.6665243 -1.0995792 0.19488664 1.3563607 -0.60793823 -0.05846788 -1.7225715 0.21054967
-1.0927358 -0.37666538 0.70802236 -2.0131714 0.94964516 1.4701655 0.053027537 0.051168486 -0.58528 0.83582383
[com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10x10]
}
Python example:
layer = BERT.init(
vocab=200, hidden_size=128, n_head=4, seq_len=20, intermediate_size=20)
train_token = np.random.randint(20, size=(2, 20))
token_type_id = np.zeros((2, 20), dtype=np.int32)
train_pos = np.zeros((2, 20), dtype=np.int32)
mask_attention = np.ones((2, 1, 1, 20), dtype=np.int32)
input = [train_token, token_type_id, train_pos, mask_attention]
output = layer.forward(input)
Input is:
<type 'list'>: [array([[ 8, 19, 5, 8, 4, 13, 13, 12, 1, 6, 16, 14, 19, 0, 11, 18,
1, 17, 0, 0],
[17, 10, 15, 19, 15, 2, 18, 8, 1, 11, 10, 17, 7, 2, 0, 0,
9, 14, 11, 6]]), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
dtype=int32), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
dtype=int32), array([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]],
dtype=int32)]
Output is
<type 'list'>: [array([[[ 1.01330066e+00, 4.74100798e-01, 3.81211847e-01, ...,
-8.00989151e-01, -6.18815482e-01, -1.09804094e+00],
[-6.29726201e-02, 6.71391249e-01, -2.28019580e-01, ...,
3.97501498e-01, -3.32217604e-01, -2.04850674e+00],
[ 5.04802346e-01, 1.00434709e+00, -7.63663530e-01, ...,
-1.11675525e+00, 1.41550392e-01, -6.47688091e-01],
...,
[ 5.07664750e-04, 6.47843182e-01, -5.33694960e-02, ...,
-2.01566055e-01, -6.62943959e-01, -2.93835902e+00],
[-1.49255514e-01, -5.47551095e-01, -3.36264402e-01, ...,
1.11121520e-01, -4.42977905e-01, -2.13847613e+00],
[-1.33390293e-01, -5.50503194e-01, -3.49355727e-01, ...,
9.34313685e-02, -4.41935956e-01, -8.25921223e-02]],
[[ 1.34967836e-02, 3.90778482e-02, -2.22317409e-02, ...,
-9.81532633e-02, -6.08992100e-01, -2.77224326e+00],
[-5.21431923e-01, -6.74456656e-02, -4.66892511e-01, ...,
1.05466165e-01, -2.67068297e-01, -9.00964141e-01],
[ 9.41378593e-01, 5.21076620e-01, -5.35079956e-01, ...,
-3.15736473e-01, 8.08603615e-02, -2.44178265e-01],
...,
[ 4.16017324e-02, -5.65874994e-01, 6.68676615e-01, ...,
-6.28256857e-01, -9.09847617e-02, -2.42878512e-01],
[-1.36971796e+00, 5.37231266e-01, -1.33729517e+00, ...,
-1.47498712e-01, 8.00304264e-02, -5.09030581e-01],
[ 3.98404837e-01, -7.18296226e-03, -1.08256066e+00, ...,
-5.17360926e-01, 5.50065935e-01, -2.32753420e+00]]],
dtype=float32), array([[[ 1.0795887 , 0.44977495, 0.45561683, ..., -0.729603 ,
-0.6098092 , -1.0323973 ],
[ 0.0154253 , 0.6424524 , -0.15503715, ..., 0.45100495,
-0.3161888 , -1.9826275 ],
[ 0.58491707, 0.9876782 , -0.69952184, ..., -1.0432141 ,
0.1380458 , -0.5642554 ],
...,
[ 0.06806332, 0.61824507, 0.02341641, ..., -0.21342012,
-0.63312817, -2.8557966 ],
[-0.06217962, -0.5528828 , -0.34740448, ..., 0.16651583,
-0.41633344, -2.064906 ],
[-0.04626712, -0.57442385, -0.277238 , ..., 0.13806444,
-0.43256086, -0.01180306]],
[[ 0.10532085, 0.01057051, 0.07536474, ..., -0.03406155,
-0.572023 , -2.6935408 ],
[-0.42477775, -0.10768362, -0.37653154, ..., 0.17155378,
-0.27841952, -0.8244427 ],
[ 1.0290473 , 0.5059685 , -0.5359356 , ..., -0.25725254,
0.1034779 , -0.16898313],
...,
[ 0.14118548, -0.5945706 , 0.7681386 , ..., -0.55807835,
-0.07778832, -0.15940095],
[-1.2648381 , 0.50598496, -1.2431567 , ..., -0.06980868,
0.10642368, -0.4181047 ],
[ 0.48330045, -0.05184587, -0.9985824 , ..., -0.5360492 ,
0.56541353, -2.2607849 ]]], dtype=float32), array([[[ 1.1541002 , 0.47630545, 0.40673187, ..., -0.7284888 ,
-0.55945337, -1.0810231 ],
[ 0.10070852, 0.64252985, -0.2007717 , ..., 0.4489277 ,
-0.24709189, -2.0173872 ],
[ 0.67520154, 0.9793912 , -0.7441366 , ..., -1.0376649 ,
0.20359974, -0.6060102 ],
...,
[ 0.07809319, 0.63523245, -0.02464442, ..., -0.21328981,
-0.5693355 , -2.8386393 ],
[ 0.02228299, -0.5229728 , -0.33483037, ..., 0.16430138,
-0.40036577, -2.094183 ],
[ 0.02232744, -0.54113156, -0.3307599 , ..., 0.14321396,
-0.3796677 , -0.04973204]],
[[ 0.20497712, 0.02804335, 0.028764 , ..., -0.01617111,
-0.5416485 , -2.7333891 ],
[-0.3361876 , -0.08618001, -0.41299412, ..., 0.17708196,
-0.23643918, -0.8763187 ],
[ 1.1118197 , 0.5178778 , -0.57264006, ..., -0.2597192 ,
0.15024357, -0.23373066],
...,
[ 0.23304611, -0.57528406, 0.71815467, ..., -0.5524511 ,
-0.04103457, -0.15449452],
[-1.1629226 , 0.5377656 , -1.2816569 , ..., -0.05795323,
0.1603044 , -0.47194824],
[ 0.5773567 , -0.04114214, -1.0306932 , ..., -0.52537155,
0.5703101 , -2.3124278 ]]], dtype=float32), array([[[ 1.1319652e+00, 4.2663044e-01, 3.9611375e-01, ...,
-7.7264631e-01, -5.3006041e-01, -1.0942854e+00],
[ 7.3858641e-02, 6.1578143e-01, -2.0985913e-01, ...,
4.0289888e-01, -2.2484708e-01, -2.0233095e+00],
[ 6.3545388e-01, 9.4610500e-01, -7.6165521e-01, ...,
-1.0820770e+00, 2.2266804e-01, -6.0132843e-01],
...,
[ 3.9479308e-02, 6.0636342e-01, -2.8302141e-02, ...,
-2.6316714e-01, -5.5309945e-01, -2.8510940e+00],
[-2.8412668e-03, -5.5100703e-01, -3.4540960e-01, ...,
1.5979633e-01, -3.8844827e-01, -2.0994248e+00],
[-1.4572166e-02, -5.7526213e-01, -3.3382124e-01, ...,
1.0289014e-01, -3.6059290e-01, -6.5041430e-02]],
[[ 1.5256011e-01, 3.3955947e-03, 1.7648729e-02, ...,
-4.9600061e-02, -5.1613468e-01, -2.7417533e+00],
[-3.6988521e-01, -8.8330485e-02, -4.2416954e-01, ...,
1.2959087e-01, -2.1623056e-01, -8.8821554e-01],
[ 1.0618008e+00, 5.0827748e-01, -5.8256608e-01, ...,
-2.9023758e-01, 1.6930477e-01, -2.3869993e-01],
...,
[ 1.8475071e-01, -5.8594310e-01, 7.0973599e-01, ...,
-5.9211296e-01, -1.7043589e-02, -1.5649734e-01],
[-1.2073172e+00, 5.1577950e-01, -1.2952001e+00, ...,
-1.0562765e-01, 1.8499596e-01, -4.6483174e-01],
[ 5.8209622e-01, -5.3714752e-02, -1.0255412e+00, ...,
-5.6718546e-01, 5.9832001e-01, -2.3260906e+00]]], dtype=float32), array([[[ 1.1358043 , 0.38664085, 0.43075162, ..., -0.762137 ,
-0.53836805, -1.1419276 ],
[ 0.0705715 , 0.61749744, -0.1978054 , ..., 0.39686537,
-0.23118263, -2.0863478 ],
[ 0.6286853 , 0.9499371 , -0.75073713, ..., -1.0837915 ,
0.20451419, -0.64585996],
...,
[ 0.04084783, 0.5485716 , 0.02199897, ..., -0.265642 ,
-0.54954815, -2.8985202 ],
[-0.00433184, -0.5782148 , -0.28893095, ..., 0.15305014,
-0.3942154 , -2.1390564 ],
[-0.01938614, -0.6034715 , -0.3210429 , ..., 0.11286073,
-0.3612479 , -0.12291119]],
[[ 0.13899514, -0.04281238, 0.05966739, ..., -0.05543021,
-0.51721877, -2.7725601 ],
[-0.38874203, -0.13524944, -0.37960985, ..., 0.12579904,
-0.23764463, -0.94251025],
[ 1.0436667 , 0.4891924 , -0.5470476 , ..., -0.30531114,
0.143379 , -0.28663573],
...,
[ 0.17597033, -0.6172772 , 0.75050735, ..., -0.59396976,
-0.02840331, -0.20918237],
[-1.2121452 , 0.47985265, -1.2640744 , ..., -0.11457531,
0.17777829, -0.5216857 ],
[ 0.5724598 , -0.08497301, -0.99838203, ..., -0.569392 ,
0.5878865 , -2.3820512 ]]], dtype=float32), array([[[ 1.1923821 , 0.40376186, 0.4216827 , ..., -0.7753511 ,
-0.58085346, -1.1371452 ],
[ 0.13859609, 0.6558944 , -0.1899949 , ..., 0.37358993,
-0.27038255, -2.0870223 ],
[ 0.7083764 , 0.977537 , -0.76046735, ..., -1.101789 ,
0.1981793 , -0.6461577 ],
...,
[ 0.10095584, 0.5967537 , 0.02207649, ..., -0.28193793,
-0.5789527 , -2.9001386 ],
[ 0.05901275, -0.53504837, -0.28481779, ..., 0.13802934,
-0.41621858, -2.1443312 ],
[ 0.04528951, -0.5612042 , -0.31392562, ..., 0.09289672,
-0.38395336, -0.12475596]],
[[ 0.19919002, -0.0038989 , 0.06975131, ..., -0.05898362,
-0.5476832 , -2.7802918 ],
[-0.32966843, -0.10950038, -0.3582222 , ..., 0.08165199,
-0.2624505 , -0.93702954],
[ 1.119693 , 0.5142474 , -0.5341173 , ..., -0.3251373 ,
0.10789905, -0.30592436],
...,
[ 0.24882485, -0.5699529 , 0.77695113, ..., -0.63034034,
-0.07624292, -0.2281592 ],
[-1.1650497 , 0.5175082 , -1.2281002 , ..., -0.14287077,
0.15133552, -0.532626 ],
[ 0.64303875, -0.05680082, -0.9739305 , ..., -0.5787345 ,
0.5447517 , -2.403577 ]]], dtype=float32), array([[[ 1.1747141 , 0.44174162, 0.3848741 , ..., -0.8011676 ,
-0.5708256 , -1.143519 ],
[ 0.11874287, 0.7037242 , -0.22899102, ..., 0.36200705,
-0.22287843, -2.0832918 ],
[ 0.68290263, 1.0014081 , -0.8112288 , ..., -1.0980991 ,
0.22316812, -0.637702 ],
...,
[ 0.07541095, 0.63492006, 0.02669529, ..., -0.27486983,
-0.53397936, -2.8813968 ],
[ 0.07104072, -0.54481 , -0.33232585, ..., 0.12730087,
-0.37563673, -2.1450465 ],
[ 0.03186123, -0.51601535, -0.35520643, ..., 0.09008651,
-0.33910847, -0.11906879]],
[[ 0.19884765, 0.06095114, 0.00477777, ..., -0.0960753 ,
-0.49155453, -2.7463722 ],
[-0.32222956, -0.08950429, -0.4053724 , ..., 0.05162536,
-0.21072339, -0.9155606 ],
[ 1.1117101 , 0.56429935, -0.59156317, ..., -0.3369357 ,
0.14969075, -0.29045773],
...,
[ 0.25635508, -0.5126209 , 0.7268977 , ..., -0.62107044,
-0.01715574, -0.21087953],
[-1.1460723 , 0.56120336, -1.2668271 , ..., -0.16734022,
0.19381218, -0.517316 ],
[ 0.63978064, -0.01486263, -1.0128225 , ..., -0.56719303,
0.58368987, -2.3722165 ]]], dtype=float32), array([[[ 1.20512652e+00, 4.72038895e-01, 3.60962778e-01, ...,
-9.00623977e-01, -5.82258105e-01, -1.14907408e+00],
[ 1.55033678e-01, 7.28412509e-01, -2.72546947e-01, ...,
2.74131984e-01, -2.24478737e-01, -2.06169677e+00],
[ 7.13116527e-01, 1.01666617e+00, -8.43635857e-01, ...,
-1.18351495e+00, 2.21053749e-01, -6.17874563e-01],
...,
[ 1.19287886e-01, 6.56103075e-01, -7.38978712e-03, ...,
-3.54864419e-01, -5.37513494e-01, -2.87535477e+00],
[ 1.16591401e-01, -5.55387378e-01, -3.71099353e-01, ...,
2.78753694e-02, -3.70597601e-01, -2.16417289e+00],
[ 8.63942727e-02, -4.81025964e-01, -3.91344100e-01, ...,
-1.84133966e-02, -3.39215338e-01, -1.10263892e-01]],
[[ 2.64633745e-01, 7.25395158e-02, -1.39633343e-02, ...,
-1.85173869e-01, -5.20042717e-01, -2.70796871e+00],
[-2.58721560e-01, -6.35206550e-02, -4.14235502e-01, ...,
4.62563671e-02, -2.47269630e-01, -8.77729058e-01],
[ 1.17845881e+00, 5.93900442e-01, -6.18097663e-01, ...,
-4.23726231e-01, 1.19810022e-01, -2.55170494e-01],
...,
[ 3.06802571e-01, -4.83913153e-01, 7.05836833e-01, ...,
-6.99279726e-01, -5.64565696e-02, -1.74492225e-01],
[-1.07746637e+00, 5.83848476e-01, -1.28454113e+00, ...,
-2.29663596e-01, 1.96212217e-01, -5.23399591e-01],
[ 7.02956200e-01, 2.42653489e-03, -1.03614473e+00, ...,
-6.54396653e-01, 5.55146933e-01, -2.35132337e+00]]],
dtype=float32), array([[[ 1.1732985 , 0.48227564, 0.4141097 , ..., -0.9222415 ,
-0.5581617 , -1.1467376 ],
[ 0.10594734, 0.75098526, -0.23052076, ..., 0.23048519,
-0.23638739, -2.033264 ],
[ 0.6784295 , 1.0418617 , -0.810519 , ..., -1.2120562 ,
0.24596576, -0.60291487],
...,
[ 0.07915874, 0.6732479 , 0.02875949, ..., -0.38714084,
-0.5037479 , -2.8571 ],
[ 0.07176737, -0.52642834, -0.31701285, ..., -0.01229298,
-0.34029967, -2.1321528 ],
[ 0.04126783, -0.46029058, -0.34176344, ..., -0.05429364,
-0.31155083, -0.10451217]],
[[ 0.22600769, 0.09270672, 0.02146479, ..., -0.22232075,
-0.48217994, -2.6969097 ],
[-0.29719839, -0.05968198, -0.37710896, ..., 0.02224515,
-0.20888865, -0.872187 ],
[ 1.1335284 , 0.60064685, -0.58743286, ..., -0.45202363,
0.13883159, -0.2602308 ],
...,
[ 0.27142784, -0.47967467, 0.70926106, ..., -0.71909636,
-0.01251143, -0.1811402 ],
[-1.095905 , 0.6111897 , -1.2443895 , ..., -0.27054876,
0.22430526, -0.5081292 ],
[ 0.7027022 , 0.01059689, -1.0006222 , ..., -0.6746712 ,
0.58800125, -2.352779 ]]], dtype=float32), array([[[ 1.2434555 , 0.44558033, 0.4151337 , ..., -0.8851603 ,
-0.5718673 , -1.1482117 ],
[ 0.11962966, 0.72577155, -0.25604928, ..., 0.2687037 ,
-0.2457071 , -2.0307996 ],
[ 0.763747 , 1.0119921 , -0.8592167 , ..., -1.1870402 ,
0.2256221 , -0.6277423 ],
...,
[ 0.08759235, 0.64535457, 0.03834408, ..., -0.35554865,
-0.5139612 , -2.8475935 ],
[ 0.13679026, -0.55156755, -0.32664305, ..., 0.01780019,
-0.3558066 , -2.1313274 ],
[ 0.11818186, -0.4789727 , -0.3590175 , ..., -0.01446133,
-0.32358617, -0.10938768]],
[[ 0.29418862, 0.09591774, 0.00587363, ..., -0.18236086,
-0.49953887, -2.694323 ],
[-0.22257486, -0.07352418, -0.4024905 , ..., 0.05929026,
-0.22622454, -0.8882016 ],
[ 1.19449 , 0.5713768 , -0.6041051 , ..., -0.4231223 ,
0.1122172 , -0.28642637],
...,
[ 0.3465784 , -0.4939726 , 0.68308717, ..., -0.68765515,
-0.04439323, -0.20094715],
[-1.0187647 , 0.59667283, -1.2578517 , ..., -0.24298497,
0.19871093, -0.53237087],
[ 0.71921486, 0.00342394, -1.026827 , ..., -0.63569874,
0.5502706 , -2.3338537 ]]], dtype=float32), array([[[ 1.2440691 , 0.477769 , 0.40438044, ..., -0.8634442 ,
-0.516493 , -1.2156196 ],
[ 0.09270077, 0.7480357 , -0.26808444, ..., 0.2696572 ,
-0.19946648, -2.0294373 ],
[ 0.7727248 , 1.0203358 , -0.8804002 , ..., -1.1491328 ,
0.26056868, -0.70166045],
...,
[ 0.07362438, 0.6428618 , 0.02992919, ..., -0.3656686 ,
-0.47058144, -2.8904293 ],
[ 0.11636666, -0.52916086, -0.3162666 , ..., 0.0035085 ,
-0.32273746, -2.211205 ],
[ 0.09450418, -0.46651682, -0.38872302, ..., -0.02868051,
-0.284238 , -0.20115192]],
[[ 0.24763471, 0.10591529, -0.02833212, ..., -0.19653179,
-0.44746324, -2.6951957 ],
[-0.2705241 , -0.05053078, -0.38580215, ..., 0.0737243 ,
-0.18193349, -0.9320284 ],
[ 1.1450722 , 0.5878723 , -0.6142542 , ..., -0.4155699 ,
0.1653907 , -0.3516124 ],
...,
[ 0.31275982, -0.48012054, 0.6611108 , ..., -0.6956505 ,
0.01540092, -0.20229349],
[-1.0758246 , 0.6180846 , -1.2664212 , ..., -0.20339848,
0.25243902, -0.5957573 ],
[ 0.6647455 , 0.02233337, -1.0082287 , ..., -0.6396673 ,
0.6092897 , -2.3386178 ]]], dtype=float32), array([[[ 1.16829467e+00, 4.73457277e-01, 3.55845094e-01, ...,
-8.79491448e-01, -5.13881445e-01, -1.31691360e+00],
[ 2.06558462e-02, 7.50393629e-01, -2.97149330e-01, ...,
2.75190771e-01, -1.88330606e-01, -2.13159895e+00],
[ 6.88944340e-01, 1.01597929e+00, -9.25147057e-01, ...,
-1.15003026e+00, 2.47588441e-01, -7.96685874e-01],
...,
[ 2.88018938e-02, 6.88098967e-01, -2.68854816e-02, ...,
-3.84741157e-01, -4.76365626e-01, -2.96942425e+00],
[ 5.18963784e-02, -4.60506737e-01, -3.59559625e-01, ...,
-1.14138005e-02, -3.24753582e-01, -2.23346853e+00],
[ 7.39114806e-02, -4.01132107e-01, -4.24620807e-01, ...,
-3.33240032e-02, -2.84682035e-01, -3.12762260e-01]],
[[ 2.14906067e-01, 1.72432721e-01, -8.47420394e-02, ...,
-2.10198522e-01, -4.30923790e-01, -2.79578996e+00],
[-3.10599983e-01, 1.65538397e-03, -4.26579088e-01, ...,
5.09980768e-02, -1.56658575e-01, -1.01907480e+00],
[ 1.09326482e+00, 6.40747488e-01, -6.59607410e-01, ...,
-4.15965647e-01, 1.87072530e-01, -4.48307097e-01],
...,
[ 2.86490113e-01, -4.34108675e-01, 6.18547022e-01, ...,
-7.17846394e-01, 3.06018218e-02, -2.91922033e-01],
[-1.13503909e+00, 6.14412308e-01, -1.32140326e+00, ...,
-2.04109967e-01, 2.61365235e-01, -6.73554838e-01],
[ 6.21745884e-01, 8.54183882e-02, -1.06341887e+00, ...,
-6.34258091e-01, 6.15509987e-01, -2.41539836e+00]]],
dtype=float32), array([[ 0.5620128 , -0.2001392 , -0.11440954, -0.04526514, -0.8816746 ,
-0.5549258 , 0.51452374, 0.13439347, 0.53412014, 0.46277392,
0.8692565 , -0.90509814, 0.31514823, 0.8086619 , 0.58900446,
-0.3894673 , -0.45003602, 0.37346584, 0.69269675, 0.21574067,
-0.72299725, 0.528553 , -0.83846116, 0.98062813, -0.05183166,
0.33388335, -0.63176596, 0.21661893, 0.43943346, 0.33758652,
-0.24407507, -0.17800584, 0.59364974, 0.47616154, 0.558793 ,
0.27490366, -0.9666731 , -0.8721832 , 0.743239 , 0.04293209,
-0.5673905 , -0.14399827, -0.41138482, 0.8764746 , -0.11112919,
0.21457899, 0.88060266, 0.88843846, 0.18521515, -0.84538144,
-0.57872075, 0.7840174 , 0.8682007 , -0.5286343 , 0.2563142 ,
0.9634152 , 0.03505438, 0.91062546, 0.3279442 , -0.61855054,
0.22826263, 0.42789218, -0.48171976, -0.13283452, -0.86695194,
0.9060679 , 0.78916115, 0.16227603, -0.36374012, -0.5703023 ,
0.19644596, -0.6927085 , 0.19042683, -0.43984833, -0.7866716 ,
0.9690585 , -0.42288277, 0.8037468 , 0.70858365, 0.87470776,
0.630474 , 0.17134413, 0.99327976, -0.46532467, -0.00972999,
0.9460259 , 0.09055056, 0.7293024 , -0.9081666 , 0.15192512,
-0.8813194 , -0.7241285 , 0.11484392, 0.5220332 , 0.6182944 ,
0.5697724 , 0.80298615, -0.916839 , 0.8679731 , 0.3047138 ,
-0.7162764 , -0.852553 , 0.8317937 , 0.6582049 , -0.06668244,
0.36977607, 0.80465484, 0.10356631, 0.5558003 , 0.29966184,
0.93551975, -0.89290446, 0.15027076, -0.66376805, -0.6382408 ,
0.6717352 , 0.9509484 , -0.79286 , -0.18582785, -0.36172768,
-0.9791676 , -0.94657 , -0.47834975, 0.4030182 , 0.8983884 ,
0.74833804, -0.24173705, -0.6059107 ],
[ 0.01938096, -0.08230975, 0.2434453 , -0.8368162 , -0.31632444,
-0.137336 , 0.8550923 , -0.51500845, 0.5093535 , 0.7847338 ,
0.2958318 , -0.3608949 , 0.3377346 , 0.7592404 , -0.10613706,
0.45210105, -0.39598942, -0.32519925, 0.89480245, 0.6049605 ,
-0.81980604, 0.6129146 , 0.5854233 , 0.8875059 , 0.8888534 ,
0.38860276, -0.81435287, -0.8599018 , -0.7989145 , 0.87724596,
-0.09401844, 0.8232204 , 0.5973142 , -0.47759202, -0.2035289 ,
0.86339366, -0.78711975, -0.8843788 , 0.50736386, 0.7154904 ,
0.02759624, -0.29022685, -0.21070601, 0.37119249, -0.93711466,
-0.41830346, 0.49852479, 0.7634121 , -0.73495114, -0.8023139 ,
-0.56360126, 0.7008394 , 0.9837745 , -0.09430382, 0.35603583,
0.98780483, 0.3609371 , -0.31916958, -0.48238578, -0.65934813,
0.67085034, -0.43169144, 0.86363876, -0.1453511 , -0.8397705 ,
-0.35035503, 0.88129896, 0.16335464, 0.34733585, 0.24485897,
-0.5006221 , -0.9430847 , -0.80959797, -0.8578838 , -0.7431067 ,
0.49626076, -0.03579912, -0.5582668 , 0.9786438 , 0.2536843 ,
0.895339 , -0.42590025, 0.9813974 , 0.4913268 , -0.95859706,
0.5229873 , -0.75750285, 0.01685579, -0.37524623, -0.4403388 ,
-0.91602516, -0.63672376, 0.28235126, 0.5060775 , 0.03505507,
0.8782664 , 0.06858374, -0.81789017, 0.41628596, 0.9114354 ,
0.79067975, -0.76645094, 0.90893763, 0.95445615, -0.8870664 ,
0.50881255, 0.30905575, 0.4437762 , -0.2528932 , -0.14799164,
0.93950725, -0.7908481 , 0.44684762, -0.9644589 , 0.37588173,
0.9690541 , -0.6058538 , 0.2965665 , -0.07335383, -0.6774956 ,
-0.9477332 , -0.8670143 , 0.03564278, -0.8282162 , 0.24308446,
0.5860108 , -0.93586445, -0.8312509 ]], dtype=float32)]