【发布时间】:2019-11-22 18:59:35
【问题描述】:
我正在使用 scikit-learn 的 Mean Shift 算法来执行图像分割。我有以下代码:
import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from itertools import cycle
from PIL import Image
image = Image.open('sample_images/fruit.png').convert('RGB')
image = np.array(image)
red = image[:,:,0]
green = image[:,:,1]
blue = image[:,:,2]
flat_red = red.flatten()
flat_green = green.flatten()
flat_blue = blue.flatten()
flattened = np.stack((flat_red, flat_green, flat_blue))
ms_clf = MeanShift(bin_seeding=True)
ms_labels = ms_clf.fit_predict(flattened)
plt.imshow(np.reshape(ms_labels, [1001, 994]))
我有一个扁平的颜色矩阵,其尺寸为 3x994994,因此总共有 2984982 个样本。
print(flattened.shape)
(3, 994994)
print(flattened)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
这个展平矩阵用作 MeanShift fit_predict() 函数的输入。当我尝试打印 fit_predict() 返回的标签数组时,我得到以下输出:
print(ms_labels)
[0 1 2]
fit_predict() 函数不是为每个数据样本返回一个标签吗?为什么我只得到一个包含 3 个元素的数组?任何见解都值得赞赏。
【问题讨论】:
标签: python numpy opencv scikit-learn computer-vision