【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

文章目录

  • 举例讲解transformer的输入输出细节
    • encoder
      • padding
      • Padding Mask
      • Positional Embedding
      • attention
      • FeedForward
      • add/Norm
      • encoder输入输出
    • decoder
      • Sequence Mask
      • 测试
  • Transformer pytorch代码实现
    • 数据准备
    • 参数设置
    • 定义位置信息
    • Mask掉停用词
    • Decoder 输入 Mask
    • 计算注意力信息、残差和归一化
    • 前馈神经网络
    • encoder layer(block)
    • Encoder
    • decoder layer(block)
    • Decoder
    • Transformer
    • 定义网络
    • 训练Transformer
    • 测试
  • 参考

举例讲解transformer的输入输出细节

数据从输入到encoder到decoder输出这个过程中的流程(以机器翻译为例子):

encoder

对于机器翻译来说,一个样本是由原始句子和翻译后的句子组成的。比如原始句子是: “我爱机器学习”,那么翻译后是 ’i love machine learning‘。 则该一个样本就是由“我爱机器学习”和 “i love machine learning” 组成。

这个样本的原始句子的单词长度是length=4,即‘我’ ‘爱’ ‘机器’ ‘学习’。经过embedding后每个词的embedding向量是512。那么“我爱机器学习”这个句子的embedding后的维度是[4,512 ] (若是批量输入,则embedding后的维度是[batch, 4, 512])。

padding

假设样本中句子的最大长度是10,那么对于长度不足10的句子,需要补足到10个长度,shape就变为[10, 512], 补全的位置上的embedding数值自然就是0了

Padding Mask

对于输入序列一般要进行padding补齐,也就是说设定一个统一长度N,在较短的序列后面填充0到长度为N。对于那些补零的数据来说,attention机制不应该把注意力放在这些位置上,所以需要进行一些处理。具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样经过softmax后,这些位置的权重就会接近0。Transformer的padding mask实际上是一个张量,每个值都是一个Boolean,值为false的地方就是要进行处理的地方。
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

Positional Embedding

得到补全后的句子embedding向量后,直接输入encoder的话,那么是没有考虑到句子中的位置顺序关系的。此时需要再加一个位置向量,位置向量在模型训练中有特定的方式,可以表示每个词的位置或者不同词之间的距离;总之,核心思想是在attention计算时提供有效的距离信息。
初步理解参考我的博客【初理解】Transformer中的Positional Encoding

attention

参考我的博文(2021李宏毅)机器学习-Self-attention

FeedForward

略,很简单

add/Norm

经过add/norm后的隐藏输出的shape也是[10,512]。

encoder输入输出

从输入开始,再从头理一遍单个encoder这个过程:

  1. 输入x
  2. x 做一个层归一化: x1 = norm(x)
  3. 进入多头self-attention: x2 = self_attention(x1)
  4. 残差加成:x3 = x + x2
  5. 再做个层归一化:x4 = norm(x3)
  6. 经过前馈网络: x5 = feed_forward(x4)
  7. 残差加成: x6 = x3 + x5
  8. 输出x6
    【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)
    这就是Encoder所做的工作

decoder

【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)
注意encoder的输出并没直接作为decoder的直接输入。

训练的时候,1.初始decoder的time step为1时(也就是第一次接收输入),其输入为一个特殊的token,可能是目标序列开始的token(如),也可能是源序列结尾的token(如),也可能是其它视任务而定的输入等等,不同源码中可能有微小的差异,其目标则是预测翻译后的第1个单词(token)是什么;2.然后和预测出来的第1个单词一起,再次作为decoder的输入,得到第2个预测单词;3后续依此类推;

具体的例子如下:

样本:“我/爱/机器/学习”和 “i/ love /machine/ learning”
训练:

  1. 把“我/爱/机器/学习”embedding后输入到encoder里去,最后一层的encoder最终输出的outputs [10, 512](假设我们采用的embedding长度为512,而且batch size = 1),此outputs 乘以新的参数矩阵,可以作为decoder里每一层用到的K和V;

  2. 将<bos>作为decoder的初始输入,将decoder的最大概率输出词 A1和‘i’做cross entropy计算error。

  3. 将<bos>,“i” 作为decoder的输入,将decoder的最大概率输出词 A2 和‘love’做cross entropy计算error。

  4. 将<bos>,“i”,“love” 作为decoder的输入,将decoder的最大概率输出词A3和’machine’ 做cross entropy计算error。

  5. 将<bos>,“i”,"love ",“machine” 作为decoder的输入,将decoder最大概率输出词A4和‘learning’做cross entropy计算error。

  6. 将<bos>,“i”,"love ",“machine”,“learning” 作为decoder的输入,将decoder最大概率输出词A5和终止符做cross entropy计算error。

Sequence Mask

上述训练过程是挨个单词串行进行的,那么能不能并行进行呢,当然可以。可以看到上述单个句子训练时候,输入到 decoder的分别是

<bos>

<bos>,“i”

<bos>,“i”,“love”

<bos>,“i”,"love ",“machine”

<bos>,“i”,"love ",“machine”,“learning”

那么为何不将这些输入组成矩阵,进行输入呢?这些输入组成矩阵形式如下:

【<bos>

<bos>,“i”

<bos>,“i”,“love”

<bos>,“i”,"love ",“machine”

<bos>,“i”,"love ",“machine”,“learning” 】

怎么操作得到这个矩阵呢?

将decoder在上述2-6步次的输入补全为一个完整的句子

【<bos>,“i”,"love ",“machine”,“learning”
<bos>,“i”,"love ",“machine”,“learning”
<bos>,“i”,"love ",“machine”,“learning”
<bos>,“i”,"love ",“machine”,“learning”
<bos>,“i”,"love ",“machine”,“learning”】

然后将上述矩阵矩阵乘以一个 mask矩阵

【1 0 0 0 0

1 1 0 0 0

1 1 1 0 0

1 1 1 1 0

1 1 1 1 1 】

这样是不是就得到了

【<bos>

<bos>,“i”

<bos>,“i”,“love”

<bos>,“i”,"love ",“machine”

<bos>,“i”,"love ",“machine”,“learning” 】

这样将这个矩阵输入到decoder(其实你可以想一下,此时这个矩阵是不是类似于批处理,矩阵的每行是一个样本,只是每行的样本长度不一样,每行输入后最终得到一个输出概率分布,作为矩阵输入的话一下可以得到5个输出概率分布)。
这样就可以进行并行计算进行训练了。

测试

训练好模型, 测试的时候,比如用 '机器学习很有趣’当作测试样本,得到其英语翻译。

这一句经过encoder后得到输出tensor,送入到decoder(并不是当作decoder的直接输入):

  1. 然后用起始符<bos>当作decoder的 输入,得到输出 machine

  2. 用<bos> + machine 当作输入得到输出 learning

  3. 用 <bos> + machine + learning 当作输入得到is

  4. 用<bos> + machine + learning + is 当作输入得到interesting

  5. 用<bos> + machine + learning + is + interesting 当作输入得到 结束符号<eos>

得到了完整的翻译 ‘machine learning is interesting’

可以看到,在测试过程中,只能一个单词一个单词的进行输出,是串行进行的。

Transformer pytorch代码实现

数据准备

import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
#自制数据集
             # Encoder_input    Decoder_input        Decoder_output
sentences = [['我 是 学 生 P' , 'S I am a student'   , 'I am a student E'],         # S: 开始符号
             ['我 喜 欢 学 习', 'S I like learning P', 'I like learning P E'],      # E: 结束符号
             ['我 是 男 生 P' , 'S I am a boy'       , 'I am a boy E']]             # P: 占位符号,如果当前句子不足固定长度用P占位 pad补0
src_vocab = {'P':0, '我':1, '是':2, '学':3, '生':4, '喜':5, '欢':6,'习':7,'男':8}   # 词源字典  字:索引
src_idx2word = {src_vocab[key]: key for key in src_vocab}
src_vocab_size = len(src_vocab)                 # 字典字的个数
tgt_vocab = {'S':0, 'E':1, 'P':2, 'I':3, 'am':4, 'a':5, 'student':6, 'like':7, 'learning':8, 'boy':9}
idx2word = {tgt_vocab[key]: key for key in tgt_vocab}                               # 把目标字典转换成 索引:字的形式
tgt_vocab_size = len(tgt_vocab)                                                     # 目标字典尺寸
src_len = len(sentences[0][0].split(" "))                                           # Encoder输入的最大长度 5
tgt_len = len(sentences[0][1].split(" "))                                           # Decoder输入输出最大长度 5
src_len,tgt_len
(5, 5)
# 把sentences 转换成字典索引
def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
      enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] 
      dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] 
      dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] 
      enc_inputs.extend(enc_input)
      dec_inputs.extend(dec_input)
      dec_outputs.extend(dec_output)
    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
print(enc_inputs)
print(dec_inputs)
print(dec_outputs)
tensor([[1, 2, 3, 4, 0],
        [1, 5, 6, 3, 7],
        [1, 2, 8, 4, 0]])
tensor([[0, 3, 4, 5, 6],
        [0, 3, 7, 8, 2],
        [0, 3, 4, 5, 9]])
tensor([[3, 4, 5, 6, 1],
        [3, 7, 8, 2, 1],
        [3, 4, 5, 9, 1]])

sentences 里一共有三个训练数据,中文->英文。把Encoder_input、Decoder_input、Decoder_output转换成字典索引,例如"学"->3、“student”->6。再把数据转换成batch大小为2的分组数据,3句话一共可以分成两组,一组2句话、一组1句话。src_len表示中文句子固定最大长度,tgt_len 表示英文句子固定最大长度。

#自定义数据集函数
class MyDataSet(Data.Dataset):
  def __init__(self, enc_inputs, dec_inputs, dec_outputs):
    super(MyDataSet, self).__init__()
    self.enc_inputs = enc_inputs
    self.dec_inputs = dec_inputs
    self.dec_outputs = dec_outputs
  def __len__(self):
    return self.enc_inputs.shape[0]
  def __getitem__(self, idx):
    return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True) 

参数设置

d_model = 512   # 字 Embedding 的维度
d_ff = 2048     # 前向传播隐藏层维度
d_k = d_v = 64  # K(=Q), V的维度 
n_layers = 6    # 有多少个encoder和decoder
n_heads = 8     # Multi-Head Attention设置为8

定义位置信息

首先,给出文章中的公式解读:


{
p
k
,
2
i
=
sin

(
k
/
1000
2
i
/
d
)
p
k
,
2
i
+
1
=
cos

(
k
/
1000
2
i
/
d
)
\left\{\begin{array}{l} \boldsymbol{p}_{k, 2 i}=\sin \left(k / 10000^{2 i / d}\right) \\ \boldsymbol{p}_{k, 2 i+1}=\cos \left(k / 10000^{2 i / d}\right) \end{array}\right.
{pk,2i=sin(k/100002i/d)pk,2i+1=cos(k/100002i/d)


 其中 
p
k
,
2
i
,
p
k
,
2
i
+
1
 分别是位置 
k
 的编码向量的第 
2
i
,
2
i
+
1
 个分量, 
d
 是向量维度 
\text { 其中 } \boldsymbol{p}_{k, 2 i}, \boldsymbol{p}_{k, 2 i+1} \text { 分别是位置 } k \text { 的编码向量的第 } 2 i, 2 i+1 \text { 个分量, } d \text { 是向量维度 }
 其中 pk,2i,pk,2i+1 分别是位置 k 的编码向量的第 2i,2i+1 个分量d 是向量维度 

【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

class PositionalEncoding(nn.Module):
    def __init__(self,d_model,dropout=0.1,max_len=5000):
        super(PositionalEncoding,self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pos_table = np.array([
        [pos / np.power(10000, 2 * i / d_model) for i in range(d_model)]
        if pos != 0 else np.zeros(d_model) for pos in range(max_len)])
        pos_table[1:, 0::2] = np.sin(pos_table[1:, 0::2])                  # 字嵌入维度为偶数时
        pos_table[1:, 1::2] = np.cos(pos_table[1:, 1::2])                  # 字嵌入维度为奇数时
        self.pos_table = torch.FloatTensor(pos_table).cuda()               # enc_inputs: [seq_len, d_model]
    def forward(self,enc_inputs):                                         # enc_inputs: [batch_size, seq_len, d_model]
        enc_inputs += self.pos_table[:enc_inputs.size(1),:]
        return self.dropout(enc_inputs.cuda())

生成位置信息矩阵pos_table,直接加上输入的enc_inputs上,得到带有位置信息的字向量,pos_table是一个固定值的矩阵。这里矩阵加法利用到了广播机制
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

Mask掉停用词

Mask句子中没有实际意义的占位符,例如’我 是 学 生 P’ ,P对应句子没有实际意义,所以需要被Mask,Encoder_input 和Decoder_input占位符都需要被Mask。
这就是为了处理,句子不一样长,但是输入有需要定长,不够长的pad填充,但是计算又不需要这个pad,所以mask掉

这个函数最核心的一句代码是 seq_k.data.eq(0),这句的作用是返回一个大小和 seq_k 一样的 tensor,只不过里面的值只有 True 和 False。如果 seq_k 某个位置的值等于 0,那么对应位置就是 True,否则即为 False。举个例子,输入为 seq_data = [1, 2, 3, 4, 0],seq_data.data.eq(0) 就会返回 [False, False, False, False, True]

def get_attn_pad_mask(seq_q,seq_k):
    batch_size, len_q = seq_q.size()# seq_q 用于升维,为了做attention,mask score矩阵用的
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k]
    return pad_attn_mask.expand(batch_size,len_q,len_k) # 扩展成多维度   [batch_size, len_q, len_k]

Decoder 输入 Mask

用来Mask未来输入信息,返回的是一个上三角矩阵。比如我们在中英文翻译时候,会先把"我是学生"整个句子输入到Encoder中,得到最后一层的输出后,才会在Decoder输入"S I am a student"(s表示开始),但是"S I am a student"这个句子我们不会一起输入,而是在T0时刻先输入"S"预测,预测第一个词"I";在下一个T1时刻,同时输入"S"和"I"到Decoder预测下一个单词"am";然后在T2时刻把"S,I,am"同时输入到Decoder预测下一个单词"a",依次把整个句子输入到Decoder,预测出"I am a student E"。
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

下图是 np.triu() 用法
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

def get_attn_subsequence_mask(seq):                               # seq: [batch_size, tgt_len]
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]          # 生成上三角矩阵,[batch_size, tgt_len, tgt_len]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()  #  [batch_size, tgt_len, tgt_len]
    return subsequence_mask

计算注意力信息、残差和归一化

【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
    def forward(self, Q, K, V, attn_mask):                             # Q: [batch_size, n_heads, len_q, d_k]
                                                                       # K: [batch_size, n_heads, len_k, d_k]
                                                                       # V: [batch_size, n_heads, len_v(=len_k), d_v]
                                                                       # attn_mask: [batch_size, n_heads, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)   # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9)                           # 如果是停用词P就等于 0 
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)                                # [batch_size, n_heads, len_q, d_v]
        return context, attn

【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

计算注意力信息,
W
Q
,
W
K
,
W
V
W^{Q}, W^{K}, W^{V}
WQ,WK,WV
矩阵会拆分成 8 个小矩阵。注意传入的 input_Q, input_K, input_V, 在Encoder和Decoder的第一次调用传入的三个矩阵是相同的,但 Decoder的第二次调用传入的三个矩阵input_Q 等于 input_K 不等于 input_V,因为decoder中是计算的cross attention,如下图所示.

【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):    # input_Q: [batch_size, len_q, d_model]
                                                                # input_K: [batch_size, len_k, d_model]
                                                                # input_V: [batch_size, len_v(=len_k), d_model]
                                                                # attn_mask: [batch_size, seq_len, seq_len]
        residual, batch_size = input_Q, input_Q.size(0)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)              # attn_mask : [batch_size, n_heads, seq_len, seq_len]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)          # context: [batch_size, n_heads, len_q, d_v]
                                                                                 # attn: [batch_size, n_heads, len_q, len_k]
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context)                                                # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).cuda()(output + residual), attn

