【问题标题】:Set the number of iterations gpt-2设置迭代次数 gpt-2
【发布时间】:2020-01-06 23:51:34
【问题描述】:

我正在按照本教程微调 gpt-2 模型:

https://medium.com/@ngwaifoong92/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f

与其关联的 GitHub 存储库:

https://github.com/nshepperd/gpt-2

我已经能够复制这些示例,我的问题是我没有找到用于设置迭代次数的参数。 基本上,训练脚本每 100 次迭代显示一个样本,并每 1000 次迭代保存一个模型版本。但是我没有找到一个参数来训练它,比如 5000 次迭代然后关闭它。

训练脚本在这里: https://github.com/nshepperd/gpt-2/blob/finetuning/train.py

编辑:

正如 cronoik 所建议的,我正在尝试将 while 替换为 for 循环。

我正在添加这些更改:

  1. 添加一个额外的参数:

    parser.add_argument('--training_steps', metavar='STEPS', type=int, default=1000, help='一个代表模型训练步数的数字')

  2. 改变循环:

     try:
         for iter_count in range(training_steps):
             if counter % args.save_every == 0:
                 save()
    
  3. 使用新参数:

    python3 train.py --training_steps 300

但是我收到了这个错误:

  File "train.py", line 259, in main
    for iter_count in range(training_steps):
NameError: name 'training_steps' is not defined

【问题讨论】:

  • 它应该是for iter_count in range(args.training_steps) 而不是for iter_count in range(training_steps),因为您添加了另一个参数,它是args 的成员。

标签: python tensorflow nlp gpt-2


【解决方案1】:

您所要做的就是将while True 循环修改为for 循环:

try:
    #replaced
    #while True:
    for i in range(5000):
        if counter % args.save_every == 0:
            save()
        if counter % args.sample_every == 0:
            generate_samples()
        if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
            validation()

        if args.accumulate_gradients > 1:
            sess.run(opt_reset)
            for _ in range(args.accumulate_gradients):
                sess.run(
                    opt_compute, feed_dict={context: sample_batch()})
            (v_loss, v_summary) = sess.run((opt_apply, summaries))
        else:
            (_, v_loss, v_summary) = sess.run(
                (opt_apply, loss, summaries),
                feed_dict={context: sample_batch()})

        summary_log.add_summary(v_summary, counter)

        avg_loss = (avg_loss[0] * 0.99 + v_loss,
                    avg_loss[1] * 0.99 + 1.0)

        print(
            '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
            .format(
                counter=counter,
                time=time.time() - start_time,
                loss=v_loss,
                avg=avg_loss[0] / avg_loss[1]))

        counter += 1
except KeyboardInterrupt:
    print('interrupted')
    save()

【讨论】:

  • 我刚刚根据您的建议更新了问题
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2023-03-05
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多