基于莎士比亚作品数据集的循环神经网络(RNN)文本生成(2)

3个月前 166次点击 来自 TensorFlow

收录专题: TensorFlow官方教程笔记

本文来自于官方教程循环神经网络(RNN)文本生成
许多更详细的细节请参考官方文档,本文只是笔者的阅读笔记。

上文已经处理完数据,并打包成batch可供模型处理。

3.创建模型

  • tf.keras.layers.Embedding:输入层。一个可训练的对照表,它会将每个字符的数字映射到一个 embedding_dim 维度的向量。
  • tf.keras.layers.GRU:一种 RNN 类型,其大小由 units=rnn_units 指定(这里你也可以使用一个 LSTM 层)。
  • tf.keras.layers.Dense:输出层,带有 vocab_size 个输出。
# 词集的长度
vocab_size = len(vocab)
# 嵌入的维度
embedding_dim = 256
# RNN 的单元数量
rnn_units = 1024

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(vocab_size, embedding_dim,
                                  batch_input_shape=[batch_size, None]),
        tf.keras.layers.GRU(rnn_units,
                            return_sequences=True,
                            stateful=True,
                            recurrent_initializer='glorot_uniform'),
        tf.keras.layers.Dense(vocab_size)
    ])
    return model

model = build_model(
    vocab_size=len(vocab),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units,
    batch_size=BATCH_SIZE)

对于每个字符,模型会查找嵌入,把嵌入当作输入运行 GRU 一个时间步,并用密集层生成逻辑回归 (logits),预测下一个字符的对数可能性

4.初次运行模型

现在运行这个模型,看看它是否按预期运行。

首先检查输出的形状:

for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

输出

(64, 100, 65) # (batch_size, sequence_length, vocab_size)

每个batch会打包64个句子,每个句子的维度是100个字符,每个句子接下来预测的字符概率使用65维来标识

查看下模型结构

model.summary()

输出

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (64, None, 256)           16640     
_________________________________________________________________
gru (GRU)                    (64, None, 1024)          3938304   
_________________________________________________________________
dense (Dense)                (64, None, 65)            66625     
=================================================================
Total params: 4,021,569
Trainable params: 4,021,569
Non-trainable params: 0
_________________________________________________________________

为了获得模型的实际预测,我们需要从输出分布中抽样,以获得实际的字符索引。这个分布是根据对字符集的逻辑回归定义的。试试这个批次中的第一个样本:

sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

sampled_indices中的数据为

array([ 3, 19, 11,  8, 17, 50, 14,  5, 16, 57, 51, 53, 17, 54,  9, 11, 22,
       13, 36, 57, 57, 50, 47, 22,  5,  7,  1, 59,  3, 26, 52,  2, 62, 30,
       54, 18, 62,  9, 63,  2, 22, 11, 18, 12, 63,  0, 13, 16, 38, 49, 21,
       25, 22, 53, 39, 63,  3, 26, 39, 15, 21, 56, 49, 39, 20, 55,  5, 39,
       61, 29, 21, 39, 39, 63, 48, 11, 27, 42, 59,  0, 19, 58, 57, 27, 40,
       13, 53, 13,  7,  4, 21, 32, 10, 57, 18, 30, 54, 36, 12,  3])

解码它们,以查看此未经训练的模型预测的文本:

print("Input: \n", repr("".join(idx2char[input_example_batch[0]])))
print()
print("Output: \n", repr("".join(idx2char[target_example_batch[0]])))
print()
print("Predictions: \n", repr("".join(idx2char[sampled_indices])))

输出

Input: 
 '\nDeliver them this paper: having read it,\nBid them repair to the market place; where I,\nEven in thei'

Output: 
 'Deliver them this paper: having read it,\nBid them repair to the market place; where I,\nEven in their'

Predictions: 
 "JDwn,JgN$oyv$oSLNqqck&'Z$.bVmbx$q3UassTOVW-kGgwHu:D  hkvHlg'pZvcNveg-MScS!WUAcWXg'VkLc:JdrgphTLDRhQU"

可以察觉,输出效果并不好,因为模型并没有经过训练。

5.训练

5.1 定义损失函数与添加优化器

使用 tf.keras.Model.compile 方法配置训练步骤。我们将使用 tf.keras.optimizers.Adam 并采用默认参数,以及损失函数。

def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)


example_batch_loss = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("scalar_loss:      ", example_batch_loss.numpy().mean())

model.compile(optimizer='adam', loss=loss)

5.2 配置检查点

# 检查点保存至的目录
checkpoint_dir = './training_checkpoints'

# 检查点的文件名
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

5.3 执行训练

EPOCHS = 10
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Card image cap
开发者雷

尘世间一个小小的开发者,每天增加一些无聊的知识,就不会无聊了

要加油~~~

技术文档 >> 系列应用 >>
热推应用
Let'sLearnSwift
学习Swift的入门教程
PyPie
Python is as good as Pie
标签