前馈神经网络

输入inputs ,经过两个全连接层,得到的结果再加上 inputs (残差),再做LayerNorm归一化。LayerNorm归一化可以理解层是把Batch中每一句话进行归一化。
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False))
    def forward(self, inputs):                             # inputs: [batch_size, seq_len, d_model]
        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(d_model).cuda()(output + residual)   # [batch_size, seq_len, d_model]  

encoder layer(block)

【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()                                     # 多头注意力机制
        self.pos_ffn = PoswiseFeedForwardNet()                                        # 前馈神经网络
    def forward(self, enc_inputs, enc_self_attn_mask):                                # enc_inputs: [batch_size, src_len, d_model]
        #输入3个enc_inputs分别与W_q、W_k、W_v相乘得到Q、K、V                          # enc_self_attn_mask: [batch_size, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs,    # enc_outputs: [batch_size, src_len, d_model], 
                                               enc_self_attn_mask)                    # attn: [batch_size, n_heads, src_len, src_len]                                                                   
        enc_outputs = self.pos_ffn(enc_outputs)                                       # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn

Encoder

第一步,中文字索引进行Embedding,转换成512维度的字向量。第二步,在子向量上面加上位置信息。第三步,Mask掉句子中的占位符号。第四步,通过6层的encoder(上一层的输出作为下一层的输入)。
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns

decoder layer(block)

decoder两次调用MultiHeadAttention时,第一次调用传入的 Q,K,V 的值是相同的,都等于dec_inputs,第二次调用 Q 矩阵是来自Decoder的输入。K,V 两个矩阵是来自Encoder的输出,等于enc_outputs。

class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()
    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): # dec_inputs: [batch_size, tgt_len, d_model]
                                                                                       # enc_outputs: [batch_size, src_len, d_model]
                                                                                       # dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
                                                                                       # dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, 
                                                 dec_inputs, dec_self_attn_mask)   # dec_outputs: [batch_size, tgt_len, d_model]
                                                                                   # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, 
                                                enc_outputs, dec_enc_attn_mask)    # dec_outputs: [batch_size, tgt_len, d_model]
                                                                                   # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs = self.pos_ffn(dec_outputs)                                    # dec_outputs: [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

Decoder

第一步,英文字索引进行Embedding,转换成512维度的字向量。第二步,在子向量上面加上位置信息。第三步,Mask掉句子中的占位符号和输出顺序.第四步,通过6层的decoder(上一层的输出作为下一层的输入)
【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batch_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model]
        # Decoder输入序列的pad mask矩阵(这个例子中decoder是没有加pad的,实际应用中都是有pad填充的)
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
        # Masked Self_Attention:当前时刻是看不到未来的信息的
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len]
        # Decoder中把两种mask矩阵相加(既屏蔽了pad的信息,也屏蔽了未来时刻的信息)
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len]
        # 这个mask主要用于encoder-decoder attention层
        # get_attn_pad_mask主要是enc_inputs的pad mask矩阵(因为enc是处理K,V的,求Attention时是用v1,v2,..vm去加权的,
        # 要把pad对应的v_i的相关系数设为0,这样注意力就不会关注pad向量)
        #                       dec_inputs只是提供expand的size的
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len]
        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

Transformer

Trasformer的整体结构,输入数据先通过Encoder,再同个Decoder,最后把输出进行多分类,分类数为英文字典长度,也就是判断每一个字的概率。

class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.Encoder = Encoder().cuda()
        self.Decoder = Decoder().cuda()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda()
    def forward(self, enc_inputs, dec_inputs):                         # enc_inputs: [batch_size, src_len]  
                                                                       # dec_inputs: [batch_size, tgt_len]
        enc_outputs, enc_self_attns = self.Encoder(enc_inputs)         # enc_outputs: [batch_size, src_len, d_model], 
                                                                       # enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.Decoder(
            dec_inputs, enc_inputs, enc_outputs)                       # dec_outpus    : [batch_size, tgt_len, d_model], 
                                                                       # dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], 
                                                                       # dec_enc_attn  : [n_layers, batch_size, tgt_len, src_len]
        dec_logits = self.projection(dec_outputs)                      # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

定义网络

model = Transformer().cuda()
criterion = nn.CrossEntropyLoss(ignore_index=0)     #忽略 占位符 索引为0.
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

训练Transformer

因为batch=2,所以一个epoch有两个loss

