【问题标题】:Keras seems to hang after call to fit_generator调用 fit_generator 后,Keras 似乎挂起
【发布时间】:2018-07-04 15:07:22
【问题描述】:

我正在尝试将 SqueezeDet model 的 Keras 实现适应新数据集。在对我的配置文件进行适当的更改后,我尝试运行 train 脚本,但它似乎在调用 fit_generator() 后挂起。当我得到以下输出时:

/anaconda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Number of images: 536
Number of epochs: 100
Number of batches: 53
Batch size: 10
2018-07-04 14:18:49.711606: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-07-04 14:18:54.080912: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1356] Found device 0 with properties:
name: Tesla K80 major: 3 minor: 7 memoryClockRate(GHz): 0.8235
pciBusID: 52a9:00:00.0
totalMemory: 11.17GiB freeMemory: 11.10GiB
2018-07-04 14:18:54.080958: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1435] Adding visible gpu devices: 0
2018-07-04 14:18:54.333214: I tensorflow/core/common_runtime/gpu/gpu_device.cc:923] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-07-04 14:18:54.333270: I tensorflow/core/common_runtime/gpu/gpu_device.cc:929]      0
2018-07-04 14:18:54.333290: I tensorflow/core/common_runtime/gpu/gpu_device.cc:942] 0:   N
2018-07-04 14:18:54.333559: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10764 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 52a9:00:00.0, compute capability: 3.7)
Learning rate: 0.01
Weights initialized by name from ../main/model/imagenet.h5
Using single GPU
Backend Qt5Agg is interactive backend. Turning interactive mode on.
Epoch 1/100

然后,即使将其搁置一天,也不会发生任何事情。它似乎冻结的电话是:

squeeze.model.fit_generator(train_generator, epochs=EPOCHS, verbose=1,
                            steps_per_epoch=nbatches_train, callbacks=cb)

参数在哪里:

train_generator = generator_from_data_path(img_names, gt_names, config=cfg)
EPOCHS = 100
nbatches_train  = 53
callbacks = [# TensorBoard object, ReduceLROnPlateau object, ModelCheckpoint object #]

我的版本:

Python 3.5.4 :: Anaconda custom (64-bit)
tensorflow-gpu : 1.8.0
tensorflow : 1.8.0
Keras : 2.2.0

【问题讨论】:

  • 移除 TensorBoard 回调并重试。
  • train_generator[0]return 什么? len(train_generator) 会返回什么吗?
  • @wl2776 不起作用,因为 train_generator 是 threadsafe_iter 对象 train_generator[0] > TypeError("'threadsafe_iter' object does not support indexing",)len(train_generator) > TypeError("object of type 'threadsafe_iter' has no len()",)。但是,我认为您正在做某事,如果我尝试使用 next(train_generator) 访问第一个元素,它会挂在该行上。
  • @MatiasValdenegro 删除回调没有任何区别。
  • 好的,那么你应该添加更多信息(代码),比如什么是 train_generator。

标签: python python-3.x tensorflow keras


【解决方案1】:

在 cmets 中格式化对话以进行回答。

罪魁祸首是train_generator

前段时间我在 Keras 中查看了 model.fit_generator 的来源。它只是从生成器中检索一些数据并将其提交给后端,没什么神奇的:)

所以,我的假设是它无法从生成器中检索数据,因为生成器不生成任何东西。

@Barker 已确认,称对 next(train_generator) 的呼叫挂起。

我个人已经搬到keras.utils.Sequence,它支持索引和长度,比普通的生成器方便得多。尽管此注释与当前问题无关。

【讨论】:

  • 此外,keras.utils.Sequence 是线程安全的,这是一个巨大的优势。
猜你喜欢
  • 2015-06-19
  • 1970-01-01
  • 1970-01-01
  • 2018-11-27
  • 2018-02-22
  • 2016-05-15
  • 1970-01-01
  • 2018-05-11
  • 2020-04-25
相关资源
最近更新 更多