Self Attention Layers


TransformerLayer

A network architecture based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Refer to this paper for more details.

Input is a Table which consists of 2 tensors.

  1. Token id tensor: shape (batch, seqLen) with the word token indices in the vocabulary
  2. Position id tensor: shape (batch, seqLen) with positions in the sentence.

Output is a Table as well.

  1. The states of Transformer layer.
  2. The pooled output which processes the hidden state of the last layer with regard to the first token of the sequence. This would be useful for segment-level tasks.

With Default Embedding:

Scala:

TransformerLayer[Float](vocab = 40990,
    seqLen = 77,
    nBlock = 12,
    residPdrop = 0.1,
    attnPdrop = 0.1,
    nHead = 12,
    hiddenSize = 768,
    embeddingDrop = 0,
    initializerRange = 0.02,
    bidirectional = false,
    outputAllBlock = false)

Python:

TransformerLayer.init(vocab=40990, seq_len=77, n_block=12, hidden_drop=0.1,
    attn_drop=0.1, n_head=12, hidden_size=768,
    embedding_drop=0.1, initializer_range=0.02,
    bidirectional=False, output_all_block=False)

Parameters:

With Customized Embedding:

Scala:

TransformerLayer[Float](nBlock = 12,
    residPdrop = 0.1,
    attnPdrop = 0.1,
    nHead = 12,
    bidirectional = false,
    initializerRange = 0.02,
    outputAllBlock = true,
    embeddingLayer = embedding.asInstanceOf[KerasLayer[Activity, Tensor[Float], Float]])

Python:

TransformerLayer(n_block=12,
    hidden_drop=0.1,
    attn_drop=0.1,
    n_head=12,
    initializer_range=0.02,
    bidirectional=False,
    output_all_block=False,
    embedding_layer=embedding,
    input_shape=((seq_len,), (seq_len,)),
    intermediate_size=0)

Parameters:

Scala example:

val shape1 = Shape(20)
val shape2 = Shape(20)
val input1 = Variable[Float](shape1)
val input2 = Variable[Float](shape2)
val input = Array(input1, input2)
val seq = TransformerLayer[Float](200, hiddenSize = 128, nHead = 8,
  seqLen = 20, nBlock = 1).from(input: _*)
val model = Model[Float](input, seq)

val trainToken = Tensor[Float](1, 20).rand()
val trainPos = Tensor.ones[Float](1, 20)
val input3 = T(trainToken, trainPos)
val output = model.forward(input3)

Input is:

{
2: 1.0  1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0
   [com.intel.analytics.bigdl.tensor.DenseTensor$mcF$sp of size 1x20]
1: 0.8087359    0.16409875  0.7404631   0.4836999   0.034994964 0.033039592 0.6694243   0.84700763  0.32154092  0.17410904  0.66117364  0.30495027  0.19573595  0.058101892 0.65923077  0.84077805  0.50113535  0.48393667  0.06523132  0.0667426
   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x20]
}

Output is:

{
2: 0.83383083   0.72725344  0.16394942  -0.79005975 0.8877357   -0.9060916  -0.6796065  0.46835706  -0.4700584  0.43868023  0.6641587   0.6711142   -0.70056283 -0.42694178 0.7615595   -0.25590983 0.21654142  0.35254374  0.83790034  0.1103606   -0.20419843 -0.9739706  0.6150182   0.4499923   0.3355538   -0.01543447 -0.99528116 0.45984524  -0.22544041 0.10049125  0.8418835   -0.116228305    -0.112435654    0.5183222   -0.59375525 0.31828925  0.50506884  0.14892755  0.94327587  -0.19001998 0.54074824  -0.07616825 -0.79334164 -0.49726814 0.23889944  -0.91731304 -0.5484148  0.5048103   0.9743351   0.10505025  0.81167877  -0.47498485 -0.83443964 -0.89340115 0.6443838   0.10184191  -0.38618097 -0.32026938 0.51587516  -0.40602723 -0.2931675  -0.86100364 0.109585665 0.9023708   0.46609795  0.0028693299    -0.5746851  -0.45607233 -0.9075561  -0.91294044 0.8077997   0.23019081  0.51124465  -0.39125186 0.16946821  -0.36827865 -0.32563296 0.62560886  -0.7278883  0.8076773   0.89344263  -0.9259615  0.21476166  0.67077845  0.5857905   -0.32905066 -0.16318946 0.6435858   -0.28905967 -0.6991412  -0.5289766  -0.6954091  0.1577004   0.5618301   -0.6290018  0.114078626 -0.52474076 0.27916297  -0.76610357 0.67119384  -0.4308661  0.063731246 -0.5281069  -0.65910465 0.5383283   -0.2875557  0.24594739  -0.6789035  0.7002648   -0.64659894 -0.70994437 -0.8416273  0.4666695   -0.55062526 0.14995292  -0.978979   0.40934727  -0.9028927  0.38194665  0.2334618   -0.9481384  -0.51903373 -0.947906   0.2667679   -0.76987743 -0.7490675  0.6777159   0.9593161
   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x128]
1: (1,.,.) =
   0.8369983    -0.9907519  0.74404025  ... 0.6154673   0.107825294 -0.806892
   0.7676861    -0.962961   0.73240614  ... 0.534349    0.0049344404    -0.81643736
   0.7487803    -0.9717681  0.7315394   ... 0.59831613  0.010904985 -0.82502025
   ...
   0.06956328   -1.2103055  1.4155688   ... -0.759053   0.6966926   -0.53496075
   0.0759853    -1.2265961  1.4023252   ... -0.7500985  0.68647313  -0.52275336
   0.06356962   -1.2309887  1.3984702   ... -0.751963   0.69192046  -0.52820134

   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x20x128]
}

Python example:

model = TransformerLayer.init(
    vocab=200, hidden_size=128, n_head=4, seq_len=20)