for epoch in range(1000):
    for enc_inputs, dec_inputs, dec_outputs in loader:
        enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda()
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        loss = criterion(outputs,dec_outputs.view(-1))
        print('Epoch:', '%04d' % (epoch+1), 'loss =', '{:.6f}'.format(loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Epoch: 0001 loss = 0.000002
Epoch: 0001 loss = 0.000002
Epoch: 0002 loss = 0.000002
Epoch: 0002 loss = 0.000002
Epoch: 0003 loss = 0.000002
Epoch: 0003 loss = 0.000002
Epoch: 0004 loss = 0.000004
Epoch: 0004 loss = 0.000002
Epoch: 0005 loss = 0.000003
Epoch: 0005 loss = 0.000004
Epoch: 0006 loss = 0.000003
Epoch: 0006 loss = 0.000002
Epoch: 0007 loss = 0.000003
Epoch: 0007 loss = 0.000002
Epoch: 0008 loss = 0.000003
Epoch: 0008 loss = 0.000003
Epoch: 0009 loss = 0.000003
Epoch: 0009 loss = 0.000002
Epoch: 0010 loss = 0.000004
Epoch: 0010 loss = 0.000002
Epoch: 0011 loss = 0.000002
Epoch: 0011 loss = 0.000002
Epoch: 0012 loss = 0.000004
Epoch: 0012 loss = 0.000003
Epoch: 0013 loss = 0.000003
Epoch: 0013 loss = 0.000003
Epoch: 0014 loss = 0.000002
Epoch: 0014 loss = 0.000002
Epoch: 0015 loss = 0.000003
Epoch: 0015 loss = 0.000003
Epoch: 0016 loss = 0.000003
Epoch: 0016 loss = 0.000002
Epoch: 0017 loss = 0.000001
Epoch: 0017 loss = 0.000002
Epoch: 0018 loss = 0.000002
Epoch: 0018 loss = 0.000003
Epoch: 0019 loss = 0.000003
Epoch: 0019 loss = 0.000002
Epoch: 0020 loss = 0.000003
Epoch: 0020 loss = 0.000002
Epoch: 0021 loss = 0.000002
Epoch: 0021 loss = 0.000004
Epoch: 0022 loss = 0.000003
Epoch: 0022 loss = 0.000002
Epoch: 0023 loss = 0.000003
Epoch: 0023 loss = 0.000002
Epoch: 0024 loss = 0.000003
Epoch: 0024 loss = 0.000002
Epoch: 0025 loss = 0.000003
Epoch: 0025 loss = 0.000002
Epoch: 0026 loss = 0.000003
Epoch: 0026 loss = 0.000002
Epoch: 0027 loss = 0.000002
Epoch: 0027 loss = 0.000002
Epoch: 0028 loss = 0.000002
Epoch: 0028 loss = 0.000002
Epoch: 0029 loss = 0.000003
Epoch: 0029 loss = 0.000002
Epoch: 0030 loss = 0.000003
Epoch: 0030 loss = 0.000003
Epoch: 0031 loss = 0.000002
Epoch: 0031 loss = 0.000002
Epoch: 0032 loss = 0.000002
Epoch: 0032 loss = 0.000003
Epoch: 0033 loss = 0.000002
Epoch: 0033 loss = 0.000002
Epoch: 0034 loss = 0.000001
Epoch: 0034 loss = 0.000002
Epoch: 0035 loss = 0.000003
Epoch: 0035 loss = 0.000002
Epoch: 0036 loss = 0.000003
Epoch: 0036 loss = 0.000002
Epoch: 0037 loss = 0.000003
Epoch: 0037 loss = 0.000003
Epoch: 0038 loss = 0.000002
Epoch: 0038 loss = 0.000002
Epoch: 0039 loss = 0.000002
Epoch: 0039 loss = 0.000002
Epoch: 0040 loss = 0.000002
Epoch: 0040 loss = 0.000002
Epoch: 0041 loss = 0.000003
Epoch: 0041 loss = 0.000002
Epoch: 0042 loss = 0.000003
Epoch: 0042 loss = 0.000003
Epoch: 0043 loss = 0.000003
Epoch: 0043 loss = 0.000002
Epoch: 0044 loss = 0.000003
Epoch: 0044 loss = 0.000002
Epoch: 0045 loss = 0.000002
Epoch: 0045 loss = 0.000003
Epoch: 0046 loss = 0.000002
Epoch: 0046 loss = 0.000002
Epoch: 0047 loss = 0.000003
Epoch: 0047 loss = 0.000002
Epoch: 0048 loss = 0.000003
Epoch: 0048 loss = 0.000002
Epoch: 0049 loss = 0.000002
Epoch: 0049 loss = 0.000004
Epoch: 0050 loss = 0.000003
Epoch: 0050 loss = 0.000002
Epoch: 0051 loss = 0.000002
Epoch: 0051 loss = 0.000002
Epoch: 0052 loss = 0.000003
Epoch: 0052 loss = 0.000003
Epoch: 0053 loss = 0.000002
Epoch: 0053 loss = 0.000002
Epoch: 0054 loss = 0.000002
Epoch: 0054 loss = 0.000001
Epoch: 0055 loss = 0.000002
Epoch: 0055 loss = 0.000003
Epoch: 0056 loss = 0.000002
Epoch: 0056 loss = 0.000003
Epoch: 0057 loss = 0.000002
Epoch: 0057 loss = 0.000003
Epoch: 0058 loss = 0.000002
Epoch: 0058 loss = 0.000002
Epoch: 0059 loss = 0.000003
Epoch: 0059 loss = 0.000004
Epoch: 0060 loss = 0.000002
Epoch: 0060 loss = 0.000003
Epoch: 0061 loss = 0.000002
Epoch: 0061 loss = 0.000002
Epoch: 0062 loss = 0.000002
Epoch: 0062 loss = 0.000003
Epoch: 0063 loss = 0.000003
Epoch: 0063 loss = 0.000002
Epoch: 0064 loss = 0.000002
Epoch: 0064 loss = 0.000003
Epoch: 0065 loss = 0.000003
Epoch: 0065 loss = 0.000002
Epoch: 0066 loss = 0.000002
Epoch: 0066 loss = 0.000004
Epoch: 0067 loss = 0.000001
Epoch: 0067 loss = 0.000003
Epoch: 0068 loss = 0.000003
Epoch: 0068 loss = 0.000004
Epoch: 0069 loss = 0.000002
Epoch: 0069 loss = 0.000002
Epoch: 0070 loss = 0.000001
Epoch: 0070 loss = 0.000003
Epoch: 0071 loss = 0.000004
Epoch: 0071 loss = 0.000002
Epoch: 0072 loss = 0.000003
Epoch: 0072 loss = 0.000002
Epoch: 0073 loss = 0.000002
Epoch: 0073 loss = 0.000003
Epoch: 0074 loss = 0.000003
Epoch: 0074 loss = 0.000002
Epoch: 0075 loss = 0.000003
Epoch: 0075 loss = 0.000002
Epoch: 0076 loss = 0.000002
Epoch: 0076 loss = 0.000003
Epoch: 0077 loss = 0.000001
Epoch: 0077 loss = 0.000002
Epoch: 0078 loss = 0.000001
Epoch: 0078 loss = 0.000002
Epoch: 0079 loss = 0.000003
Epoch: 0079 loss = 0.000002
Epoch: 0080 loss = 0.000002
Epoch: 0080 loss = 0.000002
Epoch: 0081 loss = 0.000002
Epoch: 0081 loss = 0.000005
Epoch: 0082 loss = 0.000003
Epoch: 0082 loss = 0.000002
Epoch: 0083 loss = 0.000003
Epoch: 0083 loss = 0.000003
Epoch: 0084 loss = 0.000002
Epoch: 0084 loss = 0.000003
Epoch: 0085 loss = 0.000002
Epoch: 0085 loss = 0.000002
Epoch: 0086 loss = 0.000003
Epoch: 0086 loss = 0.000001
Epoch: 0087 loss = 0.000002
Epoch: 0087 loss = 0.000002
Epoch: 0088 loss = 0.000001
Epoch: 0088 loss = 0.000002
Epoch: 0089 loss = 0.000002
Epoch: 0089 loss = 0.000003
Epoch: 0090 loss = 0.000002
Epoch: 0090 loss = 0.000002
Epoch: 0091 loss = 0.000004
Epoch: 0091 loss = 0.000002
Epoch: 0092 loss = 0.000002
Epoch: 0092 loss = 0.000002
Epoch: 0093 loss = 0.000003
Epoch: 0093 loss = 0.000002
Epoch: 0094 loss = 0.000002
Epoch: 0094 loss = 0.000003
Epoch: 0095 loss = 0.000001
Epoch: 0095 loss = 0.000002
Epoch: 0096 loss = 0.000003
Epoch: 0096 loss = 0.000002
Epoch: 0097 loss = 0.000002
Epoch: 0097 loss = 0.000002
Epoch: 0098 loss = 0.000001
Epoch: 0098 loss = 0.000003
Epoch: 0099 loss = 0.000003
Epoch: 0099 loss = 0.000003
Epoch: 0100 loss = 0.000002
Epoch: 0100 loss = 0.000003
Epoch: 0101 loss = 0.000003
Epoch: 0101 loss = 0.000002
Epoch: 0102 loss = 0.000004
Epoch: 0102 loss = 0.000002
Epoch: 0103 loss = 0.000003
Epoch: 0103 loss = 0.000002
Epoch: 0104 loss = 0.000003
Epoch: 0104 loss = 0.000003
Epoch: 0105 loss = 0.000003
Epoch: 0105 loss = 0.000002
Epoch: 0106 loss = 0.000002
Epoch: 0106 loss = 0.000001
Epoch: 0107 loss = 0.000003
Epoch: 0107 loss = 0.000003
Epoch: 0108 loss = 0.000002
Epoch: 0108 loss = 0.000002
Epoch: 0109 loss = 0.000003
Epoch: 0109 loss = 0.000002
Epoch: 0110 loss = 0.000002
Epoch: 0110 loss = 0.000002
Epoch: 0111 loss = 0.000002
Epoch: 0111 loss = 0.000003
Epoch: 0112 loss = 0.000003
Epoch: 0112 loss = 0.000002
Epoch: 0113 loss = 0.000003
Epoch: 0113 loss = 0.000003
Epoch: 0114 loss = 0.000003
Epoch: 0114 loss = 0.000002
Epoch: 0115 loss = 0.000001
Epoch: 0115 loss = 0.000003
Epoch: 0116 loss = 0.000002
Epoch: 0116 loss = 0.000002
Epoch: 0117 loss = 0.000003
Epoch: 0117 loss = 0.000002
Epoch: 0118 loss = 0.000002
Epoch: 0118 loss = 0.000001
Epoch: 0119 loss = 0.000003
Epoch: 0119 loss = 0.000002
Epoch: 0120 loss = 0.000002
Epoch: 0120 loss = 0.000002
Epoch: 0121 loss = 0.000002
Epoch: 0121 loss = 0.000003
Epoch: 0122 loss = 0.000003
Epoch: 0122 loss = 0.000002
Epoch: 0123 loss = 0.000003
Epoch: 0123 loss = 0.000002
Epoch: 0124 loss = 0.000002
Epoch: 0124 loss = 0.000002
Epoch: 0125 loss = 0.000002
Epoch: 0125 loss = 0.000003
Epoch: 0126 loss = 0.000002
Epoch: 0126 loss = 0.000002
Epoch: 0127 loss = 0.000002
Epoch: 0127 loss = 0.000002
Epoch: 0128 loss = 0.000002
Epoch: 0128 loss = 0.000002
Epoch: 0129 loss = 0.000002
Epoch: 0129 loss = 0.000003
Epoch: 0130 loss = 0.000002
Epoch: 0130 loss = 0.000002
Epoch: 0131 loss = 0.000002
Epoch: 0131 loss = 0.000002
Epoch: 0132 loss = 0.000004
Epoch: 0132 loss = 0.000002
Epoch: 0133 loss = 0.000003
Epoch: 0133 loss = 0.000002
Epoch: 0134 loss = 0.000003
Epoch: 0134 loss = 0.000003
Epoch: 0135 loss = 0.000003
Epoch: 0135 loss = 0.000002
Epoch: 0136 loss = 0.000003
Epoch: 0136 loss = 0.000002
Epoch: 0137 loss = 0.000002
Epoch: 0137 loss = 0.000002
Epoch: 0138 loss = 0.000004
Epoch: 0138 loss = 0.000003
Epoch: 0139 loss = 0.000003
Epoch: 0139 loss = 0.000002
Epoch: 0140 loss = 0.000002
Epoch: 0140 loss = 0.000002
Epoch: 0141 loss = 0.000003
Epoch: 0141 loss = 0.000002
Epoch: 0142 loss = 0.000002
Epoch: 0142 loss = 0.000002
Epoch: 0143 loss = 0.000003
Epoch: 0143 loss = 0.000002
Epoch: 0144 loss = 0.000003
Epoch: 0144 loss = 0.000003
Epoch: 0145 loss = 0.000003
Epoch: 0145 loss = 0.000003
Epoch: 0146 loss = 0.000002
Epoch: 0146 loss = 0.000002
Epoch: 0147 loss = 0.000001
Epoch: 0147 loss = 0.000002
Epoch: 0148 loss = 0.000002
Epoch: 0148 loss = 0.000002
Epoch: 0149 loss = 0.000005
Epoch: 0149 loss = 0.000002
Epoch: 0150 loss = 0.000003
Epoch: 0150 loss = 0.000002
Epoch: 0151 loss = 0.000002
Epoch: 0151 loss = 0.000001
Epoch: 0152 loss = 0.000003
Epoch: 0152 loss = 0.000003
Epoch: 0153 loss = 0.000003
Epoch: 0153 loss = 0.000003
Epoch: 0154 loss = 0.000002
Epoch: 0154 loss = 0.000002
Epoch: 0155 loss = 0.000004
Epoch: 0155 loss = 0.000002
Epoch: 0156 loss = 0.000003
Epoch: 0156 loss = 0.000002
Epoch: 0157 loss = 0.000002
Epoch: 0157 loss = 0.000002
Epoch: 0158 loss = 0.000005
Epoch: 0158 loss = 0.000003
Epoch: 0159 loss = 0.000002
Epoch: 0159 loss = 0.000002
Epoch: 0160 loss = 0.000002
Epoch: 0160 loss = 0.000001
Epoch: 0161 loss = 0.000002
Epoch: 0161 loss = 0.000002
Epoch: 0162 loss = 0.000001
Epoch: 0162 loss = 0.000002
Epoch: 0163 loss = 0.000004
Epoch: 0163 loss = 0.000003
Epoch: 0164 loss = 0.000002
Epoch: 0164 loss = 0.000003
Epoch: 0165 loss = 0.000004
Epoch: 0165 loss = 0.000003
Epoch: 0166 loss = 0.000002
Epoch: 0166 loss = 0.000002
Epoch: 0167 loss = 0.000002
Epoch: 0167 loss = 0.000002
Epoch: 0168 loss = 0.000002
Epoch: 0168 loss = 0.000001
Epoch: 0169 loss = 0.000003
Epoch: 0169 loss = 0.000003
Epoch: 0170 loss = 0.000003
Epoch: 0170 loss = 0.000003
Epoch: 0171 loss = 0.000004
Epoch: 0171 loss = 0.000003
Epoch: 0172 loss = 0.000002
Epoch: 0172 loss = 0.000002
Epoch: 0173 loss = 0.000001
Epoch: 0173 loss = 0.000003
Epoch: 0174 loss = 0.000003
Epoch: 0174 loss = 0.000003
Epoch: 0175 loss = 0.000002
Epoch: 0175 loss = 0.000002
Epoch: 0176 loss = 0.000003
Epoch: 0176 loss = 0.000002
Epoch: 0177 loss = 0.000002
Epoch: 0177 loss = 0.000002
Epoch: 0178 loss = 0.000003
Epoch: 0178 loss = 0.000002
Epoch: 0179 loss = 0.000003
Epoch: 0179 loss = 0.000002
Epoch: 0180 loss = 0.000003
Epoch: 0180 loss = 0.000002
Epoch: 0181 loss = 0.000002
Epoch: 0181 loss = 0.000002
Epoch: 0182 loss = 0.000003
Epoch: 0182 loss = 0.000001
Epoch: 0183 loss = 0.000003
Epoch: 0183 loss = 0.000003
Epoch: 0184 loss = 0.000003
Epoch: 0184 loss = 0.000002
Epoch: 0185 loss = 0.000003
Epoch: 0185 loss = 0.000002
Epoch: 0186 loss = 0.000001
Epoch: 0186 loss = 0.000002
Epoch: 0187 loss = 0.000002
Epoch: 0187 loss = 0.000002
Epoch: 0188 loss = 0.000001
Epoch: 0188 loss = 0.000002
Epoch: 0189 loss = 0.000002
Epoch: 0189 loss = 0.000002
Epoch: 0190 loss = 0.000002
Epoch: 0190 loss = 0.000003
Epoch: 0191 loss = 0.000002
Epoch: 0191 loss = 0.000003
Epoch: 0192 loss = 0.000002
Epoch: 0192 loss = 0.000001
Epoch: 0193 loss = 0.000003
Epoch: 0193 loss = 0.000003
Epoch: 0194 loss = 0.000002
Epoch: 0194 loss = 0.000003
Epoch: 0195 loss = 0.000002
Epoch: 0195 loss = 0.000002
Epoch: 0196 loss = 0.000003
Epoch: 0196 loss = 0.000002
Epoch: 0197 loss = 0.000003
Epoch: 0197 loss = 0.000002
Epoch: 0198 loss = 0.000002
Epoch: 0198 loss = 0.000002
Epoch: 0199 loss = 0.000001
Epoch: 0199 loss = 0.000002
Epoch: 0200 loss = 0.000002
Epoch: 0200 loss = 0.000002
Epoch: 0201 loss = 0.000003
Epoch: 0201 loss = 0.000002
Epoch: 0202 loss = 0.000003
Epoch: 0202 loss = 0.000002
Epoch: 0203 loss = 0.000001
Epoch: 0203 loss = 0.000002
Epoch: 0204 loss = 0.000002
Epoch: 0204 loss = 0.000002
Epoch: 0205 loss = 0.000002
Epoch: 0205 loss = 0.000003
Epoch: 0206 loss = 0.000003
Epoch: 0206 loss = 0.000002
Epoch: 0207 loss = 0.000002
Epoch: 0207 loss = 0.000003
Epoch: 0208 loss = 0.000003
Epoch: 0208 loss = 0.000001
Epoch: 0209 loss = 0.000002
Epoch: 0209 loss = 0.000003
Epoch: 0210 loss = 0.000002
Epoch: 0210 loss = 0.000002
Epoch: 0211 loss = 0.000002
Epoch: 0211 loss = 0.000002
Epoch: 0212 loss = 0.000002
Epoch: 0212 loss = 0.000002
Epoch: 0213 loss = 0.000003
Epoch: 0213 loss = 0.000002
Epoch: 0214 loss = 0.000002
Epoch: 0214 loss = 0.000002
Epoch: 0215 loss = 0.000002
Epoch: 0215 loss = 0.000001
Epoch: 0216 loss = 0.000002
Epoch: 0216 loss = 0.000002
Epoch: 0217 loss = 0.000002
Epoch: 0217 loss = 0.000003
Epoch: 0218 loss = 0.000002
Epoch: 0218 loss = 0.000002
Epoch: 0219 loss = 0.000002
Epoch: 0219 loss = 0.000002
Epoch: 0220 loss = 0.000003
Epoch: 0220 loss = 0.000002
Epoch: 0221 loss = 0.000002
Epoch: 0221 loss = 0.000002
Epoch: 0222 loss = 0.000002
Epoch: 0222 loss = 0.000002
Epoch: 0223 loss = 0.000003
Epoch: 0223 loss = 0.000001
Epoch: 0224 loss = 0.000002
Epoch: 0224 loss = 0.000002
Epoch: 0225 loss = 0.000002
Epoch: 0225 loss = 0.000002
Epoch: 0226 loss = 0.000002
Epoch: 0226 loss = 0.000002
Epoch: 0227 loss = 0.000001
Epoch: 0227 loss = 0.000003
Epoch: 0228 loss = 0.000003
Epoch: 0228 loss = 0.000002
Epoch: 0229 loss = 0.000003
Epoch: 0229 loss = 0.000002
Epoch: 0230 loss = 0.000003
Epoch: 0230 loss = 0.000002
Epoch: 0231 loss = 0.000002
Epoch: 0231 loss = 0.000002
Epoch: 0232 loss = 0.000004
Epoch: 0232 loss = 0.000001
Epoch: 0233 loss = 0.000003
Epoch: 0233 loss = 0.000002
Epoch: 0234 loss = 0.000002
Epoch: 0234 loss = 0.000002
Epoch: 0235 loss = 0.000002
Epoch: 0235 loss = 0.000002
Epoch: 0236 loss = 0.000002
Epoch: 0236 loss = 0.000002
Epoch: 0237 loss = 0.000002
Epoch: 0237 loss = 0.000002
Epoch: 0238 loss = 0.000003
Epoch: 0238 loss = 0.000002
Epoch: 0239 loss = 0.000002
Epoch: 0239 loss = 0.000002
Epoch: 0240 loss = 0.000003
Epoch: 0240 loss = 0.000002
Epoch: 0241 loss = 0.000002
Epoch: 0241 loss = 0.000002
Epoch: 0242 loss = 0.000002
Epoch: 0242 loss = 0.000002
Epoch: 0243 loss = 0.000002
Epoch: 0243 loss = 0.000002
Epoch: 0244 loss = 0.000002
Epoch: 0244 loss = 0.000001
Epoch: 0245 loss = 0.000001
Epoch: 0245 loss = 0.000002
Epoch: 0246 loss = 0.000002
Epoch: 0246 loss = 0.000002
Epoch: 0247 loss = 0.000002
Epoch: 0247 loss = 0.000002
Epoch: 0248 loss = 0.000002
Epoch: 0248 loss = 0.000002
Epoch: 0249 loss = 0.000002
Epoch: 0249 loss = 0.000002
Epoch: 0250 loss = 0.000002
Epoch: 0250 loss = 0.000002
Epoch: 0251 loss = 0.000002
Epoch: 0251 loss = 0.000002
Epoch: 0252 loss = 0.000001
Epoch: 0252 loss = 0.000002
Epoch: 0253 loss = 0.000002
Epoch: 0253 loss = 0.000003
Epoch: 0254 loss = 0.000002
Epoch: 0254 loss = 0.000002
Epoch: 0255 loss = 0.000003
Epoch: 0255 loss = 0.000002
Epoch: 0256 loss = 0.000001
Epoch: 0256 loss = 0.000002
Epoch: 0257 loss = 0.000003
Epoch: 0257 loss = 0.000002
Epoch: 0258 loss = 0.000001
Epoch: 0258 loss = 0.000002
Epoch: 0259 loss = 0.000003
Epoch: 0259 loss = 0.000002
Epoch: 0260 loss = 0.000001
Epoch: 0260 loss = 0.000001
Epoch: 0261 loss = 0.000002
Epoch: 0261 loss = 0.000002
Epoch: 0262 loss = 0.000002
Epoch: 0262 loss = 0.000003
Epoch: 0263 loss = 0.000002
Epoch: 0263 loss = 0.000002
Epoch: 0264 loss = 0.000002
Epoch: 0264 loss = 0.000002
Epoch: 0265 loss = 0.000003
Epoch: 0265 loss = 0.000002
Epoch: 0266 loss = 0.000001
Epoch: 0266 loss = 0.000002
Epoch: 0267 loss = 0.000003
Epoch: 0267 loss = 0.000002
Epoch: 0268 loss = 0.000002
Epoch: 0268 loss = 0.000002
Epoch: 0269 loss = 0.000002
Epoch: 0269 loss = 0.000002
Epoch: 0270 loss = 0.000003
Epoch: 0270 loss = 0.000002
Epoch: 0271 loss = 0.000002
Epoch: 0271 loss = 0.000002
Epoch: 0272 loss = 0.000003
Epoch: 0272 loss = 0.000002
Epoch: 0273 loss = 0.000002
Epoch: 0273 loss = 0.000002
Epoch: 0274 loss = 0.000002
Epoch: 0274 loss = 0.000001
Epoch: 0275 loss = 0.000003
Epoch: 0275 loss = 0.000003
Epoch: 0276 loss = 0.000001
Epoch: 0276 loss = 0.000002
Epoch: 0277 loss = 0.000003
Epoch: 0277 loss = 0.000002
Epoch: 0278 loss = 0.000002
Epoch: 0278 loss = 0.000002
Epoch: 0279 loss = 0.000003
Epoch: 0279 loss = 0.000002
Epoch: 0280 loss = 0.000002
Epoch: 0280 loss = 0.000003
Epoch: 0281 loss = 0.000003
Epoch: 0281 loss = 0.000003
Epoch: 0282 loss = 0.000002
Epoch: 0282 loss = 0.000002
Epoch: 0283 loss = 0.000002
Epoch: 0283 loss = 0.000003
Epoch: 0284 loss = 0.000001
Epoch: 0284 loss = 0.000002
Epoch: 0285 loss = 0.000002
Epoch: 0285 loss = 0.000002
Epoch: 0286 loss = 0.000002
Epoch: 0286 loss = 0.000002
Epoch: 0287 loss = 0.000003
Epoch: 0287 loss = 0.000002
Epoch: 0288 loss = 0.000002
Epoch: 0288 loss = 0.000002
Epoch: 0289 loss = 0.000001
Epoch: 0289 loss = 0.000002
Epoch: 0290 loss = 0.000002
Epoch: 0290 loss = 0.000002
Epoch: 0291 loss = 0.000003
Epoch: 0291 loss = 0.000002
Epoch: 0292 loss = 0.000002
Epoch: 0292 loss = 0.000002
Epoch: 0293 loss = 0.000002
Epoch: 0293 loss = 0.000002
Epoch: 0294 loss = 0.000003
Epoch: 0294 loss = 0.000003
Epoch: 0295 loss = 0.000002
Epoch: 0295 loss = 0.000002
Epoch: 0296 loss = 0.000003
Epoch: 0296 loss = 0.000002
Epoch: 0297 loss = 0.000003
Epoch: 0297 loss = 0.000002
Epoch: 0298 loss = 0.000003
Epoch: 0298 loss = 0.000001
Epoch: 0299 loss = 0.000004
Epoch: 0299 loss = 0.000002
Epoch: 0300 loss = 0.000002
Epoch: 0300 loss = 0.000002
Epoch: 0301 loss = 0.000003
Epoch: 0301 loss = 0.000002
Epoch: 0302 loss = 0.000002
Epoch: 0302 loss = 0.000003
Epoch: 0303 loss = 0.000002
Epoch: 0303 loss = 0.000002
Epoch: 0304 loss = 0.000002
Epoch: 0304 loss = 0.000002
Epoch: 0305 loss = 0.000002
Epoch: 0305 loss = 0.000002
Epoch: 0306 loss = 0.000002
Epoch: 0306 loss = 0.000002
Epoch: 0307 loss = 0.000002
Epoch: 0307 loss = 0.000002
Epoch: 0308 loss = 0.000003
Epoch: 0308 loss = 0.000002
Epoch: 0309 loss = 0.000002
Epoch: 0309 loss = 0.000002
Epoch: 0310 loss = 0.000002
Epoch: 0310 loss = 0.000002
Epoch: 0311 loss = 0.000002
Epoch: 0311 loss = 0.000003
Epoch: 0312 loss = 0.000003
Epoch: 0312 loss = 0.000002
Epoch: 0313 loss = 0.000003
Epoch: 0313 loss = 0.000002
Epoch: 0314 loss = 0.000003
Epoch: 0314 loss = 0.000002
Epoch: 0315 loss = 0.000003
Epoch: 0315 loss = 0.000002
Epoch: 0316 loss = 0.000003
Epoch: 0316 loss = 0.000002
Epoch: 0317 loss = 0.000002
Epoch: 0317 loss = 0.000002
Epoch: 0318 loss = 0.000002
Epoch: 0318 loss = 0.000001
Epoch: 0319 loss = 0.000001
Epoch: 0319 loss = 0.000002
Epoch: 0320 loss = 0.000003
Epoch: 0320 loss = 0.000002
Epoch: 0321 loss = 0.000002
Epoch: 0321 loss = 0.000002
Epoch: 0322 loss = 0.000002
Epoch: 0322 loss = 0.000004
Epoch: 0323 loss = 0.000002
Epoch: 0323 loss = 0.000002
Epoch: 0324 loss = 0.000002
Epoch: 0324 loss = 0.000002
Epoch: 0325 loss = 0.000002
Epoch: 0325 loss = 0.000002
Epoch: 0326 loss = 0.000002
Epoch: 0326 loss = 0.000002
Epoch: 0327 loss = 0.000003
Epoch: 0327 loss = 0.000002
Epoch: 0328 loss = 0.000003
Epoch: 0328 loss = 0.000002
Epoch: 0329 loss = 0.000003
Epoch: 0329 loss = 0.000002
Epoch: 0330 loss = 0.000003
Epoch: 0330 loss = 0.000002
Epoch: 0331 loss = 0.000002
Epoch: 0331 loss = 0.000002
Epoch: 0332 loss = 0.000003
Epoch: 0332 loss = 0.000003
Epoch: 0333 loss = 0.000002
Epoch: 0333 loss = 0.000002
Epoch: 0334 loss = 0.000003
Epoch: 0334 loss = 0.000001
Epoch: 0335 loss = 0.000003
Epoch: 0335 loss = 0.000003
Epoch: 0336 loss = 0.000001
Epoch: 0336 loss = 0.000002
Epoch: 0337 loss = 0.000002
Epoch: 0337 loss = 0.000002
Epoch: 0338 loss = 0.000002
Epoch: 0338 loss = 0.000002
Epoch: 0339 loss = 0.000002
Epoch: 0339 loss = 0.000001
Epoch: 0340 loss = 0.000001
Epoch: 0340 loss = 0.000002
Epoch: 0341 loss = 0.000002
Epoch: 0341 loss = 0.000002
Epoch: 0342 loss = 0.000003
Epoch: 0342 loss = 0.000001
Epoch: 0343 loss = 0.000003
Epoch: 0343 loss = 0.000002
Epoch: 0344 loss = 0.000003
Epoch: 0344 loss = 0.000003
Epoch: 0345 loss = 0.000003
Epoch: 0345 loss = 0.000002
Epoch: 0346 loss = 0.000002
Epoch: 0346 loss = 0.000001
Epoch: 0347 loss = 0.000002
Epoch: 0347 loss = 0.000002
Epoch: 0348 loss = 0.000003
Epoch: 0348 loss = 0.000003
Epoch: 0349 loss = 0.000002
Epoch: 0349 loss = 0.000002
Epoch: 0350 loss = 0.000001
Epoch: 0350 loss = 0.000002
Epoch: 0351 loss = 0.000004
Epoch: 0351 loss = 0.000002
Epoch: 0352 loss = 0.000004
Epoch: 0352 loss = 0.000003
Epoch: 0353 loss = 0.000001
Epoch: 0353 loss = 0.000002
Epoch: 0354 loss = 0.000003
Epoch: 0354 loss = 0.000002
Epoch: 0355 loss = 0.000002
Epoch: 0355 loss = 0.000001
Epoch: 0356 loss = 0.000001
Epoch: 0356 loss = 0.000002
Epoch: 0357 loss = 0.000003
Epoch: 0357 loss = 0.000002
Epoch: 0358 loss = 0.000003
Epoch: 0358 loss = 0.000002
Epoch: 0359 loss = 0.000003
Epoch: 0359 loss = 0.000002
Epoch: 0360 loss = 0.000002
Epoch: 0360 loss = 0.000002
Epoch: 0361 loss = 0.000004
Epoch: 0361 loss = 0.000002
Epoch: 0362 loss = 0.000004
Epoch: 0362 loss = 0.000003
Epoch: 0363 loss = 0.000002
Epoch: 0363 loss = 0.000002
Epoch: 0364 loss = 0.000003
Epoch: 0364 loss = 0.000002
Epoch: 0365 loss = 0.000002
Epoch: 0365 loss = 0.000002
Epoch: 0366 loss = 0.000003
Epoch: 0366 loss = 0.000002
Epoch: 0367 loss = 0.000002
Epoch: 0367 loss = 0.000002
Epoch: 0368 loss = 0.000003
Epoch: 0368 loss = 0.000002
Epoch: 0369 loss = 0.000002
Epoch: 0369 loss = 0.000002
Epoch: 0370 loss = 0.000002
Epoch: 0370 loss = 0.000002
Epoch: 0371 loss = 0.000002
Epoch: 0371 loss = 0.000002
Epoch: 0372 loss = 0.000003
Epoch: 0372 loss = 0.000002
Epoch: 0373 loss = 0.000002
Epoch: 0373 loss = 0.000001
Epoch: 0374 loss = 0.000002
Epoch: 0374 loss = 0.000002
Epoch: 0375 loss = 0.000004
Epoch: 0375 loss = 0.000003
Epoch: 0376 loss = 0.000002
Epoch: 0376 loss = 0.000002
Epoch: 0377 loss = 0.000002
Epoch: 0377 loss = 0.000002
Epoch: 0378 loss = 0.000002
Epoch: 0378 loss = 0.000001
Epoch: 0379 loss = 0.000002
Epoch: 0379 loss = 0.000003
Epoch: 0380 loss = 0.000001
Epoch: 0380 loss = 0.000002
Epoch: 0381 loss = 0.000002
Epoch: 0381 loss = 0.000002
Epoch: 0382 loss = 0.000003
Epoch: 0382 loss = 0.000001
Epoch: 0383 loss = 0.000001
Epoch: 0383 loss = 0.000002
Epoch: 0384 loss = 0.000001
Epoch: 0384 loss = 0.000002
Epoch: 0385 loss = 0.000003
Epoch: 0385 loss = 0.000003
Epoch: 0386 loss = 0.000003
Epoch: 0386 loss = 0.000004
Epoch: 0387 loss = 0.000001
Epoch: 0387 loss = 0.000002
Epoch: 0388 loss = 0.000002
Epoch: 0388 loss = 0.000002
Epoch: 0389 loss = 0.000003
Epoch: 0389 loss = 0.000001
Epoch: 0390 loss = 0.000002
Epoch: 0390 loss = 0.000002
Epoch: 0391 loss = 0.000003
Epoch: 0391 loss = 0.000002
Epoch: 0392 loss = 0.000001
Epoch: 0392 loss = 0.000002
Epoch: 0393 loss = 0.000003
Epoch: 0393 loss = 0.000002
Epoch: 0394 loss = 0.000002
Epoch: 0394 loss = 0.000003
Epoch: 0395 loss = 0.000002
Epoch: 0395 loss = 0.000001
Epoch: 0396 loss = 0.000002
Epoch: 0396 loss = 0.000002
Epoch: 0397 loss = 0.000002
Epoch: 0397 loss = 0.000003
Epoch: 0398 loss = 0.000002
Epoch: 0398 loss = 0.000001
Epoch: 0399 loss = 0.000003
Epoch: 0399 loss = 0.000003
Epoch: 0400 loss = 0.000002
Epoch: 0400 loss = 0.000002
Epoch: 0401 loss = 0.000002
Epoch: 0401 loss = 0.000002
Epoch: 0402 loss = 0.000003
Epoch: 0402 loss = 0.000001
Epoch: 0403 loss = 0.000002
Epoch: 0403 loss = 0.000001
Epoch: 0404 loss = 0.000002
Epoch: 0404 loss = 0.000002
Epoch: 0405 loss = 0.000002
Epoch: 0405 loss = 0.000002
Epoch: 0406 loss = 0.000002
Epoch: 0406 loss = 0.000001
Epoch: 0407 loss = 0.000002
Epoch: 0407 loss = 0.000002
Epoch: 0408 loss = 0.000003
Epoch: 0408 loss = 0.000002
Epoch: 0409 loss = 0.000002
Epoch: 0409 loss = 0.000001
Epoch: 0410 loss = 0.000003
Epoch: 0410 loss = 0.000001
Epoch: 0411 loss = 0.000002
Epoch: 0411 loss = 0.000002
Epoch: 0412 loss = 0.000002
Epoch: 0412 loss = 0.000003
Epoch: 0413 loss = 0.000002
Epoch: 0413 loss = 0.000002
Epoch: 0414 loss = 0.000003
Epoch: 0414 loss = 0.000002
Epoch: 0415 loss = 0.000002
Epoch: 0415 loss = 0.000002
Epoch: 0416 loss = 0.000003
Epoch: 0416 loss = 0.000001
Epoch: 0417 loss = 0.000003
Epoch: 0417 loss = 0.000002
Epoch: 0418 loss = 0.000003
Epoch: 0418 loss = 0.000002
Epoch: 0419 loss = 0.000001
Epoch: 0419 loss = 0.000002
Epoch: 0420 loss = 0.000003
Epoch: 0420 loss = 0.000002
Epoch: 0421 loss = 0.000003
Epoch: 0421 loss = 0.000002
Epoch: 0422 loss = 0.000003
Epoch: 0422 loss = 0.000002
Epoch: 0423 loss = 0.000002
Epoch: 0423 loss = 0.000002
Epoch: 0424 loss = 0.000002
Epoch: 0424 loss = 0.000003
Epoch: 0425 loss = 0.000002
Epoch: 0425 loss = 0.000002
Epoch: 0426 loss = 0.000003
Epoch: 0426 loss = 0.000001
Epoch: 0427 loss = 0.000002
Epoch: 0427 loss = 0.000002
Epoch: 0428 loss = 0.000002
Epoch: 0428 loss = 0.000002
Epoch: 0429 loss = 0.000003
Epoch: 0429 loss = 0.000002
Epoch: 0430 loss = 0.000001
Epoch: 0430 loss = 0.000002
Epoch: 0431 loss = 0.000002
Epoch: 0431 loss = 0.000002
Epoch: 0432 loss = 0.000002
Epoch: 0432 loss = 0.000002
Epoch: 0433 loss = 0.000003
Epoch: 0433 loss = 0.000002
Epoch: 0434 loss = 0.000002
Epoch: 0434 loss = 0.000001
Epoch: 0435 loss = 0.000003
Epoch: 0435 loss = 0.000002
Epoch: 0436 loss = 0.000002
Epoch: 0436 loss = 0.000001
Epoch: 0437 loss = 0.000003
Epoch: 0437 loss = 0.000002
Epoch: 0438 loss = 0.000002
Epoch: 0438 loss = 0.000002
Epoch: 0439 loss = 0.000003
Epoch: 0439 loss = 0.000002
Epoch: 0440 loss = 0.000001
Epoch: 0440 loss = 0.000001
Epoch: 0441 loss = 0.000003
Epoch: 0441 loss = 0.000003
Epoch: 0442 loss = 0.000002
Epoch: 0442 loss = 0.000002
Epoch: 0443 loss = 0.000002
Epoch: 0443 loss = 0.000003
Epoch: 0444 loss = 0.000002
Epoch: 0444 loss = 0.000003
Epoch: 0445 loss = 0.000002
Epoch: 0445 loss = 0.000002
Epoch: 0446 loss = 0.000002
Epoch: 0446 loss = 0.000002
Epoch: 0447 loss = 0.000002
Epoch: 0447 loss = 0.000002
Epoch: 0448 loss = 0.000002
Epoch: 0448 loss = 0.000002
Epoch: 0449 loss = 0.000003
Epoch: 0449 loss = 0.000002
Epoch: 0450 loss = 0.000002
Epoch: 0450 loss = 0.000002
Epoch: 0451 loss = 0.000003
Epoch: 0451 loss = 0.000002
Epoch: 0452 loss = 0.000003
Epoch: 0452 loss = 0.000002
Epoch: 0453 loss = 0.000002
Epoch: 0453 loss = 0.000001
Epoch: 0454 loss = 0.000003
Epoch: 0454 loss = 0.000002
Epoch: 0455 loss = 0.000002
Epoch: 0455 loss = 0.000002
Epoch: 0456 loss = 0.000002
Epoch: 0456 loss = 0.000002
Epoch: 0457 loss = 0.000003
Epoch: 0457 loss = 0.000002
Epoch: 0458 loss = 0.000003
Epoch: 0458 loss = 0.000003
Epoch: 0459 loss = 0.000002
Epoch: 0459 loss = 0.000001
Epoch: 0460 loss = 0.000002
Epoch: 0460 loss = 0.000003
Epoch: 0461 loss = 0.000002
Epoch: 0461 loss = 0.000002
Epoch: 0462 loss = 0.000001
Epoch: 0462 loss = 0.000002
Epoch: 0463 loss = 0.000001
Epoch: 0463 loss = 0.000002
Epoch: 0464 loss = 0.000002
Epoch: 0464 loss = 0.000002
Epoch: 0465 loss = 0.000002
Epoch: 0465 loss = 0.000002
Epoch: 0466 loss = 0.000003
Epoch: 0466 loss = 0.000002
Epoch: 0467 loss = 0.000001
Epoch: 0467 loss = 0.000002
Epoch: 0468 loss = 0.000002
Epoch: 0468 loss = 0.000001
Epoch: 0469 loss = 0.000002
Epoch: 0469 loss = 0.000002
Epoch: 0470 loss = 0.000002
Epoch: 0470 loss = 0.000002
Epoch: 0471 loss = 0.000003
Epoch: 0471 loss = 0.000002
Epoch: 0472 loss = 0.000002
Epoch: 0472 loss = 0.000002
Epoch: 0473 loss = 0.000002
Epoch: 0473 loss = 0.000003
Epoch: 0474 loss = 0.000002
Epoch: 0474 loss = 0.000002
Epoch: 0475 loss = 0.000003
Epoch: 0475 loss = 0.000002
Epoch: 0476 loss = 0.000002
Epoch: 0476 loss = 0.000003
Epoch: 0477 loss = 0.000003
Epoch: 0477 loss = 0.000002
Epoch: 0478 loss = 0.000003
Epoch: 0478 loss = 0.000002
Epoch: 0479 loss = 0.000002
Epoch: 0479 loss = 0.000002
Epoch: 0480 loss = 0.000002
Epoch: 0480 loss = 0.000002
Epoch: 0481 loss = 0.000002
Epoch: 0481 loss = 0.000002
Epoch: 0482 loss = 0.000003
Epoch: 0482 loss = 0.000002
Epoch: 0483 loss = 0.000002
Epoch: 0483 loss = 0.000003
Epoch: 0484 loss = 0.000002
Epoch: 0484 loss = 0.000002
Epoch: 0485 loss = 0.000002
Epoch: 0485 loss = 0.000002
Epoch: 0486 loss = 0.000002
Epoch: 0486 loss = 0.000002
Epoch: 0487 loss = 0.000003
Epoch: 0487 loss = 0.000003
Epoch: 0488 loss = 0.000002
Epoch: 0488 loss = 0.000002
Epoch: 0489 loss = 0.000002
Epoch: 0489 loss = 0.000002
Epoch: 0490 loss = 0.000003
Epoch: 0490 loss = 0.000002
Epoch: 0491 loss = 0.000002
Epoch: 0491 loss = 0.000002
Epoch: 0492 loss = 0.000002
Epoch: 0492 loss = 0.000002
Epoch: 0493 loss = 0.000001
Epoch: 0493 loss = 0.000002
Epoch: 0494 loss = 0.000003
Epoch: 0494 loss = 0.000002
Epoch: 0495 loss = 0.000003
Epoch: 0495 loss = 0.000002
Epoch: 0496 loss = 0.000001
Epoch: 0496 loss = 0.000002
Epoch: 0497 loss = 0.000002
Epoch: 0497 loss = 0.000002
Epoch: 0498 loss = 0.000002
Epoch: 0498 loss = 0.000002
Epoch: 0499 loss = 0.000003
Epoch: 0499 loss = 0.000002
Epoch: 0500 loss = 0.000002
Epoch: 0500 loss = 0.000002
Epoch: 0501 loss = 0.000002
Epoch: 0501 loss = 0.000002
Epoch: 0502 loss = 0.000001
Epoch: 0502 loss = 0.000003
Epoch: 0503 loss = 0.000003
Epoch: 0503 loss = 0.000002
Epoch: 0504 loss = 0.000002
Epoch: 0504 loss = 0.000002
Epoch: 0505 loss = 0.000003
Epoch: 0505 loss = 0.000002
Epoch: 0506 loss = 0.000002
Epoch: 0506 loss = 0.000002
Epoch: 0507 loss = 0.000004
Epoch: 0507 loss = 0.000002
Epoch: 0508 loss = 0.000003
Epoch: 0508 loss = 0.000002
Epoch: 0509 loss = 0.000002
Epoch: 0509 loss = 0.000003
Epoch: 0510 loss = 0.000001
Epoch: 0510 loss = 0.000002
Epoch: 0511 loss = 0.000002
Epoch: 0511 loss = 0.000003
Epoch: 0512 loss = 0.000002
Epoch: 0512 loss = 0.000002
Epoch: 0513 loss = 0.000003
Epoch: 0513 loss = 0.000002
Epoch: 0514 loss = 0.000001
Epoch: 0514 loss = 0.000002
Epoch: 0515 loss = 0.000002
Epoch: 0515 loss = 0.000002
Epoch: 0516 loss = 0.000001
Epoch: 0516 loss = 0.000002
Epoch: 0517 loss = 0.000002
Epoch: 0517 loss = 0.000002
Epoch: 0518 loss = 0.000004
Epoch: 0518 loss = 0.000003
Epoch: 0519 loss = 0.000002
Epoch: 0519 loss = 0.000002
Epoch: 0520 loss = 0.000002
Epoch: 0520 loss = 0.000003
Epoch: 0521 loss = 0.000002
Epoch: 0521 loss = 0.000002
Epoch: 0522 loss = 0.000002
Epoch: 0522 loss = 0.000002
Epoch: 0523 loss = 0.000002
Epoch: 0523 loss = 0.000002
Epoch: 0524 loss = 0.000002
Epoch: 0524 loss = 0.000002
Epoch: 0525 loss = 0.000002
Epoch: 0525 loss = 0.000003
Epoch: 0526 loss = 0.000002
Epoch: 0526 loss = 0.000002
Epoch: 0527 loss = 0.000003
Epoch: 0527 loss = 0.000001
Epoch: 0528 loss = 0.000002
Epoch: 0528 loss = 0.000002
Epoch: 0529 loss = 0.000002
Epoch: 0529 loss = 0.000001
Epoch: 0530 loss = 0.000002
Epoch: 0530 loss = 0.000002
Epoch: 0531 loss = 0.000003
Epoch: 0531 loss = 0.000001
Epoch: 0532 loss = 0.000001
Epoch: 0532 loss = 0.000002
Epoch: 0533 loss = 0.000002
Epoch: 0533 loss = 0.000002
Epoch: 0534 loss = 0.000002
Epoch: 0534 loss = 0.000003
Epoch: 0535 loss = 0.000002
Epoch: 0535 loss = 0.000002
Epoch: 0536 loss = 0.000003
Epoch: 0536 loss = 0.000002
Epoch: 0537 loss = 0.000003
Epoch: 0537 loss = 0.000001
Epoch: 0538 loss = 0.000002
Epoch: 0538 loss = 0.000002
Epoch: 0539 loss = 0.000003
Epoch: 0539 loss = 0.000002
Epoch: 0540 loss = 0.000002
Epoch: 0540 loss = 0.000002
Epoch: 0541 loss = 0.000003
Epoch: 0541 loss = 0.000002
Epoch: 0542 loss = 0.000002
Epoch: 0542 loss = 0.000002
Epoch: 0543 loss = 0.000003
Epoch: 0543 loss = 0.000002
Epoch: 0544 loss = 0.000002
Epoch: 0544 loss = 0.000001
Epoch: 0545 loss = 0.000002
Epoch: 0545 loss = 0.000002
Epoch: 0546 loss = 0.000002
Epoch: 0546 loss = 0.000002
Epoch: 0547 loss = 0.000002
Epoch: 0547 loss = 0.000001
Epoch: 0548 loss = 0.000002
Epoch: 0548 loss = 0.000002
Epoch: 0549 loss = 0.000003
Epoch: 0549 loss = 0.000002
Epoch: 0550 loss = 0.000002
Epoch: 0550 loss = 0.000003
Epoch: 0551 loss = 0.000002
Epoch: 0551 loss = 0.000002
Epoch: 0552 loss = 0.000002
Epoch: 0552 loss = 0.000003
Epoch: 0553 loss = 0.000001
Epoch: 0553 loss = 0.000002
Epoch: 0554 loss = 0.000003
Epoch: 0554 loss = 0.000002
Epoch: 0555 loss = 0.000002
Epoch: 0555 loss = 0.000002
Epoch: 0556 loss = 0.000002
Epoch: 0556 loss = 0.000002
Epoch: 0557 loss = 0.000001
Epoch: 0557 loss = 0.000002
Epoch: 0558 loss = 0.000003
Epoch: 0558 loss = 0.000002
Epoch: 0559 loss = 0.000002
Epoch: 0559 loss = 0.000002
Epoch: 0560 loss = 0.000002
Epoch: 0560 loss = 0.000002
Epoch: 0561 loss = 0.000004
Epoch: 0561 loss = 0.000002
Epoch: 0562 loss = 0.000001
Epoch: 0562 loss = 0.000003
Epoch: 0563 loss = 0.000002
Epoch: 0563 loss = 0.000002
Epoch: 0564 loss = 0.000002
Epoch: 0564 loss = 0.000003
Epoch: 0565 loss = 0.000002
Epoch: 0565 loss = 0.000002
Epoch: 0566 loss = 0.000002
Epoch: 0566 loss = 0.000002
Epoch: 0567 loss = 0.000002
Epoch: 0567 loss = 0.000003
Epoch: 0568 loss = 0.000002
Epoch: 0568 loss = 0.000001
Epoch: 0569 loss = 0.000002
Epoch: 0569 loss = 0.000002
Epoch: 0570 loss = 0.000002
Epoch: 0570 loss = 0.000002
Epoch: 0571 loss = 0.000003
Epoch: 0571 loss = 0.000002
Epoch: 0572 loss = 0.000001
Epoch: 0572 loss = 0.000002
Epoch: 0573 loss = 0.000003
Epoch: 0573 loss = 0.000002
Epoch: 0574 loss = 0.000002
Epoch: 0574 loss = 0.000002
Epoch: 0575 loss = 0.000001
Epoch: 0575 loss = 0.000001
Epoch: 0576 loss = 0.000003
Epoch: 0576 loss = 0.000002
Epoch: 0577 loss = 0.000002
Epoch: 0577 loss = 0.000002
Epoch: 0578 loss = 0.000002
Epoch: 0578 loss = 0.000002
Epoch: 0579 loss = 0.000003
Epoch: 0579 loss = 0.000002
Epoch: 0580 loss = 0.000002
Epoch: 0580 loss = 0.000003
Epoch: 0581 loss = 0.000001
Epoch: 0581 loss = 0.000002
Epoch: 0582 loss = 0.000003
Epoch: 0582 loss = 0.000002
Epoch: 0583 loss = 0.000002
Epoch: 0583 loss = 0.000001
Epoch: 0584 loss = 0.000002
Epoch: 0584 loss = 0.000001
Epoch: 0585 loss = 0.000002
Epoch: 0585 loss = 0.000002
Epoch: 0586 loss = 0.000002
Epoch: 0586 loss = 0.000001
Epoch: 0587 loss = 0.000002
Epoch: 0587 loss = 0.000001
Epoch: 0588 loss = 0.000002
Epoch: 0588 loss = 0.000002
Epoch: 0589 loss = 0.000002
Epoch: 0589 loss = 0.000001
Epoch: 0590 loss = 0.000002
Epoch: 0590 loss = 0.000002
Epoch: 0591 loss = 0.000002
Epoch: 0591 loss = 0.000002
Epoch: 0592 loss = 0.000002
Epoch: 0592 loss = 0.000002
Epoch: 0593 loss = 0.000003
Epoch: 0593 loss = 0.000002
Epoch: 0594 loss = 0.000002
Epoch: 0594 loss = 0.000002
Epoch: 0595 loss = 0.000002
Epoch: 0595 loss = 0.000002
Epoch: 0596 loss = 0.000002
Epoch: 0596 loss = 0.000001
Epoch: 0597 loss = 0.000001
Epoch: 0597 loss = 0.000001
Epoch: 0598 loss = 0.000002
Epoch: 0598 loss = 0.000002
Epoch: 0599 loss = 0.000002
Epoch: 0599 loss = 0.000002
Epoch: 0600 loss = 0.000002
Epoch: 0600 loss = 0.000002
Epoch: 0601 loss = 0.000002
Epoch: 0601 loss = 0.000002
Epoch: 0602 loss = 0.000002
Epoch: 0602 loss = 0.000002
Epoch: 0603 loss = 0.000003
Epoch: 0603 loss = 0.000003
Epoch: 0604 loss = 0.000002
Epoch: 0604 loss = 0.000002
Epoch: 0605 loss = 0.000002
Epoch: 0605 loss = 0.000002
Epoch: 0606 loss = 0.000002
Epoch: 0606 loss = 0.000002
Epoch: 0607 loss = 0.000002
Epoch: 0607 loss = 0.000001
Epoch: 0608 loss = 0.000002
Epoch: 0608 loss = 0.000001
Epoch: 0609 loss = 0.000003
Epoch: 0609 loss = 0.000002
Epoch: 0610 loss = 0.000003
Epoch: 0610 loss = 0.000003
Epoch: 0611 loss = 0.000002
Epoch: 0611 loss = 0.000002
Epoch: 0612 loss = 0.000003
Epoch: 0612 loss = 0.000002
Epoch: 0613 loss = 0.000003
Epoch: 0613 loss = 0.000002
Epoch: 0614 loss = 0.000001
Epoch: 0614 loss = 0.000002
Epoch: 0615 loss = 0.000002
Epoch: 0615 loss = 0.000002
Epoch: 0616 loss = 0.000002
Epoch: 0616 loss = 0.000001
Epoch: 0617 loss = 0.000002
Epoch: 0617 loss = 0.000001
Epoch: 0618 loss = 0.000003
Epoch: 0618 loss = 0.000002
Epoch: 0619 loss = 0.000002
Epoch: 0619 loss = 0.000002
Epoch: 0620 loss = 0.000002
Epoch: 0620 loss = 0.000001
Epoch: 0621 loss = 0.000003
Epoch: 0621 loss = 0.000001
Epoch: 0622 loss = 0.000002
Epoch: 0622 loss = 0.000001
Epoch: 0623 loss = 0.000002
Epoch: 0623 loss = 0.000002
Epoch: 0624 loss = 0.000001
Epoch: 0624 loss = 0.000002
Epoch: 0625 loss = 0.000003
Epoch: 0625 loss = 0.000002
Epoch: 0626 loss = 0.000001
Epoch: 0626 loss = 0.000002
Epoch: 0627 loss = 0.000003
Epoch: 0627 loss = 0.000002
Epoch: 0628 loss = 0.000002
Epoch: 0628 loss = 0.000002
Epoch: 0629 loss = 0.000002
Epoch: 0629 loss = 0.000002
Epoch: 0630 loss = 0.000002
Epoch: 0630 loss = 0.000002
Epoch: 0631 loss = 0.000001
Epoch: 0631 loss = 0.000002
Epoch: 0632 loss = 0.000002
Epoch: 0632 loss = 0.000002
Epoch: 0633 loss = 0.000001
Epoch: 0633 loss = 0.000002
Epoch: 0634 loss = 0.000004
Epoch: 0634 loss = 0.000002
Epoch: 0635 loss = 0.000002
Epoch: 0635 loss = 0.000002
Epoch: 0636 loss = 0.000002
Epoch: 0636 loss = 0.000002
Epoch: 0637 loss = 0.000003
Epoch: 0637 loss = 0.000002
Epoch: 0638 loss = 0.000002
Epoch: 0638 loss = 0.000003
Epoch: 0639 loss = 0.000001
Epoch: 0639 loss = 0.000002
Epoch: 0640 loss = 0.000003
Epoch: 0640 loss = 0.000002
Epoch: 0641 loss = 0.000002
Epoch: 0641 loss = 0.000001
Epoch: 0642 loss = 0.000002
Epoch: 0642 loss = 0.000001
Epoch: 0643 loss = 0.000001
Epoch: 0643 loss = 0.000002
Epoch: 0644 loss = 0.000003
Epoch: 0644 loss = 0.000002
Epoch: 0645 loss = 0.000003
Epoch: 0645 loss = 0.000002
Epoch: 0646 loss = 0.000002
Epoch: 0646 loss = 0.000001
Epoch: 0647 loss = 0.000002
Epoch: 0647 loss = 0.000002
Epoch: 0648 loss = 0.000002
Epoch: 0648 loss = 0.000002
Epoch: 0649 loss = 0.000002
Epoch: 0649 loss = 0.000002
Epoch: 0650 loss = 0.000002
Epoch: 0650 loss = 0.000002
Epoch: 0651 loss = 0.000001
Epoch: 0651 loss = 0.000001
Epoch: 0652 loss = 0.000002
Epoch: 0652 loss = 0.000002
Epoch: 0653 loss = 0.000002
Epoch: 0653 loss = 0.000002
Epoch: 0654 loss = 0.000002
Epoch: 0654 loss = 0.000002
Epoch: 0655 loss = 0.000001
Epoch: 0655 loss = 0.000001
Epoch: 0656 loss = 0.000002
Epoch: 0656 loss = 0.000002
Epoch: 0657 loss = 0.000003
Epoch: 0657 loss = 0.000002
Epoch: 0658 loss = 0.000002
Epoch: 0658 loss = 0.000001
Epoch: 0659 loss = 0.000002
Epoch: 0659 loss = 0.000002
Epoch: 0660 loss = 0.000003
Epoch: 0660 loss = 0.000003
Epoch: 0661 loss = 0.000002
Epoch: 0661 loss = 0.000002
Epoch: 0662 loss = 0.000002
Epoch: 0662 loss = 0.000002
Epoch: 0663 loss = 0.000004
Epoch: 0663 loss = 0.000001
Epoch: 0664 loss = 0.000002
Epoch: 0664 loss = 0.000003
Epoch: 0665 loss = 0.000003
Epoch: 0665 loss = 0.000002
Epoch: 0666 loss = 0.000002
Epoch: 0666 loss = 0.000002
Epoch: 0667 loss = 0.000003
Epoch: 0667 loss = 0.000002
Epoch: 0668 loss = 0.000002
Epoch: 0668 loss = 0.000001
Epoch: 0669 loss = 0.000001
Epoch: 0669 loss = 0.000002
Epoch: 0670 loss = 0.000002
Epoch: 0670 loss = 0.000002
Epoch: 0671 loss = 0.000005
Epoch: 0671 loss = 0.000002
Epoch: 0672 loss = 0.000001
Epoch: 0672 loss = 0.000001
Epoch: 0673 loss = 0.000002
Epoch: 0673 loss = 0.000002
Epoch: 0674 loss = 0.000002
Epoch: 0674 loss = 0.000002
Epoch: 0675 loss = 0.000002
Epoch: 0675 loss = 0.000002
Epoch: 0676 loss = 0.000003
Epoch: 0676 loss = 0.000002
Epoch: 0677 loss = 0.000002
Epoch: 0677 loss = 0.000003
Epoch: 0678 loss = 0.000003
Epoch: 0678 loss = 0.000003
Epoch: 0679 loss = 0.000002
Epoch: 0679 loss = 0.000002
Epoch: 0680 loss = 0.000002
Epoch: 0680 loss = 0.000002
Epoch: 0681 loss = 0.000003
Epoch: 0681 loss = 0.000001
Epoch: 0682 loss = 0.000002
Epoch: 0682 loss = 0.000002
Epoch: 0683 loss = 0.000003
Epoch: 0683 loss = 0.000001
Epoch: 0684 loss = 0.000003
Epoch: 0684 loss = 0.000002
Epoch: 0685 loss = 0.000002
Epoch: 0685 loss = 0.000002
Epoch: 0686 loss = 0.000001
Epoch: 0686 loss = 0.000001
Epoch: 0687 loss = 0.000001
Epoch: 0687 loss = 0.000001
Epoch: 0688 loss = 0.000002
Epoch: 0688 loss = 0.000002
Epoch: 0689 loss = 0.000002
Epoch: 0689 loss = 0.000001
Epoch: 0690 loss = 0.000002
Epoch: 0690 loss = 0.000001
Epoch: 0691 loss = 0.000002
Epoch: 0691 loss = 0.000002
Epoch: 0692 loss = 0.000003
Epoch: 0692 loss = 0.000002
Epoch: 0693 loss = 0.000001
Epoch: 0693 loss = 0.000001
Epoch: 0694 loss = 0.000001
Epoch: 0694 loss = 0.000001
Epoch: 0695 loss = 0.000002
Epoch: 0695 loss = 0.000003
Epoch: 0696 loss = 0.000002
Epoch: 0696 loss = 0.000001
Epoch: 0697 loss = 0.000002
Epoch: 0697 loss = 0.000001
Epoch: 0698 loss = 0.000002
Epoch: 0698 loss = 0.000002
Epoch: 0699 loss = 0.000003
Epoch: 0699 loss = 0.000001
Epoch: 0700 loss = 0.000001
Epoch: 0700 loss = 0.000002
Epoch: 0701 loss = 0.000002
Epoch: 0701 loss = 0.000002
Epoch: 0702 loss = 0.000004
Epoch: 0702 loss = 0.000002
Epoch: 0703 loss = 0.000002
Epoch: 0703 loss = 0.000002
Epoch: 0704 loss = 0.000002
Epoch: 0704 loss = 0.000002
Epoch: 0705 loss = 0.000001
Epoch: 0705 loss = 0.000002
Epoch: 0706 loss = 0.000002
Epoch: 0706 loss = 0.000002
Epoch: 0707 loss = 0.000002
Epoch: 0707 loss = 0.000002
Epoch: 0708 loss = 0.000002
Epoch: 0708 loss = 0.000001
Epoch: 0709 loss = 0.000002
Epoch: 0709 loss = 0.000001
Epoch: 0710 loss = 0.000001
Epoch: 0710 loss = 0.000002
Epoch: 0711 loss = 0.000003
Epoch: 0711 loss = 0.000002
Epoch: 0712 loss = 0.000001
Epoch: 0712 loss = 0.000002
Epoch: 0713 loss = 0.000003
Epoch: 0713 loss = 0.000002
Epoch: 0714 loss = 0.000003
Epoch: 0714 loss = 0.000002
Epoch: 0715 loss = 0.000002
Epoch: 0715 loss = 0.000002
Epoch: 0716 loss = 0.000002
Epoch: 0716 loss = 0.000002
Epoch: 0717 loss = 0.000001
Epoch: 0717 loss = 0.000002
Epoch: 0718 loss = 0.000001
Epoch: 0718 loss = 0.000001
Epoch: 0719 loss = 0.000002
Epoch: 0719 loss = 0.000001
Epoch: 0720 loss = 0.000002
Epoch: 0720 loss = 0.000002
Epoch: 0721 loss = 0.000003
Epoch: 0721 loss = 0.000001
Epoch: 0722 loss = 0.000002
Epoch: 0722 loss = 0.000002
Epoch: 0723 loss = 0.000002
Epoch: 0723 loss = 0.000002
Epoch: 0724 loss = 0.000003
Epoch: 0724 loss = 0.000002
Epoch: 0725 loss = 0.000001
Epoch: 0725 loss = 0.000002
Epoch: 0726 loss = 0.000003
Epoch: 0726 loss = 0.000002
Epoch: 0727 loss = 0.000002
Epoch: 0727 loss = 0.000002
Epoch: 0728 loss = 0.000002
Epoch: 0728 loss = 0.000002
Epoch: 0729 loss = 0.000002
Epoch: 0729 loss = 0.000001
Epoch: 0730 loss = 0.000002
Epoch: 0730 loss = 0.000002
Epoch: 0731 loss = 0.000001
Epoch: 0731 loss = 0.000001
Epoch: 0732 loss = 0.000002
Epoch: 0732 loss = 0.000002
Epoch: 0733 loss = 0.000002
Epoch: 0733 loss = 0.000002
Epoch: 0734 loss = 0.000001
Epoch: 0734 loss = 0.000002
Epoch: 0735 loss = 0.000003
Epoch: 0735 loss = 0.000002
Epoch: 0736 loss = 0.000002
Epoch: 0736 loss = 0.000002
Epoch: 0737 loss = 0.000002
Epoch: 0737 loss = 0.000002
Epoch: 0738 loss = 0.000002
Epoch: 0738 loss = 0.000002
Epoch: 0739 loss = 0.000002
Epoch: 0739 loss = 0.000001
Epoch: 0740 loss = 0.000002
Epoch: 0740 loss = 0.000001
Epoch: 0741 loss = 0.000002
Epoch: 0741 loss = 0.000003
Epoch: 0742 loss = 0.000003
Epoch: 0742 loss = 0.000001
Epoch: 0743 loss = 0.000004
Epoch: 0743 loss = 0.000002
Epoch: 0744 loss = 0.000002
Epoch: 0744 loss = 0.000003
Epoch: 0745 loss = 0.000002
Epoch: 0745 loss = 0.000002
Epoch: 0746 loss = 0.000002
Epoch: 0746 loss = 0.000004
Epoch: 0747 loss = 0.000002
Epoch: 0747 loss = 0.000002
Epoch: 0748 loss = 0.000002
Epoch: 0748 loss = 0.000003
Epoch: 0749 loss = 0.000002
Epoch: 0749 loss = 0.000001
Epoch: 0750 loss = 0.000002
Epoch: 0750 loss = 0.000003
Epoch: 0751 loss = 0.000002
Epoch: 0751 loss = 0.000001
Epoch: 0752 loss = 0.000002
Epoch: 0752 loss = 0.000003
Epoch: 0753 loss = 0.000003
Epoch: 0753 loss = 0.000002
Epoch: 0754 loss = 0.000003
Epoch: 0754 loss = 0.000002
Epoch: 0755 loss = 0.000002
Epoch: 0755 loss = 0.000002
Epoch: 0756 loss = 0.000001
Epoch: 0756 loss = 0.000002
Epoch: 0757 loss = 0.000003
Epoch: 0757 loss = 0.000003
Epoch: 0758 loss = 0.000002
Epoch: 0758 loss = 0.000001
Epoch: 0759 loss = 0.000003
Epoch: 0759 loss = 0.000002
Epoch: 0760 loss = 0.000002
Epoch: 0760 loss = 0.000002
Epoch: 0761 loss = 0.000001
Epoch: 0761 loss = 0.000002
Epoch: 0762 loss = 0.000003
Epoch: 0762 loss = 0.000002
Epoch: 0763 loss = 0.000002
Epoch: 0763 loss = 0.000002
Epoch: 0764 loss = 0.000002
Epoch: 0764 loss = 0.000002
Epoch: 0765 loss = 0.000003
Epoch: 0765 loss = 0.000002
Epoch: 0766 loss = 0.000002
Epoch: 0766 loss = 0.000002
Epoch: 0767 loss = 0.000001
Epoch: 0767 loss = 0.000002
Epoch: 0768 loss = 0.000003
Epoch: 0768 loss = 0.000002
Epoch: 0769 loss = 0.000003
Epoch: 0769 loss = 0.000002
Epoch: 0770 loss = 0.000003
Epoch: 0770 loss = 0.000002
Epoch: 0771 loss = 0.000002
Epoch: 0771 loss = 0.000003
Epoch: 0772 loss = 0.000001
Epoch: 0772 loss = 0.000002
Epoch: 0773 loss = 0.000003
Epoch: 0773 loss = 0.000003
Epoch: 0774 loss = 0.000002
Epoch: 0774 loss = 0.000001
Epoch: 0775 loss = 0.000002
Epoch: 0775 loss = 0.000002
Epoch: 0776 loss = 0.000005
Epoch: 0776 loss = 0.000002
Epoch: 0777 loss = 0.000002
Epoch: 0777 loss = 0.000002
Epoch: 0778 loss = 0.000003
Epoch: 0778 loss = 0.000001
Epoch: 0779 loss = 0.000002
Epoch: 0779 loss = 0.000002
Epoch: 0780 loss = 0.000001
Epoch: 0780 loss = 0.000002
Epoch: 0781 loss = 0.000002
Epoch: 0781 loss = 0.000002
Epoch: 0782 loss = 0.000002
Epoch: 0782 loss = 0.000002
Epoch: 0783 loss = 0.000002
Epoch: 0783 loss = 0.000001
Epoch: 0784 loss = 0.000003
Epoch: 0784 loss = 0.000002
Epoch: 0785 loss = 0.000002
Epoch: 0785 loss = 0.000002
Epoch: 0786 loss = 0.000002
Epoch: 0786 loss = 0.000002
Epoch: 0787 loss = 0.000003
Epoch: 0787 loss = 0.000002
Epoch: 0788 loss = 0.000001
Epoch: 0788 loss = 0.000002
Epoch: 0789 loss = 0.000002
Epoch: 0789 loss = 0.000002
Epoch: 0790 loss = 0.000002
Epoch: 0790 loss = 0.000002
Epoch: 0791 loss = 0.000002
Epoch: 0791 loss = 0.000002
Epoch: 0792 loss = 0.000002
Epoch: 0792 loss = 0.000001
Epoch: 0793 loss = 0.000002
Epoch: 0793 loss = 0.000001
Epoch: 0794 loss = 0.000002
Epoch: 0794 loss = 0.000001
Epoch: 0795 loss = 0.000002
Epoch: 0795 loss = 0.000001
Epoch: 0796 loss = 0.000002
Epoch: 0796 loss = 0.000002
Epoch: 0797 loss = 0.000003
Epoch: 0797 loss = 0.000002
Epoch: 0798 loss = 0.000002
Epoch: 0798 loss = 0.000002
Epoch: 0799 loss = 0.000003
Epoch: 0799 loss = 0.000003
Epoch: 0800 loss = 0.000003
Epoch: 0800 loss = 0.000001
Epoch: 0801 loss = 0.000002
Epoch: 0801 loss = 0.000001
Epoch: 0802 loss = 0.000002
Epoch: 0802 loss = 0.000002
Epoch: 0803 loss = 0.000002
Epoch: 0803 loss = 0.000002
Epoch: 0804 loss = 0.000003
Epoch: 0804 loss = 0.000002
Epoch: 0805 loss = 0.000002
Epoch: 0805 loss = 0.000002
Epoch: 0806 loss = 0.000001
Epoch: 0806 loss = 0.000002
Epoch: 0807 loss = 0.000002
Epoch: 0807 loss = 0.000002
Epoch: 0808 loss = 0.000002
Epoch: 0808 loss = 0.000002
Epoch: 0809 loss = 0.000002
Epoch: 0809 loss = 0.000002
Epoch: 0810 loss = 0.000002
Epoch: 0810 loss = 0.000002
Epoch: 0811 loss = 0.000001
Epoch: 0811 loss = 0.000002
Epoch: 0812 loss = 0.000004
Epoch: 0812 loss = 0.000002
Epoch: 0813 loss = 0.000002
Epoch: 0813 loss = 0.000002
Epoch: 0814 loss = 0.000002
Epoch: 0814 loss = 0.000001
Epoch: 0815 loss = 0.000002
Epoch: 0815 loss = 0.000001
Epoch: 0816 loss = 0.000003
Epoch: 0816 loss = 0.000002
Epoch: 0817 loss = 0.000002
Epoch: 0817 loss = 0.000002
Epoch: 0818 loss = 0.000002
Epoch: 0818 loss = 0.000002
Epoch: 0819 loss = 0.000003
Epoch: 0819 loss = 0.000002
Epoch: 0820 loss = 0.000002
Epoch: 0820 loss = 0.000001
Epoch: 0821 loss = 0.000002
Epoch: 0821 loss = 0.000001
Epoch: 0822 loss = 0.000001
Epoch: 0822 loss = 0.000002
Epoch: 0823 loss = 0.000002
Epoch: 0823 loss = 0.000002
Epoch: 0824 loss = 0.000002
Epoch: 0824 loss = 0.000002
Epoch: 0825 loss = 0.000003
Epoch: 0825 loss = 0.000002
Epoch: 0826 loss = 0.000002
Epoch: 0826 loss = 0.000002
Epoch: 0827 loss = 0.000002
Epoch: 0827 loss = 0.000002
Epoch: 0828 loss = 0.000002
Epoch: 0828 loss = 0.000001
Epoch: 0829 loss = 0.000001
Epoch: 0829 loss = 0.000002
Epoch: 0830 loss = 0.000002
Epoch: 0830 loss = 0.000002
Epoch: 0831 loss = 0.000001
Epoch: 0831 loss = 0.000002
Epoch: 0832 loss = 0.000001
Epoch: 0832 loss = 0.000002
Epoch: 0833 loss = 0.000002
Epoch: 0833 loss = 0.000002
Epoch: 0834 loss = 0.000001
Epoch: 0834 loss = 0.000002
Epoch: 0835 loss = 0.000002
Epoch: 0835 loss = 0.000002
Epoch: 0836 loss = 0.000001
Epoch: 0836 loss = 0.000002
Epoch: 0837 loss = 0.000002
Epoch: 0837 loss = 0.000002
Epoch: 0838 loss = 0.000002
Epoch: 0838 loss = 0.000002
Epoch: 0839 loss = 0.000002
Epoch: 0839 loss = 0.000002
Epoch: 0840 loss = 0.000001
Epoch: 0840 loss = 0.000002
Epoch: 0841 loss = 0.000003
Epoch: 0841 loss = 0.000002
Epoch: 0842 loss = 0.000001
Epoch: 0842 loss = 0.000002
Epoch: 0843 loss = 0.000002
Epoch: 0843 loss = 0.000002
Epoch: 0844 loss = 0.000002
Epoch: 0844 loss = 0.000002
Epoch: 0845 loss = 0.000001
Epoch: 0845 loss = 0.000002
Epoch: 0846 loss = 0.000002
Epoch: 0846 loss = 0.000001
Epoch: 0847 loss = 0.000002
Epoch: 0847 loss = 0.000002
Epoch: 0848 loss = 0.000002
Epoch: 0848 loss = 0.000002
Epoch: 0849 loss = 0.000002
Epoch: 0849 loss = 0.000002
Epoch: 0850 loss = 0.000002
Epoch: 0850 loss = 0.000002
Epoch: 0851 loss = 0.000002
Epoch: 0851 loss = 0.000001
Epoch: 0852 loss = 0.000002
Epoch: 0852 loss = 0.000002
Epoch: 0853 loss = 0.000002
Epoch: 0853 loss = 0.000003
Epoch: 0854 loss = 0.000002
Epoch: 0854 loss = 0.000002
Epoch: 0855 loss = 0.000002
Epoch: 0855 loss = 0.000002
Epoch: 0856 loss = 0.000002
Epoch: 0856 loss = 0.000002
Epoch: 0857 loss = 0.000003
Epoch: 0857 loss = 0.000001
Epoch: 0858 loss = 0.000002
Epoch: 0858 loss = 0.000001
Epoch: 0859 loss = 0.000003
Epoch: 0859 loss = 0.000002
Epoch: 0860 loss = 0.000003
Epoch: 0860 loss = 0.000002
Epoch: 0861 loss = 0.000002
Epoch: 0861 loss = 0.000003
Epoch: 0862 loss = 0.000002
Epoch: 0862 loss = 0.000002
Epoch: 0863 loss = 0.000002
Epoch: 0863 loss = 0.000002
Epoch: 0864 loss = 0.000001
Epoch: 0864 loss = 0.000002
Epoch: 0865 loss = 0.000001
Epoch: 0865 loss = 0.000001
Epoch: 0866 loss = 0.000002
Epoch: 0866 loss = 0.000002
Epoch: 0867 loss = 0.000002
Epoch: 0867 loss = 0.000002
Epoch: 0868 loss = 0.000002
Epoch: 0868 loss = 0.000002
Epoch: 0869 loss = 0.000001
Epoch: 0869 loss = 0.000002
Epoch: 0870 loss = 0.000002
Epoch: 0870 loss = 0.000002
Epoch: 0871 loss = 0.000003
Epoch: 0871 loss = 0.000002
Epoch: 0872 loss = 0.000001
Epoch: 0872 loss = 0.000002
Epoch: 0873 loss = 0.000002
Epoch: 0873 loss = 0.000001
Epoch: 0874 loss = 0.000001
Epoch: 0874 loss = 0.000002
Epoch: 0875 loss = 0.000002
Epoch: 0875 loss = 0.000002
Epoch: 0876 loss = 0.000002
Epoch: 0876 loss = 0.000002
Epoch: 0877 loss = 0.000001
Epoch: 0877 loss = 0.000002
Epoch: 0878 loss = 0.000002
Epoch: 0878 loss = 0.000002
Epoch: 0879 loss = 0.000002
Epoch: 0879 loss = 0.000002
Epoch: 0880 loss = 0.000002
Epoch: 0880 loss = 0.000002
Epoch: 0881 loss = 0.000002
Epoch: 0881 loss = 0.000001
Epoch: 0882 loss = 0.000003
Epoch: 0882 loss = 0.000002
Epoch: 0883 loss = 0.000002
Epoch: 0883 loss = 0.000002
Epoch: 0884 loss = 0.000002
Epoch: 0884 loss = 0.000002
Epoch: 0885 loss = 0.000001
Epoch: 0885 loss = 0.000002
Epoch: 0886 loss = 0.000001
Epoch: 0886 loss = 0.000002
Epoch: 0887 loss = 0.000002
Epoch: 0887 loss = 0.000002
Epoch: 0888 loss = 0.000001
Epoch: 0888 loss = 0.000002
Epoch: 0889 loss = 0.000001
Epoch: 0889 loss = 0.000002
Epoch: 0890 loss = 0.000002
Epoch: 0890 loss = 0.000002
Epoch: 0891 loss = 0.000004
Epoch: 0891 loss = 0.000002
Epoch: 0892 loss = 0.000002
Epoch: 0892 loss = 0.000001
Epoch: 0893 loss = 0.000001
Epoch: 0893 loss = 0.000002
Epoch: 0894 loss = 0.000003
Epoch: 0894 loss = 0.000002
Epoch: 0895 loss = 0.000002
Epoch: 0895 loss = 0.000002
Epoch: 0896 loss = 0.000002
Epoch: 0896 loss = 0.000002
Epoch: 0897 loss = 0.000001
Epoch: 0897 loss = 0.000002
Epoch: 0898 loss = 0.000001
Epoch: 0898 loss = 0.000002
Epoch: 0899 loss = 0.000003
Epoch: 0899 loss = 0.000003
Epoch: 0900 loss = 0.000002
Epoch: 0900 loss = 0.000002
Epoch: 0901 loss = 0.000002
Epoch: 0901 loss = 0.000002
Epoch: 0902 loss = 0.000002
Epoch: 0902 loss = 0.000002
Epoch: 0903 loss = 0.000002
Epoch: 0903 loss = 0.000002
Epoch: 0904 loss = 0.000001
Epoch: 0904 loss = 0.000002
Epoch: 0905 loss = 0.000001
Epoch: 0905 loss = 0.000001
Epoch: 0906 loss = 0.000002
Epoch: 0906 loss = 0.000002
Epoch: 0907 loss = 0.000002
Epoch: 0907 loss = 0.000002
Epoch: 0908 loss = 0.000002
Epoch: 0908 loss = 0.000001
Epoch: 0909 loss = 0.000001
Epoch: 0909 loss = 0.000002
Epoch: 0910 loss = 0.000001
Epoch: 0910 loss = 0.000001
Epoch: 0911 loss = 0.000003
Epoch: 0911 loss = 0.000002
Epoch: 0912 loss = 0.000001
Epoch: 0912 loss = 0.000002
Epoch: 0913 loss = 0.000002
Epoch: 0913 loss = 0.000002
Epoch: 0914 loss = 0.000003
Epoch: 0914 loss = 0.000002
Epoch: 0915 loss = 0.000002
Epoch: 0915 loss = 0.000002
Epoch: 0916 loss = 0.000002
Epoch: 0916 loss = 0.000001
Epoch: 0917 loss = 0.000002
Epoch: 0917 loss = 0.000001
Epoch: 0918 loss = 0.000002
Epoch: 0918 loss = 0.000001
Epoch: 0919 loss = 0.000002
Epoch: 0919 loss = 0.000002
Epoch: 0920 loss = 0.000002
Epoch: 0920 loss = 0.000003
Epoch: 0921 loss = 0.000001
Epoch: 0921 loss = 0.000002
Epoch: 0922 loss = 0.000002
Epoch: 0922 loss = 0.000002
Epoch: 0923 loss = 0.000002
Epoch: 0923 loss = 0.000002
Epoch: 0924 loss = 0.000001
Epoch: 0924 loss = 0.000002
Epoch: 0925 loss = 0.000001
Epoch: 0925 loss = 0.000001
Epoch: 0926 loss = 0.000002
Epoch: 0926 loss = 0.000001
Epoch: 0927 loss = 0.000002
Epoch: 0927 loss = 0.000001
Epoch: 0928 loss = 0.000002
Epoch: 0928 loss = 0.000002
Epoch: 0929 loss = 0.000001
Epoch: 0929 loss = 0.000002
Epoch: 0930 loss = 0.000002
Epoch: 0930 loss = 0.000002
Epoch: 0931 loss = 0.000002
Epoch: 0931 loss = 0.000002
Epoch: 0932 loss = 0.000003
Epoch: 0932 loss = 0.000002
Epoch: 0933 loss = 0.000002
Epoch: 0933 loss = 0.000002
Epoch: 0934 loss = 0.000001
Epoch: 0934 loss = 0.000001
Epoch: 0935 loss = 0.000001
Epoch: 0935 loss = 0.000002
Epoch: 0936 loss = 0.000001
Epoch: 0936 loss = 0.000002
Epoch: 0937 loss = 0.000002
Epoch: 0937 loss = 0.000002
Epoch: 0938 loss = 0.000001
Epoch: 0938 loss = 0.000002
Epoch: 0939 loss = 0.000002
Epoch: 0939 loss = 0.000001
Epoch: 0940 loss = 0.000002
Epoch: 0940 loss = 0.000002
Epoch: 0941 loss = 0.000002
Epoch: 0941 loss = 0.000002
Epoch: 0942 loss = 0.000001
Epoch: 0942 loss = 0.000002
Epoch: 0943 loss = 0.000002
Epoch: 0943 loss = 0.000002
Epoch: 0944 loss = 0.000001
Epoch: 0944 loss = 0.000001
Epoch: 0945 loss = 0.000002
Epoch: 0945 loss = 0.000001
Epoch: 0946 loss = 0.000001
Epoch: 0946 loss = 0.000002
Epoch: 0947 loss = 0.000002
Epoch: 0947 loss = 0.000002
Epoch: 0948 loss = 0.000002
Epoch: 0948 loss = 0.000002
Epoch: 0949 loss = 0.000002
Epoch: 0949 loss = 0.000002
Epoch: 0950 loss = 0.000002
Epoch: 0950 loss = 0.000002
Epoch: 0951 loss = 0.000001
Epoch: 0951 loss = 0.000001
Epoch: 0952 loss = 0.000002
Epoch: 0952 loss = 0.000001
Epoch: 0953 loss = 0.000003
Epoch: 0953 loss = 0.000002
Epoch: 0954 loss = 0.000001
Epoch: 0954 loss = 0.000001
Epoch: 0955 loss = 0.000002
Epoch: 0955 loss = 0.000002
Epoch: 0956 loss = 0.000002
Epoch: 0956 loss = 0.000001
Epoch: 0957 loss = 0.000003
Epoch: 0957 loss = 0.000001
Epoch: 0958 loss = 0.000002
Epoch: 0958 loss = 0.000001
Epoch: 0959 loss = 0.000002
Epoch: 0959 loss = 0.000002
Epoch: 0960 loss = 0.000001
Epoch: 0960 loss = 0.000001
Epoch: 0961 loss = 0.000002
Epoch: 0961 loss = 0.000002
Epoch: 0962 loss = 0.000002
Epoch: 0962 loss = 0.000001
Epoch: 0963 loss = 0.000001
Epoch: 0963 loss = 0.000002
Epoch: 0964 loss = 0.000003
Epoch: 0964 loss = 0.000002
Epoch: 0965 loss = 0.000002
Epoch: 0965 loss = 0.000002
Epoch: 0966 loss = 0.000002
Epoch: 0966 loss = 0.000002
Epoch: 0967 loss = 0.000001
Epoch: 0967 loss = 0.000002
Epoch: 0968 loss = 0.000002
Epoch: 0968 loss = 0.000001
Epoch: 0969 loss = 0.000002
Epoch: 0969 loss = 0.000002
Epoch: 0970 loss = 0.000002
Epoch: 0970 loss = 0.000002
Epoch: 0971 loss = 0.000002
Epoch: 0971 loss = 0.000001
Epoch: 0972 loss = 0.000002
Epoch: 0972 loss = 0.000001
Epoch: 0973 loss = 0.000001
Epoch: 0973 loss = 0.000002
Epoch: 0974 loss = 0.000002
Epoch: 0974 loss = 0.000002
Epoch: 0975 loss = 0.000001
Epoch: 0975 loss = 0.000002
Epoch: 0976 loss = 0.000002
Epoch: 0976 loss = 0.000002
Epoch: 0977 loss = 0.000002
Epoch: 0977 loss = 0.000002
Epoch: 0978 loss = 0.000002
Epoch: 0978 loss = 0.000001
Epoch: 0979 loss = 0.000002
Epoch: 0979 loss = 0.000002
Epoch: 0980 loss = 0.000001
Epoch: 0980 loss = 0.000002
Epoch: 0981 loss = 0.000002
Epoch: 0981 loss = 0.000001
Epoch: 0982 loss = 0.000003
Epoch: 0982 loss = 0.000001
Epoch: 0983 loss = 0.000002
Epoch: 0983 loss = 0.000002
Epoch: 0984 loss = 0.000002
Epoch: 0984 loss = 0.000001
Epoch: 0985 loss = 0.000002
Epoch: 0985 loss = 0.000002
Epoch: 0986 loss = 0.000002
Epoch: 0986 loss = 0.000001
Epoch: 0987 loss = 0.000003
Epoch: 0987 loss = 0.000002
Epoch: 0988 loss = 0.000002
Epoch: 0988 loss = 0.000003
Epoch: 0989 loss = 0.000002
Epoch: 0989 loss = 0.000002
Epoch: 0990 loss = 0.000001
Epoch: 0990 loss = 0.000001
Epoch: 0991 loss = 0.000002
Epoch: 0991 loss = 0.000002
Epoch: 0992 loss = 0.000002
Epoch: 0992 loss = 0.000003
Epoch: 0993 loss = 0.000002
Epoch: 0993 loss = 0.000001
Epoch: 0994 loss = 0.000002
Epoch: 0994 loss = 0.000001
Epoch: 0995 loss = 0.000003
Epoch: 0995 loss = 0.000001
Epoch: 0996 loss = 0.000001
Epoch: 0996 loss = 0.000001
Epoch: 0997 loss = 0.000002
Epoch: 0997 loss = 0.000002
Epoch: 0998 loss = 0.000002
Epoch: 0998 loss = 0.000002
Epoch: 0999 loss = 0.000002
Epoch: 0999 loss = 0.000002
Epoch: 1000 loss = 0.000002
Epoch: 1000 loss = 0.000002

测试

def test(model, enc_input, start_symbol):
    enc_outputs, enc_self_attns = model.Encoder(enc_input)
    dec_input = torch.zeros(1,tgt_len).type_as(enc_input.data)
    next_symbol = start_symbol
    for i in range(0,tgt_len):
        dec_input[0][i] = next_symbol
        dec_outputs, _, _ = model.Decoder(dec_input,enc_input,enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[i]
        next_symbol = next_word.item()
    return dec_input
enc_inputs, _, _ = next(iter(loader))
predict_dec_input = test(model, enc_inputs[1].view(1, -1).cuda(), start_symbol=tgt_vocab["S"])
predict, _, _, _ = model(enc_inputs[1].view(1, -1).cuda(), predict_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print([src_idx2word[int(i)] for i in enc_inputs[1]], '->', 
[idx2word[n.item()] for n in predict.squeeze()])
['我', '是', '男', '生', 'P'] -> ['I', 'am', 'a', 'boy', 'E']

参考

  • https://wmathor.com/index.php/archives/1455/
  • https://zhuanlan.zhihu.com/p/166608727?utm_source=wechat_session&utm_medium=social&utm_oi=1101397910679302144&utm_campaign=shareopn
  • https://zhuanlan.zhihu.com/p/403433120

发表回复