以下只是您可以尝试的概念。是的,你也可以为此训练一个神经网络,但如果你的设置足够简单,一些工程就可以了。
我们首先创建一个“玩具视频”来使用:
import numpy as np
import matplotlib.pyplot as plt
# Create a toy "video"
image = np.asarray([
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 2, 1],
[0, 0, 2, 4, 4, 2],
[0, 0, 2, 4, 4, 2],
[0, 0, 1, 2, 2, 1],
])
signal = np.asarray([0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0])
x = list(range(len(signal)))
signal = np.interp(np.linspace(0, len(signal), 100), x, signal)[..., None]
frames = np.einsum('tk,xy->txyk', signal, image)[..., 0]
绘制几帧:
fig, axes = plt.subplots(1, 12, sharex='all', sharey='all')
for i, ax in enumerate(axes):
ax.matshow(frames[i], vmin=0, vmax=1)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.set_title(i)
plt.show()
现在您有了这种玩具视频,将其转换回某种二进制信号非常简单。您只需计算每帧的平均亮度:
reconstructed = frames.mean(1).mean(1)
reconstructed_bin = reconstructed > 0.5
plt.plot(reconstructed, label='original')
plt.plot(reconstructed_bin, label='binary')
plt.title('Reconstructed Signal')
plt.legend()
plt.show()
从这里我们只需要确定每次闪光的长度。
# This is ugly, I know. Just for understanding though:
# 1. Splits the binary signal on zero-values
# 2. Filters out the garbage (accept only lists where len(e) > 1)
# 3. Gets the length of the remaining list == the duration of each flash
tmp = np.split(reconstructed_bin, np.where(reconstructed_bin == 0)[0][1:])
flashes = list(map(len, filter(lambda e: len(e) > 1, tmp)))
我们现在可以看看闪烁需要多长时间:
print(flashes)
给我们
[5, 5, 5, 10, 9, 9, 5, 5, 5]
所以..“短”闪光似乎需要 5 帧,“长”大约需要 10 帧。这样,我们可以通过定义一个合理的阈值 7 来将每个闪光分类为“长”或“短”,如下所示:
# Classify each flash-duration
flashes_classified = list(map(lambda f: 'long' if f > 7 else 'short', flashes))
让我们重复暂停
# Repeat for pauses
tmp = np.split(reconstructed_bin, np.where(reconstructed_bin != False)[0][1:])
pauses = list(map(len, filter(lambda e: len(e) > 1, tmp)))
pauses_classified = np.asarray(list(map(lambda f: 'w' if f > 6 else 'c', pauses)))
pauses_indices, = np.where(np.asarray(pauses_classified) == 'w')
现在我们可以将结果可视化。
fig = plt.figure()
ax = fig.gca()
ax.bar(range(len(flashes)), flashes, label='Flash duration')
ax.set_xticks(list(range(len(flashes_classified))))
ax.set_xticklabels(flashes_classified)
[ax.axvline(idx-0.5, ls='--', c='r', label='Pause' if i == 0 else None) for i, idx in enumerate(pauses_indices)]
plt.legend()
plt.show()