train_token = np.random.randint(20, size=(2, 20))
train_pos = np.zeros((2, 20), dtype=np.int32)
input = [train_token, train_pos]
output = model.forward(input)

Input is:

<type 'list'>: [array([[11,  2, 16,  6, 17, 18,  2,  4,  5, 16, 18, 15, 13, 19,  5, 15,
    14, 14,  2,  9],
   [10, 15, 13,  6, 12,  0, 11,  3, 16, 13,  6, 13, 17, 13,  3,  4,
    15,  5,  7, 15]]), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
  dtype=int32)]

Output is

<type 'list'>: [array([[[ 0.26004127, -0.31793368, -1.1605529 , ..., -0.81875914,
 -0.02121837, -0.8328352 ],
[-0.8622302 , -0.35201085,  0.63190293, ...,  2.0652232 ,
  1.5278    ,  0.38224357],
[-2.5103235 ,  1.4465114 ,  0.71134603, ...,  1.1776686 ,
  0.6882701 ,  0.3534629 ],
...,
[-0.22725764,  1.2112792 , -0.40597847, ...,  2.2241254 ,
  0.2580125 , -1.1470895 ],
[-0.56174546,  1.3353435 , -0.7445968 , ...,  1.1259638 ,
  0.6951011 , -1.1421459 ],
[-0.6615135 ,  1.1899865 , -0.81727505, ...,  2.0474243 ,
  0.20160393, -0.7789728 ]],

[[-1.1624268 , -0.5375418 , -0.7274868 , ..., -0.99061227,
 -0.57117355,  1.0684316 ],
[ 0.11317759, -0.7231343 ,  0.7723393 , ...,  1.6518786 ,
  1.0916579 ,  0.18682887],
[-1.9651127 ,  0.9987117 ,  0.32025027, ...,  0.94719195,
 -0.21028236, -0.02251417],
...,
[-0.6677234 ,  0.69822913, -0.9714249 , ...,  2.208334  ,
  0.7719772 , -0.93855625],
[-0.63691545,  1.3876344 , -0.8491991 , ...,  2.060551  ,
  0.34702447, -0.8160082 ],
[-0.6608573 ,  1.2608795 , -0.46634364, ...,  2.100828  ,
  0.2967869 , -1.0938305 ]]], dtype=float32), array([[ 0.06879381,  0.6821829 , -0.8267953 , -0.02695777, -0.53899264,
 0.8241045 ,  0.6976903 ,  0.31741282,  0.23590134,  0.5565326 ,
 0.95292866,  0.5658284 , -0.2916065 , -0.37934095, -0.2774958 ,
 0.73409927, -0.71731025,  0.07897043,  0.88609815, -0.27966806,
 0.93520796,  0.72740096,  0.1626402 , -0.26063287,  0.28597558,
-0.12945679,  0.7151408 , -0.8463592 , -0.48385444, -0.29313505,
 0.86453205, -0.93834317,  0.41815573,  0.92436415,  0.8209114 ,
 0.6627246 , -0.574135  ,  0.607416  ,  0.04769071, -0.29779348,
-0.26268572, -0.78998053, -0.7522611 ,  0.89941144, -0.15754697,
 0.9298859 , -0.8327022 , -0.63423705, -0.63789636, -0.14168388,
-0.56104964, -0.80995566,  0.9244693 ,  0.4679966 , -0.16284083,
 0.8478645 ,  0.29836348, -0.15369722, -0.4490478 ,  0.11052075,
 0.23767054,  0.59320366, -0.79055625,  0.22201608, -0.88366413,
-0.4410687 ,  0.8762162 , -0.6516914 , -0.5993653 , -0.5972125 ,
-0.86697286, -0.17644943,  0.95839834, -0.06382846,  0.7430881 ,
-0.59690744,  0.3901914 ,  0.06803267,  0.9142394 ,  0.7583274 ,
-0.18442968,  0.56280667, -0.37844184, -0.41195455, -0.8376329 ,
 0.87641823, -0.98970294, -0.6764397 , -0.86945957, -0.69273126,
 0.9911777 ,  0.417286  , -0.8774987 ,  0.17141937,  0.7204654 ,
-0.62387246, -0.8795049 ,  0.62618923, -0.29725042, -0.4565646 ,
-0.47798416, -0.97555065, -0.94241685, -0.97800356,  0.8523641 ,
-0.96860206,  0.5378995 , -0.73754525, -0.01649606, -0.4274561 ,
-0.5290453 ,  0.11851768,  0.48821065,  0.4822751 ,  0.49497148,
-0.5734494 , -0.29612035, -0.7254394 , -0.1418346 , -0.56686646,
 0.03665365, -0.9586826 , -0.0983429 , -0.09348761, -0.96338177,
 0.76481736,  0.87975204,  0.70463663],
[-0.09654156,  0.78266025, -0.9125131 , -0.6706971 , -0.58709925,
-0.94729275, -0.32309514, -0.95263994,  0.2036015 , -0.9297767 ,
 0.6164713 ,  0.3484337 ,  0.46247053,  0.21615174, -0.8382687 ,
-0.55828595, -0.59234536, -0.9643932 ,  0.9310115 , -0.12657425,
 0.63812125,  0.80040973, -0.47581342,  0.9823402 , -0.5400171 ,
 0.5864317 , -0.19979174, -0.5721838 ,  0.9190707 ,  0.31628668,
 0.08952013,  0.8719338 ,  0.26684833,  0.8955768 , -0.9275499 ,
-0.81994563,  0.28863704, -0.16376448,  0.15855551,  0.04302022,
 0.4440408 , -0.7293209 ,  0.2255107 ,  0.16333969,  0.38721767,
-0.04512435, -0.5473172 , -0.5812051 , -0.8219114 , -0.43659028,
-0.04860768, -0.8912252 ,  0.62100273,  0.7187475 , -0.06158534,
 0.6554498 , -0.62163985,  0.63035303,  0.19207267, -0.68847877,
 0.10341872, -0.88906926, -0.38804066, -0.8157233 , -0.81641346,
 0.8846337 , -0.70225614,  0.6281251 , -0.81235796,  0.77828485,
 0.9393982 , -0.42554784,  0.4150426 , -0.32612413, -0.721988  ,
 0.96166253, -0.6080237 , -0.7312329 ,  0.06843777, -0.09806018,
-0.7357863 , -0.28613612, -0.8895085 , -0.9027925 ,  0.56311375,
 0.85699487, -0.32128897,  0.80635303, -0.01190906, -0.23292968,
-0.5115769 ,  0.17153661, -0.79993784,  0.6232265 , -0.06049479,
-0.83510727,  0.9652135 ,  0.08310007, -0.9671807 , -0.17466563,
 0.48009604,  0.594712  ,  0.19612817, -0.9279629 , -0.59968966,
-0.36079255, -0.7250685 ,  0.59395283,  0.7574965 , -0.4377294 ,
 0.45312116,  0.7117049 , -0.82085943, -0.10442825,  0.73688287,
 0.38598123,  0.35439053, -0.3862137 , -0.56253886,  0.7388591 ,
-0.6024478 , -0.699977  , -0.46581215, -0.79513186,  0.09657894,
 0.280869  , -0.38445532, -0.98311806]], dtype=float32)]

