【问题标题】:Scikit-learn: Understanding MeanShift fit_predict()Scikit-learn:理解 MeanShift fit_predict()
【发布时间】:2019-11-22 18:59:35
【问题描述】:

我正在使用 scikit-learn 的 Mean Shift 算法来执行图像分割。我有以下代码:

import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from itertools import cycle
from PIL import Image

image = Image.open('sample_images/fruit.png').convert('RGB')
image = np.array(image)

red = image[:,:,0]
green = image[:,:,1]
blue = image[:,:,2]

flat_red = red.flatten()
flat_green = green.flatten()
flat_blue = blue.flatten()

flattened = np.stack((flat_red, flat_green, flat_blue))

ms_clf = MeanShift(bin_seeding=True)
ms_labels = ms_clf.fit_predict(flattened)
plt.imshow(np.reshape(ms_labels, [1001, 994]))

我有一个扁平的颜色矩阵,其尺寸为 3x994994,因此总共有 2984982 个样本。

print(flattened.shape)
(3, 994994)

print(flattened)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]

这个展平矩阵用作 MeanShift fit_predict() 函数的输入。当我尝试打印 fit_predict() 返回的标签数组时,我得到以下输出:

print(ms_labels)
[0 1 2]

fit_predict() 函数不是为每个数据样本返回一个标签吗?为什么我只得到一个包含 3 个元素的数组?任何见解都值得赞赏。

【问题讨论】:

    标签: python numpy opencv scikit-learn computer-vision


    【解决方案1】:

    fit_predict() 的文档说它将形状为 (n_samples, n_features) 的 X 作为输入,并返回形状为 (n_samples,) 的标签。由于您输入的是 3x994994 数组,其中 n_samples=3 和 n_features=994994,这意味着标签将是 (3,) 数组,如您所见。它本质上是将“扁平化”中的每个图像通道视为一个数据。

    【讨论】:

      猜你喜欢
      • 2013-10-02
      • 2016-09-03
      • 2019-01-15
      • 1970-01-01
      • 2012-10-15
      • 2020-04-04
      • 2015-02-12
      • 2018-03-06
      • 1970-01-01
      相关资源
      最近更新 更多