【发布时间】:2021-06-07 13:25:41
【问题描述】:
我目前正在研究 Aeriel 视频上的人类行为识别。我正在使用这个dataset。您可以看到videos 和labels 文件。我正在构建一个 SSD 模型来训练数据..
我在使用 model.fit 时遇到错误。
我认为主要问题出在 DataGenerator 类中,尽管我无法解决该错误。代码如下
import numpy as np
import cv2
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import os
import json
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, img_dir, ann_dir,
batch_size=32, dim=(300,300),
shuffle=True):
self.idx_to_name = ["None", '"Handshaking"\n', '"Hugging"\n', '"Reading"\n', '"Drinking"\n', '"Pushing/Pulling"\n', '"Carrying"\n', '"Calling"\n', '"Running"\n', '"Walking"\n', '"Lying"\n', '"Sitting"\n', '"Standing"\n']
self.name_to_idx = dict([(v, k) for k, v in enumerate(self.idx_to_name)])
self.img_dir = img_dir
self.ann_dir = ann_dir
# self.frame = frame
self.batch_size = batch_size
self.dim = dim
def _get_annotation(self, file, j):
frame_map = dict()
with open(file, 'r') as fp:
line = fp.readline()
while line:
line_split = line.split(' ')
frame_id = int(line_split[5])
if line_split[10] is not None:
label = line_split[10]
else:
label = "None"
val = (int(line_split[0]), list(map(int, line_split[1:5])), list(map(int, line_split[6:8])), line_split[10])
if frame_id not in frame_map:
frame_map[frame_id] = [val]
else:
frame_map[frame_id].append(val)
line = fp.readline()
for obj in frame_map[int(j)]:
xmin = float(obj[1][0]) / 3840.0
ymin = float(obj[1][1]) / 2160.0
xmax = float(obj[1][2]) / 3840.0
ymax = float(obj[1][3]) / 2160.0
name = obj[3]
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.name_to_idx[name] + 1)
return np.array(boxes, dtype=np.float32), np.array(labels, dtype=np.int64)
def __getitem__(self, index):
start_index = index * self.batch_size
x_train = []
y_train = []
i = start_index - 1
while len(x_train) < self.batch_size:
try:
for i in os.listdir(self.img_dir):
for j in os.listdir(self.img_dir + '/' + i):
img = cv2.imread(self.img_dir + '/' +i + '/' + j)
img = cv2.resize(img,(320,240))
img = np.array(img, dtype = np.float32)
img = img / 255.0
boxes, labels = self._get_annotation(self.ann_dir + '/' + i + '.txt', int(j[:-4]))
x_train.append(img)
y_train.append(boxes)
i += 1
except Exception as err:
print(err)
continue
x_train = np.array(x_train)
y_train = np.array(y_train)
return x_train, y_train
train_data = DataGenerator("/content/okutama_imgs", "/content/okutama_labels", batch_size=4)
model.compile(optimizer=optimizer, loss= SSD_loss, metrics=['accuracy'])
model.fit(train_data, epochs = 50, verbose = 1, callbacks = callbacks)
运行 model.fit 时出现此错误。我不明白为什么会这样,如果您需要更多信息,我很乐意提供给您。这是错误
NotImplementedError Traceback (most recent call last)
<ipython-input-64-7c278c6b3232> in <module>()
----> 1 model.fit(train_data, epochs = 50, verbose = 1, callbacks = callbacks)
4 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/data_utils.py in __len__(self)
456 The number of batches in the Sequence.
457 """
--> 458 raise NotImplementedError
459
460 def on_epoch_end(self):
NotImplementedError:
【问题讨论】:
-
这个错误是因为你没有在继承的类中实现一些函数,在这种情况下我认为是
__len__函数。 -
嗨,你解决了这个问题吗?
标签: python tensorflow keras deep-learning computer-vision