BERT

Bidirectional Encoder Representations from Transformers. Refer https://arxiv.org/pdf/1810.04805.pdf

Input is a Table which consists of 4 tensors.

  1. Token id tensor: shape (batch, seqLen) with the word token indices in the vocabulary
  2. Token type id tensor: shape (batch, seqLen) with the token types in (0, 1). 0 means sentence A and 1 means a sentence B (see BERT paper for more details).
  3. Position id tensor: shape (batch, seqLen) with positions in the sentence.
  4. Attention_mask tensor: shape (batch, seqLen) with indices in (0, 1). It's a mask to be used if the input sequence length is smaller than seqLen in the current batch.

Output is a Table as well.

  1. The states of BERT layer.
  2. The pooled output which processes the hidden state of the last layer with regard to the first token of the sequence. This would be useful for segment-level tasks.

With Default Embedding:

Scala:

BERT[Float](vocab: Int = 40990,
    hiddenSize: Int = 768,
    nBlock: Int = 12,
    nHead: Int = 12,
    maxPositionLen: Int = 512,
    intermediateSize: Int = 3072,
    hiddenPDrop: Double = 0.1,
    attnPDrop: Double = 0.1,
    initializerRange: Double = 0.02,
    outputAllBlock: Boolean = true,
    inputSeqLen: Int = -1)

Python:

BERT.init(vocab=40990,
    hidden_size=768,
    n_block=12,
    n_head=12,
    seq_len=512,
    intermediate_size=3072,
    hidden_drop=0.1,
    attn_drop=0.1,
    initializer_range=0.02,
    output_all_block=True)

Parameters:

With Customized Embedding:

Scala:

BERT[Float](nBlock = 12,
    nHead = 12,
    intermediateSize = 3072,
    hiddenPDrop = 0.1,
    attnPDrop = 0.1,
    initializerRange = 0.02,
    outputAllBlock = true,
    embeddingLayer = embedding)

Python:

BERT(n_block=12,
    n_head=12,
    intermediate_size=3072,
    hidden_drop=0.1,
    attn_drop=0.1,
    initializer_range=0.02,
    output_all_block=True,
    embedding_layer=embedding,
    input_shape=((seq_len,), (seq_len,), (seq_len,), (1, 1, seq_len)))

Parameters:

Loading from existing pretrained model:

Scala:

BERT[Float](path = "",
    weightPath = null,
    inputSeqLen = 11,
    hiddenPDrop = 0.1,
    attnPDrop = 0.1,
    outputAllBlock = true)

Python:

BERT.init_from_existing_model(path="",
    weight_path=None,
    input_seq_len=-1.0,
    hidden_drop=-1.0,
    attn_drop=-1.0,
    output_all_block=True)

Parameters:

Scala example:

val layer = BERT[Float](vocab = 100,
    hiddenSize = 10,
    nBlock = 3,
    nHead = 2,
    intermediateSize = 64,
    hiddenPDrop = 0.1,
    attnPDrop = 0.1,
    maxPositionLen = 10,
    outputAllBlock = false,
    inputSeqLen = 10)

val shape = Shape(List(Shape(1, 10), Shape(1, 10), Shape(1, 10), Shape(1, 1, 1, 10)))
layer.build(shape)
val inputIds = Tensor[Float](Array[Float](7, 20, 39, 27, 10,
  39, 30, 21, 17, 15), Array(1, 10))
val segmentIds = Tensor[Float](Array[Float](0, 0, 0, 0, 0, 1, 1, 1, 1, 1), Array(1, 10))
val positionIds = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(1, 10))
val masks = Tensor[Float](1, 1, 1, 10).fill(1.0f)

val output = layer.forward(T(inputIds, segmentIds, positionIds, masks))

Input is:

{
2: 0.0  0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0
   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
4: (1,1,.,.) =
   1.0  1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0

   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x1x1x10]
1: 7.0  20.0    39.0    27.0    10.0    39.0    30.0    21.0    17.0    15.0
   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
3: 0.0  1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0
   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
}

Output is:

{
2: 0.5398573    0.08571402  -0.9461041  -0.35362077 -0.24374364 0.24349216  0.9587727   -0.03278971 -0.826852   -0.8808889
   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10]
