基于莎士比亚作品数据集的循环神经网络(RNN)文本生成(3)

2周前 67次点击 来自 TensorFlow

收录专题: TensorFlow官方教程笔记

本文来自于官方教程循环神经网络(RNN)文本生成
许多更详细的细节请参考官方文档,本文只是笔者的阅读笔记。

6.生成文本

6.1 恢复最新的检查点

为保持此次预测步骤简单,将批大小设定为 1。

由于 RNN 状态从时间步传递到时间步的方式,模型建立好之后只接受固定的批大小。

若要使用不同的 batch_size 来运行模型,我们需要重建模型并从检查点中恢复权重。

tf.train.latest_checkpoint(checkpoint_dir)

model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
model.summary()

输出

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_1 (Embedding)      (1, None, 256)            16640     
_________________________________________________________________
gru_1 (GRU)                  (1, None, 1024)           3938304   
_________________________________________________________________
dense_1 (Dense)              (1, None, 65)             66625     
=================================================================
Total params: 4,021,569
Trainable params: 4,021,569
Non-trainable params: 0
_________________________________________________________________

6.2 预测循环

下面的代码块生成文本:

  • 首先设置起始字符串,初始化 RNN 状态并设置要生成的字符个数。
  • 用起始字符串和 RNN 状态,获取下一个字符的预测分布。
  • 然后,用分类分布计算预测字符的索引。把这个预测字符当作模型的下一个输入。
  • 模型返回的 RNN 状态被输送回模型。现在,模型有更多上下文可以学习,而非只有一个字符。在预测出下一个字符后,更改过的 RNN 状态被再次输送回模型。模型就是这样,通过不断从前面预测的字符获得更多上下文,进行学习。
def generate_text(model, start_string):
    # 评估步骤(用学习过的模型生成文本)

    # 要生成的字符个数
    num_generate = 1000

    # 将起始字符串转换为数字(向量化)
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)

    # 空字符串用于存储结果
    text_generated = []

    # 低温度会生成更可预测的文本
    # 较高温度会生成更令人惊讶的文本
    # 可以通过试验以找到最好的设定
    temperature = 1.0

    # 这里批大小为 1
    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
        # 删除批次的维度
        predictions = tf.squeeze(predictions, 0)

        # 用分类分布预测模型返回的字符
        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()

        # 把预测字符和前面的隐藏状态一起传递给模型作为下一个输入
        input_eval = tf.expand_dims([predicted_id], 0)

        text_generated.append(idx2char[predicted_id])

    return (start_string + ''.join(text_generated))

print(generate_text(model, start_string=u"ROMEO: "))

若想改进结果,最简单的方式是延长训练时间 (试试 EPOCHS=30)。

你还可以试验使用不同的起始字符串,或者尝试增加另一个 RNN 层以提高模型的准确率,亦或调整温度参数以生成更多或者更少的随机预测。

7. 高级:自定义训练

你将使用 tf.GradientTape 跟踪梯度。

步骤如下:

  • 首先,初始化 RNN 状态,使用 tf.keras.Model.reset_states 方法。
  • 然后,迭代数据集(逐批次)并计算每次迭代对应的 预测。
  • 打开一个 tf.GradientTape 并计算该上下文时的预测和损失。
  • 使用 tf.GradientTape.grads 方法,计算当前模型变量情况下的损失梯度。
  • 最后,使用优化器的 tf.train.Optimizer.apply_gradients 方法向下迈出一步。
model = build_model(
  vocab_size = len(vocab),
  embedding_dim=embedding_dim,
  rnn_units=rnn_units,
  batch_size=BATCH_SIZE)

optimizer = tf.keras.optimizers.Adam()

@tf.function
def train_step(inp, target):
  with tf.GradientTape() as tape:
    predictions = model(inp)
    loss = tf.reduce_mean(
        tf.keras.losses.sparse_categorical_crossentropy(
            target, predictions, from_logits=True))
  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  return loss

# 训练步骤
EPOCHS = 10

for epoch in range(EPOCHS):
  start = time.time()

  # 在每个训练周期开始时,初始化隐藏状态
  # 隐藏状态最初为 None
  hidden = model.reset_states()

  for (batch_n, (inp, target)) in enumerate(dataset):
    loss = train_step(inp, target)

    if batch_n % 100 == 0:
      template = 'Epoch {} Batch {} Loss {}'
      print(template.format(epoch+1, batch_n, loss))

  # 每 5 个训练周期,保存(检查点)1 次模型
  if (epoch + 1) % 5 == 0:
    model.save_weights(checkpoint_prefix.format(epoch=epoch))

  print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))
  print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

model.save_weights(checkpoint_prefix.format(epoch=epoch))
Card image cap
开发者雷

尘世间一个小小的开发者,每天增加一些无聊的知识,就不会无聊了

要加油~~~

技术文档 >> 系列应用 >>
热推应用
Let'sLearnSwift
学习Swift的入门教程
PyPie
Python is as good as Pie
标签