【问题标题】:Understanding the standardisation process of KNN了解KNN的标准化过程
【发布时间】:2023-03-26 03:29:01
【问题描述】:

所以我在理解这个 KNN 分类器的标准化过程时遇到了一些麻烦。基本上我需要知道标准化过程中发生了什么。如果有人可以提供帮助,将不胜感激。我知道有一个由“训练示例”组成的均值和标准的变量,但在那之后实际发生的事情是我遇到的困难。

classdef myknn
methods(Static)

                %the function m calls the train examples, train labels
                %and the no. of nearest neighbours.
    function m = fit(train_examples, train_labels, k)

            % start of standardisation process
        m.mean = mean(train_examples{:,:});  %mean variable
        m.std = std(train_examples{:,:}); %standard deviation variable
        for i=1:size(train_examples,1)
            train_examples{i,:} = train_examples{i,:} - m.mean;
            train_examples{i,:} = train_examples{i,:} ./ m.std;
        end
            % end of standardisation process

        m.train_examples = train_examples;
        m.train_labels = train_labels;
        m.k = k;

    end

    function predictions = predict(m, test_examples)

        predictions = categorical;

        for i=1:size(test_examples,1)

            fprintf('classifying example example %i/%i\n', i, size(test_examples,1));

            this_test_example = test_examples{i,:};

            % start of standardisation process
            this_test_example = this_test_example - m.mean;
            this_test_example = this_test_example ./ m.std;
            % end of standardisation process

            this_prediction = myknn.predict_one(m, this_test_example);
            predictions(end+1) = this_prediction;

        end

    end

    function prediction = predict_one(m, this_test_example)

        distances = myknn.calculate_distances(m, this_test_example);
        neighbour_indices = myknn.find_nn_indices(m, distances);
        prediction = myknn.make_prediction(m, neighbour_indices);

    end

    function distances = calculate_distances(m, this_test_example)

        distances = [];

        for i=1:size(m.train_examples,1)

            this_training_example = m.train_examples{i,:};
            this_distance = myknn.calculate_distance(this_training_example, this_test_example);
            distances(end+1) = this_distance;
        end

    end

    function distance = calculate_distance(p, q)

        differences = q - p;
        squares = differences .^ 2;
        total = sum(squares);
        distance = sqrt(total);

    end

    function neighbour_indices = find_nn_indices(m, distances)

        [sorted, indices] = sort(distances);
        neighbour_indices = indices(1:m.k);

    end

    function prediction = make_prediction(m, neighbour_indices)

        neighbour_labels = m.train_labels(neighbour_indices);
        prediction = mode(neighbour_labels);

    end

end

结束

【问题讨论】:

    标签: matlab machine-learning knn


    【解决方案1】:

    标准化是对训练示例中的每个特征进行归一化的过程,使每个特征的均值为零,标准差为 1。执行此操作的过程是找到每个特征的平均值和每个特征的标准差。之后,我们将每个特征减去对应的均值,然后除以对应的标准差。

    通过这段代码可以清楚地看到:

        m.mean = mean(train_examples{:,:});  %mean variable
        m.std = std(train_examples{:,:}); %standard deviation variable
        for i=1:size(train_examples,1)
            train_examples{i,:} = train_examples{i,:} - m.mean;
            train_examples{i,:} = train_examples{i,:} ./ m.std;
        end
    

    m.mean 记住每个特征的平均值,而m.std 记住每个特征的标准差。请注意,当您想在测试时执行分类时,您必须记住这两个。这可以通过您使用的predict 方法看出,它从训练示例中获取测试特征并减去每个特征的均值和标准差。

    function predictions = predict(m, test_examples)
    
        predictions = categorical;
    
        for i=1:size(test_examples,1)
    
            fprintf('classifying example example %i/%i\n', i, size(test_examples,1));
    
            this_test_example = test_examples{i,:};
    
            % start of standardisation process
            this_test_example = this_test_example - m.mean;
            this_test_example = this_test_example ./ m.std;
            % end of standardisation process
    
            this_prediction = myknn.predict_one(m, this_test_example);
            predictions(end+1) = this_prediction;
    
        end
    

    请注意,我们在测试示例中使用了 m.meanm.std,这些数量来自训练示例。

    我关于标准化的帖子应该提供更多背景信息。此外,它实现了与您提供的代码相同的效果,但以更加矢量化的方式:How does this code for standardizing data work?

    【讨论】:

      猜你喜欢
      • 2019-07-31
      • 2014-10-04
      • 1970-01-01
      • 1970-01-01
      • 2014-07-22
      • 2014-11-14
      • 1970-01-01
      • 2013-01-11
      • 1970-01-01
      相关资源
      最近更新 更多