1: (1,.,.) =
   1.3381815    1.7575556   -1.1870699  0.8455374   -1.6000531  0.115945406 -0.33695826 -0.39254665 -0.33637434 -0.20421773
   -0.08370285  0.056055143 -0.91990083 1.6324282   -0.093128644    -0.4484297  -2.0828273  0.10244746  0.577287    1.2597716
   0.3563086    0.37092525  -0.5089354  0.4525072   1.7706354   0.65231055  -2.0269241  -0.2548585  0.3711578   -1.1831268
   0.2429675    -0.023419544    -0.28389466 0.6601246   -0.009858845    -0.028412571    -2.5104556  1.0338438   1.3621751   -0.44306967
   1.7147139    1.1627073   -0.19394834 0.8043055   -1.0080436  -1.7716306  -0.7668168  -0.19861369 0.45103902  -0.19371253
   0.077525005  0.0722655   1.0745171   0.07997274  0.06562643  1.6474637   0.18938908  -2.377528   -0.6107291  -0.21850263
   -1.3190242   1.7057956   0.32655835  0.5711799   -0.80318034 0.2776545   1.4860673   -0.676896   -0.39734793 -1.1708072
   -0.4327645   -0.19849697 0.3695452   -0.08213705 1.2378154   0.591234    -1.505518   1.684885    -1.6251724  -0.03939093
   0.6422535    -0.582018   1.6665243   -1.0995792  0.19488664  1.3563607   -0.60793823 -0.05846788 -1.7225715  0.21054967
   -1.0927358   -0.37666538 0.70802236  -2.0131714  0.94964516  1.4701655   0.053027537 0.051168486 -0.58528    0.83582383

   [com.intel.analytics.bigdl.tensor.DenseTensor of size 1x10x10]
}

Python example:

layer = BERT.init(
    vocab=200, hidden_size=128, n_head=4, seq_len=20, intermediate_size=20)

train_token = np.random.randint(20, size=(2, 20))
token_type_id = np.zeros((2, 20), dtype=np.int32)
train_pos = np.zeros((2, 20), dtype=np.int32)
mask_attention = np.ones((2, 1, 1, 20), dtype=np.int32)
input = [train_token, token_type_id, train_pos, mask_attention]
output = layer.forward(input)

Input is:

<type 'list'>: [array([[ 8, 19,  5,  8,  4, 13, 13, 12,  1,  6, 16, 14, 19,  0, 11, 18,
 1, 17,  0,  0],
[17, 10, 15, 19, 15,  2, 18,  8,  1, 11, 10, 17,  7,  2,  0,  0,
 9, 14, 11,  6]]), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
dtype=int32), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
dtype=int32), array([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],

[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]],
dtype=int32)]

Output is

<type 'list'>: [array([[[ 1.01330066e+00,  4.74100798e-01,  3.81211847e-01, ...,
 -8.00989151e-01, -6.18815482e-01, -1.09804094e+00],
[-6.29726201e-02,  6.71391249e-01, -2.28019580e-01, ...,
  3.97501498e-01, -3.32217604e-01, -2.04850674e+00],
[ 5.04802346e-01,  1.00434709e+00, -7.63663530e-01, ...,
 -1.11675525e+00,  1.41550392e-01, -6.47688091e-01],
...,
[ 5.07664750e-04,  6.47843182e-01, -5.33694960e-02, ...,
 -2.01566055e-01, -6.62943959e-01, -2.93835902e+00],
[-1.49255514e-01, -5.47551095e-01, -3.36264402e-01, ...,
  1.11121520e-01, -4.42977905e-01, -2.13847613e+00],
[-1.33390293e-01, -5.50503194e-01, -3.49355727e-01, ...,
  9.34313685e-02, -4.41935956e-01, -8.25921223e-02]],

[[ 1.34967836e-02,  3.90778482e-02, -2.22317409e-02, ...,
 -9.81532633e-02, -6.08992100e-01, -2.77224326e+00],
[-5.21431923e-01, -6.74456656e-02, -4.66892511e-01, ...,
  1.05466165e-01, -2.67068297e-01, -9.00964141e-01],
[ 9.41378593e-01,  5.21076620e-01, -5.35079956e-01, ...,
 -3.15736473e-01,  8.08603615e-02, -2.44178265e-01],
...,
[ 4.16017324e-02, -5.65874994e-01,  6.68676615e-01, ...,
 -6.28256857e-01, -9.09847617e-02, -2.42878512e-01],
[-1.36971796e+00,  5.37231266e-01, -1.33729517e+00, ...,
 -1.47498712e-01,  8.00304264e-02, -5.09030581e-01],
[ 3.98404837e-01, -7.18296226e-03, -1.08256066e+00, ...,
 -5.17360926e-01,  5.50065935e-01, -2.32753420e+00]]],
dtype=float32), array([[[ 1.0795887 ,  0.44977495,  0.45561683, ..., -0.729603  ,
 -0.6098092 , -1.0323973 ],
[ 0.0154253 ,  0.6424524 , -0.15503715, ...,  0.45100495,
 -0.3161888 , -1.9826275 ],
[ 0.58491707,  0.9876782 , -0.69952184, ..., -1.0432141 ,
  0.1380458 , -0.5642554 ],
...,
[ 0.06806332,  0.61824507,  0.02341641, ..., -0.21342012,
 -0.63312817, -2.8557966 ],
[-0.06217962, -0.5528828 , -0.34740448, ...,  0.16651583,
 -0.41633344, -2.064906  ],
[-0.04626712, -0.57442385, -0.277238  , ...,  0.13806444,
 -0.43256086, -0.01180306]],

[[ 0.10532085,  0.01057051,  0.07536474, ..., -0.03406155,
 -0.572023  , -2.6935408 ],
[-0.42477775, -0.10768362, -0.37653154, ...,  0.17155378,
 -0.27841952, -0.8244427 ],
[ 1.0290473 ,  0.5059685 , -0.5359356 , ..., -0.25725254,
  0.1034779 , -0.16898313],
...,
[ 0.14118548, -0.5945706 ,  0.7681386 , ..., -0.55807835,
 -0.07778832, -0.15940095],
[-1.2648381 ,  0.50598496, -1.2431567 , ..., -0.06980868,
  0.10642368, -0.4181047 ],
[ 0.48330045, -0.05184587, -0.9985824 , ..., -0.5360492 ,
  0.56541353, -2.2607849 ]]], dtype=float32), array([[[ 1.1541002 ,  0.47630545,  0.40673187, ..., -0.7284888 ,
 -0.55945337, -1.0810231 ],
[ 0.10070852,  0.64252985, -0.2007717 , ...,  0.4489277 ,
 -0.24709189, -2.0173872 ],
[ 0.67520154,  0.9793912 , -0.7441366 , ..., -1.0376649 ,
  0.20359974, -0.6060102 ],
...,
[ 0.07809319,  0.63523245, -0.02464442, ..., -0.21328981,
 -0.5693355 , -2.8386393 ],
[ 0.02228299, -0.5229728 , -0.33483037, ...,  0.16430138,
 -0.40036577, -2.094183  ],
[ 0.02232744, -0.54113156, -0.3307599 , ...,  0.14321396,
 -0.3796677 , -0.04973204]],

[[ 0.20497712,  0.02804335,  0.028764  , ..., -0.01617111,
 -0.5416485 , -2.7333891 ],
[-0.3361876 , -0.08618001, -0.41299412, ...,  0.17708196,
 -0.23643918, -0.8763187 ],
[ 1.1118197 ,  0.5178778 , -0.57264006, ..., -0.2597192 ,
  0.15024357, -0.23373066],
...,
[ 0.23304611, -0.57528406,  0.71815467, ..., -0.5524511 ,
 -0.04103457, -0.15449452],
[-1.1629226 ,  0.5377656 , -1.2816569 , ..., -0.05795323,
  0.1603044 , -0.47194824],
[ 0.5773567 , -0.04114214, -1.0306932 , ..., -0.52537155,
  0.5703101 , -2.3124278 ]]], dtype=float32), array([[[ 1.1319652e+00,  4.2663044e-01,  3.9611375e-01, ...,
 -7.7264631e-01, -5.3006041e-01, -1.0942854e+00],
[ 7.3858641e-02,  6.1578143e-01, -2.0985913e-01, ...,
  4.0289888e-01, -2.2484708e-01, -2.0233095e+00],
[ 6.3545388e-01,  9.4610500e-01, -7.6165521e-01, ...,
 -1.0820770e+00,  2.2266804e-01, -6.0132843e-01],
...,
[ 3.9479308e-02,  6.0636342e-01, -2.8302141e-02, ...,
 -2.6316714e-01, -5.5309945e-01, -2.8510940e+00],
[-2.8412668e-03, -5.5100703e-01, -3.4540960e-01, ...,
  1.5979633e-01, -3.8844827e-01, -2.0994248e+00],
[-1.4572166e-02, -5.7526213e-01, -3.3382124e-01, ...,
  1.0289014e-01, -3.6059290e-01, -6.5041430e-02]],

[[ 1.5256011e-01,  3.3955947e-03,  1.7648729e-02, ...,
 -4.9600061e-02, -5.1613468e-01, -2.7417533e+00],
[-3.6988521e-01, -8.8330485e-02, -4.2416954e-01, ...,
  1.2959087e-01, -2.1623056e-01, -8.8821554e-01],
[ 1.0618008e+00,  5.0827748e-01, -5.8256608e-01, ...,
 -2.9023758e-01,  1.6930477e-01, -2.3869993e-01],
...,
[ 1.8475071e-01, -5.8594310e-01,  7.0973599e-01, ...,
 -5.9211296e-01, -1.7043589e-02, -1.5649734e-01],
[-1.2073172e+00,  5.1577950e-01, -1.2952001e+00, ...,
 -1.0562765e-01,  1.8499596e-01, -4.6483174e-01],
[ 5.8209622e-01, -5.3714752e-02, -1.0255412e+00, ...,
 -5.6718546e-01,  5.9832001e-01, -2.3260906e+00]]], dtype=float32), array([[[ 1.1358043 ,  0.38664085,  0.43075162, ..., -0.762137  ,
 -0.53836805, -1.1419276 ],
[ 0.0705715 ,  0.61749744, -0.1978054 , ...,  0.39686537,
 -0.23118263, -2.0863478 ],
[ 0.6286853 ,  0.9499371 , -0.75073713, ..., -1.0837915 ,
  0.20451419, -0.64585996],
...,
[ 0.04084783,  0.5485716 ,  0.02199897, ..., -0.265642  ,
 -0.54954815, -2.8985202 ],
[-0.00433184, -0.5782148 , -0.28893095, ...,  0.15305014,
 -0.3942154 , -2.1390564 ],
[-0.01938614, -0.6034715 , -0.3210429 , ...,  0.11286073,
 -0.3612479 , -0.12291119]],

[[ 0.13899514, -0.04281238,  0.05966739, ..., -0.05543021,
 -0.51721877, -2.7725601 ],
[-0.38874203, -0.13524944, -0.37960985, ...,  0.12579904,
 -0.23764463, -0.94251025],
[ 1.0436667 ,  0.4891924 , -0.5470476 , ..., -0.30531114,
  0.143379  , -0.28663573],
...,
[ 0.17597033, -0.6172772 ,  0.75050735, ..., -0.59396976,
 -0.02840331, -0.20918237],
[-1.2121452 ,  0.47985265, -1.2640744 , ..., -0.11457531,
  0.17777829, -0.5216857 ],
[ 0.5724598 , -0.08497301, -0.99838203, ..., -0.569392  ,
  0.5878865 , -2.3820512 ]]], dtype=float32), array([[[ 1.1923821 ,  0.40376186,  0.4216827 , ..., -0.7753511 ,
 -0.58085346, -1.1371452 ],
[ 0.13859609,  0.6558944 , -0.1899949 , ...,  0.37358993,
 -0.27038255, -2.0870223 ],
[ 0.7083764 ,  0.977537  , -0.76046735, ..., -1.101789  ,
  0.1981793 , -0.6461577 ],
...,
[ 0.10095584,  0.5967537 ,  0.02207649, ..., -0.28193793,
 -0.5789527 , -2.9001386 ],
[ 0.05901275, -0.53504837, -0.28481779, ...,  0.13802934,
 -0.41621858, -2.1443312 ],
[ 0.04528951, -0.5612042 , -0.31392562, ...,  0.09289672,
 -0.38395336, -0.12475596]],

[[ 0.19919002, -0.0038989 ,  0.06975131, ..., -0.05898362,
 -0.5476832 , -2.7802918 ],
[-0.32966843, -0.10950038, -0.3582222 , ...,  0.08165199,
 -0.2624505 , -0.93702954],
[ 1.119693  ,  0.5142474 , -0.5341173 , ..., -0.3251373 ,
  0.10789905, -0.30592436],
...,
[ 0.24882485, -0.5699529 ,  0.77695113, ..., -0.63034034,
 -0.07624292, -0.2281592 ],
[-1.1650497 ,  0.5175082 , -1.2281002 , ..., -0.14287077,
  0.15133552, -0.532626  ],
[ 0.64303875, -0.05680082, -0.9739305 , ..., -0.5787345 ,
  0.5447517 , -2.403577  ]]], dtype=float32), array([[[ 1.1747141 ,  0.44174162,  0.3848741 , ..., -0.8011676 ,
 -0.5708256 , -1.143519  ],
[ 0.11874287,  0.7037242 , -0.22899102, ...,  0.36200705,
 -0.22287843, -2.0832918 ],
[ 0.68290263,  1.0014081 , -0.8112288 , ..., -1.0980991 ,
  0.22316812, -0.637702  ],
...,
[ 0.07541095,  0.63492006,  0.02669529, ..., -0.27486983,
 -0.53397936, -2.8813968 ],
[ 0.07104072, -0.54481   , -0.33232585, ...,  0.12730087,
 -0.37563673, -2.1450465 ],
[ 0.03186123, -0.51601535, -0.35520643, ...,  0.09008651,
 -0.33910847, -0.11906879]],

[[ 0.19884765,  0.06095114,  0.00477777, ..., -0.0960753 ,
 -0.49155453, -2.7463722 ],
[-0.32222956, -0.08950429, -0.4053724 , ...,  0.05162536,
 -0.21072339, -0.9155606 ],
[ 1.1117101 ,  0.56429935, -0.59156317, ..., -0.3369357 ,
  0.14969075, -0.29045773],
...,
[ 0.25635508, -0.5126209 ,  0.7268977 , ..., -0.62107044,
 -0.01715574, -0.21087953],
[-1.1460723 ,  0.56120336, -1.2668271 , ..., -0.16734022,
  0.19381218, -0.517316  ],
[ 0.63978064, -0.01486263, -1.0128225 , ..., -0.56719303,
  0.58368987, -2.3722165 ]]], dtype=float32), array([[[ 1.20512652e+00,  4.72038895e-01,  3.60962778e-01, ...,
 -9.00623977e-01, -5.82258105e-01, -1.14907408e+00],
[ 1.55033678e-01,  7.28412509e-01, -2.72546947e-01, ...,
  2.74131984e-01, -2.24478737e-01, -2.06169677e+00],
[ 7.13116527e-01,  1.01666617e+00, -8.43635857e-01, ...,
 -1.18351495e+00,  2.21053749e-01, -6.17874563e-01],
...,
[ 1.19287886e-01,  6.56103075e-01, -7.38978712e-03, ...,
 -3.54864419e-01, -5.37513494e-01, -2.87535477e+00],
[ 1.16591401e-01, -5.55387378e-01, -3.71099353e-01, ...,
  2.78753694e-02, -3.70597601e-01, -2.16417289e+00],
[ 8.63942727e-02, -4.81025964e-01, -3.91344100e-01, ...,
 -1.84133966e-02, -3.39215338e-01, -1.10263892e-01]],

[[ 2.64633745e-01,  7.25395158e-02, -1.39633343e-02, ...,
 -1.85173869e-01, -5.20042717e-01, -2.70796871e+00],
[-2.58721560e-01, -6.35206550e-02, -4.14235502e-01, ...,
  4.62563671e-02, -2.47269630e-01, -8.77729058e-01],
[ 1.17845881e+00,  5.93900442e-01, -6.18097663e-01, ...,
 -4.23726231e-01,  1.19810022e-01, -2.55170494e-01],
...,
[ 3.06802571e-01, -4.83913153e-01,  7.05836833e-01, ...,
 -6.99279726e-01, -5.64565696e-02, -1.74492225e-01],
[-1.07746637e+00,  5.83848476e-01, -1.28454113e+00, ...,
 -2.29663596e-01,  1.96212217e-01, -5.23399591e-01],
[ 7.02956200e-01,  2.42653489e-03, -1.03614473e+00, ...,
 -6.54396653e-01,  5.55146933e-01, -2.35132337e+00]]],
dtype=float32), array([[[ 1.1732985 ,  0.48227564,  0.4141097 , ..., -0.9222415 ,
 -0.5581617 , -1.1467376 ],
[ 0.10594734,  0.75098526, -0.23052076, ...,  0.23048519,
 -0.23638739, -2.033264  ],
[ 0.6784295 ,  1.0418617 , -0.810519  , ..., -1.2120562 ,
  0.24596576, -0.60291487],
...,
[ 0.07915874,  0.6732479 ,  0.02875949, ..., -0.38714084,
 -0.5037479 , -2.8571    ],
[ 0.07176737, -0.52642834, -0.31701285, ..., -0.01229298,
 -0.34029967, -2.1321528 ],
[ 0.04126783, -0.46029058, -0.34176344, ..., -0.05429364,
 -0.31155083, -0.10451217]],

[[ 0.22600769,  0.09270672,  0.02146479, ..., -0.22232075,
 -0.48217994, -2.6969097 ],
[-0.29719839, -0.05968198, -0.37710896, ...,  0.02224515,
 -0.20888865, -0.872187  ],
[ 1.1335284 ,  0.60064685, -0.58743286, ..., -0.45202363,
  0.13883159, -0.2602308 ],
...,
[ 0.27142784, -0.47967467,  0.70926106, ..., -0.71909636,
 -0.01251143, -0.1811402 ],
[-1.095905  ,  0.6111897 , -1.2443895 , ..., -0.27054876,
  0.22430526, -0.5081292 ],
[ 0.7027022 ,  0.01059689, -1.0006222 , ..., -0.6746712 ,
  0.58800125, -2.352779  ]]], dtype=float32), array([[[ 1.2434555 ,  0.44558033,  0.4151337 , ..., -0.8851603 ,
 -0.5718673 , -1.1482117 ],
[ 0.11962966,  0.72577155, -0.25604928, ...,  0.2687037 ,
 -0.2457071 , -2.0307996 ],
[ 0.763747  ,  1.0119921 , -0.8592167 , ..., -1.1870402 ,
  0.2256221 , -0.6277423 ],
...,
[ 0.08759235,  0.64535457,  0.03834408, ..., -0.35554865,
 -0.5139612 , -2.8475935 ],
[ 0.13679026, -0.55156755, -0.32664305, ...,  0.01780019,
 -0.3558066 , -2.1313274 ],
[ 0.11818186, -0.4789727 , -0.3590175 , ..., -0.01446133,
 -0.32358617, -0.10938768]],

[[ 0.29418862,  0.09591774,  0.00587363, ..., -0.18236086,
 -0.49953887, -2.694323  ],
[-0.22257486, -0.07352418, -0.4024905 , ...,  0.05929026,
 -0.22622454, -0.8882016 ],
[ 1.19449   ,  0.5713768 , -0.6041051 , ..., -0.4231223 ,
  0.1122172 , -0.28642637],
...,
[ 0.3465784 , -0.4939726 ,  0.68308717, ..., -0.68765515,
 -0.04439323, -0.20094715],
[-1.0187647 ,  0.59667283, -1.2578517 , ..., -0.24298497,
  0.19871093, -0.53237087],
[ 0.71921486,  0.00342394, -1.026827  , ..., -0.63569874,
  0.5502706 , -2.3338537 ]]], dtype=float32), array([[[ 1.2440691 ,  0.477769  ,  0.40438044, ..., -0.8634442 ,
 -0.516493  , -1.2156196 ],
[ 0.09270077,  0.7480357 , -0.26808444, ...,  0.2696572 ,
 -0.19946648, -2.0294373 ],
[ 0.7727248 ,  1.0203358 , -0.8804002 , ..., -1.1491328 ,
  0.26056868, -0.70166045],
...,
[ 0.07362438,  0.6428618 ,  0.02992919, ..., -0.3656686 ,
 -0.47058144, -2.8904293 ],
[ 0.11636666, -0.52916086, -0.3162666 , ...,  0.0035085 ,
 -0.32273746, -2.211205  ],
[ 0.09450418, -0.46651682, -0.38872302, ..., -0.02868051,
 -0.284238  , -0.20115192]],

[[ 0.24763471,  0.10591529, -0.02833212, ..., -0.19653179,
 -0.44746324, -2.6951957 ],
[-0.2705241 , -0.05053078, -0.38580215, ...,  0.0737243 ,
 -0.18193349, -0.9320284 ],
[ 1.1450722 ,  0.5878723 , -0.6142542 , ..., -0.4155699 ,
  0.1653907 , -0.3516124 ],
...,
[ 0.31275982, -0.48012054,  0.6611108 , ..., -0.6956505 ,
  0.01540092, -0.20229349],
[-1.0758246 ,  0.6180846 , -1.2664212 , ..., -0.20339848,
  0.25243902, -0.5957573 ],
[ 0.6647455 ,  0.02233337, -1.0082287 , ..., -0.6396673 ,
  0.6092897 , -2.3386178 ]]], dtype=float32), array([[[ 1.16829467e+00,  4.73457277e-01,  3.55845094e-01, ...,
 -8.79491448e-01, -5.13881445e-01, -1.31691360e+00],
[ 2.06558462e-02,  7.50393629e-01, -2.97149330e-01, ...,
  2.75190771e-01, -1.88330606e-01, -2.13159895e+00],
[ 6.88944340e-01,  1.01597929e+00, -9.25147057e-01, ...,
 -1.15003026e+00,  2.47588441e-01, -7.96685874e-01],
...,
[ 2.88018938e-02,  6.88098967e-01, -2.68854816e-02, ...,
 -3.84741157e-01, -4.76365626e-01, -2.96942425e+00],
[ 5.18963784e-02, -4.60506737e-01, -3.59559625e-01, ...,
 -1.14138005e-02, -3.24753582e-01, -2.23346853e+00],
[ 7.39114806e-02, -4.01132107e-01, -4.24620807e-01, ...,
 -3.33240032e-02, -2.84682035e-01, -3.12762260e-01]],

[[ 2.14906067e-01,  1.72432721e-01, -8.47420394e-02, ...,
 -2.10198522e-01, -4.30923790e-01, -2.79578996e+00],
[-3.10599983e-01,  1.65538397e-03, -4.26579088e-01, ...,
  5.09980768e-02, -1.56658575e-01, -1.01907480e+00],
[ 1.09326482e+00,  6.40747488e-01, -6.59607410e-01, ...,
 -4.15965647e-01,  1.87072530e-01, -4.48307097e-01],
...,
[ 2.86490113e-01, -4.34108675e-01,  6.18547022e-01, ...,
 -7.17846394e-01,  3.06018218e-02, -2.91922033e-01],
[-1.13503909e+00,  6.14412308e-01, -1.32140326e+00, ...,
 -2.04109967e-01,  2.61365235e-01, -6.73554838e-01],
[ 6.21745884e-01,  8.54183882e-02, -1.06341887e+00, ...,
 -6.34258091e-01,  6.15509987e-01, -2.41539836e+00]]],
dtype=float32), array([[ 0.5620128 , -0.2001392 , -0.11440954, -0.04526514, -0.8816746 ,
-0.5549258 ,  0.51452374,  0.13439347,  0.53412014,  0.46277392,
 0.8692565 , -0.90509814,  0.31514823,  0.8086619 ,  0.58900446,
-0.3894673 , -0.45003602,  0.37346584,  0.69269675,  0.21574067,
-0.72299725,  0.528553  , -0.83846116,  0.98062813, -0.05183166,
 0.33388335, -0.63176596,  0.21661893,  0.43943346,  0.33758652,
-0.24407507, -0.17800584,  0.59364974,  0.47616154,  0.558793  ,
 0.27490366, -0.9666731 , -0.8721832 ,  0.743239  ,  0.04293209,
-0.5673905 , -0.14399827, -0.41138482,  0.8764746 , -0.11112919,
 0.21457899,  0.88060266,  0.88843846,  0.18521515, -0.84538144,
-0.57872075,  0.7840174 ,  0.8682007 , -0.5286343 ,  0.2563142 ,
 0.9634152 ,  0.03505438,  0.91062546,  0.3279442 , -0.61855054,
 0.22826263,  0.42789218, -0.48171976, -0.13283452, -0.86695194,
 0.9060679 ,  0.78916115,  0.16227603, -0.36374012, -0.5703023 ,
 0.19644596, -0.6927085 ,  0.19042683, -0.43984833, -0.7866716 ,
 0.9690585 , -0.42288277,  0.8037468 ,  0.70858365,  0.87470776,
 0.630474  ,  0.17134413,  0.99327976, -0.46532467, -0.00972999,
 0.9460259 ,  0.09055056,  0.7293024 , -0.9081666 ,  0.15192512,
-0.8813194 , -0.7241285 ,  0.11484392,  0.5220332 ,  0.6182944 ,
 0.5697724 ,  0.80298615, -0.916839  ,  0.8679731 ,  0.3047138 ,
-0.7162764 , -0.852553  ,  0.8317937 ,  0.6582049 , -0.06668244,
 0.36977607,  0.80465484,  0.10356631,  0.5558003 ,  0.29966184,
 0.93551975, -0.89290446,  0.15027076, -0.66376805, -0.6382408 ,
 0.6717352 ,  0.9509484 , -0.79286   , -0.18582785, -0.36172768,
-0.9791676 , -0.94657   , -0.47834975,  0.4030182 ,  0.8983884 ,
 0.74833804, -0.24173705, -0.6059107 ],
[ 0.01938096, -0.08230975,  0.2434453 , -0.8368162 , -0.31632444,
-0.137336  ,  0.8550923 , -0.51500845,  0.5093535 ,  0.7847338 ,
 0.2958318 , -0.3608949 ,  0.3377346 ,  0.7592404 , -0.10613706,
 0.45210105, -0.39598942, -0.32519925,  0.89480245,  0.6049605 ,
-0.81980604,  0.6129146 ,  0.5854233 ,  0.8875059 ,  0.8888534 ,
 0.38860276, -0.81435287, -0.8599018 , -0.7989145 ,  0.87724596,
-0.09401844,  0.8232204 ,  0.5973142 , -0.47759202, -0.2035289 ,
 0.86339366, -0.78711975, -0.8843788 ,  0.50736386,  0.7154904 ,
 0.02759624, -0.29022685, -0.21070601,  0.37119249, -0.93711466,
-0.41830346,  0.49852479,  0.7634121 , -0.73495114, -0.8023139 ,
-0.56360126,  0.7008394 ,  0.9837745 , -0.09430382,  0.35603583,
 0.98780483,  0.3609371 , -0.31916958, -0.48238578, -0.65934813,
 0.67085034, -0.43169144,  0.86363876, -0.1453511 , -0.8397705 ,
-0.35035503,  0.88129896,  0.16335464,  0.34733585,  0.24485897,
-0.5006221 , -0.9430847 , -0.80959797, -0.8578838 , -0.7431067 ,
 0.49626076, -0.03579912, -0.5582668 ,  0.9786438 ,  0.2536843 ,
 0.895339  , -0.42590025,  0.9813974 ,  0.4913268 , -0.95859706,
 0.5229873 , -0.75750285,  0.01685579, -0.37524623, -0.4403388 ,
-0.91602516, -0.63672376,  0.28235126,  0.5060775 ,  0.03505507,
 0.8782664 ,  0.06858374, -0.81789017,  0.41628596,  0.9114354 ,
 0.79067975, -0.76645094,  0.90893763,  0.95445615, -0.8870664 ,
 0.50881255,  0.30905575,  0.4437762 , -0.2528932 , -0.14799164,
 0.93950725, -0.7908481 ,  0.44684762, -0.9644589 ,  0.37588173,
 0.9690541 , -0.6058538 ,  0.2965665 , -0.07335383, -0.6774956 ,
-0.9477332 , -0.8670143 ,  0.03564278, -0.8282162 ,  0.24308446,
 0.5860108 , -0.93586445, -0.8312509 ]], dtype=float32)]