【问题标题】:What's wrong with my implementation of the nearest neighbour algorithm (for the TSP)?我的最近邻算法(对于 TSP)的实现有什么问题?
【发布时间】:2021-10-21 23:17:16
【问题描述】:

我的任务是实施nearest neighbour algorithm for the travelling salesman problem。据说该方法应该尝试从每个城市开始,并返回找到的最佳游览。根据自动标记程序,我的实现在最基本的情况下可以正常工作,但对于所有更高级的情况只能部分工作。

我不明白我哪里出错了,正在寻求对我的代码的正确性进行审查。我很想知道我哪里出了问题以及正确的方法是什么。

我的Java代码如下:

/*
 * Returns the shortest tour found by exercising the NN algorithm 
 * from each possible starting city in table.
 * table[i][j] == table[j][i] gives the cost of travel between City i and City j.
 */
 public static int[] tspnn(double[][] table) {
     
     // number of vertices 
     int numberOfVertices = table.length;
     // the Hamiltonian cycle built starting from vertex i
     int[] currentHamiltonianCycle = new int[numberOfVertices];
     // the lowest total cost Hamiltonian cycle
     double lowestTotalCost = Double.POSITIVE_INFINITY;
     //  the shortest Hamiltonian cycle
     int[] shortestHamiltonianCycle = new int[numberOfVertices];
     
     // consider each vertex i as a starting point
     for (int i = 0; i < numberOfVertices; i++) {
         /* 
          * Consider all vertices that are reachable from the starting point i,
          * thereby creating a new current Hamiltonian cycle.
          */
         for (int j = 0; j < numberOfVertices; j++) {
             /* 
              * The modulo of the sum of i and j allows us to account for the fact 
              * that Java indexes arrays from 0.
              */
             currentHamiltonianCycle[j] = (i + j) % numberOfVertices;   
         }
         for (int j = 1; j < numberOfVertices - 1; j++) {
             int nextVertex = j;
             for (int p = j + 1; p < numberOfVertices; p++) {
                 if (table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[p]] < table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[nextVertex]]) {
                           nextVertex = p;
                 }
             }
             
             int a = currentHamiltonianCycle[nextVertex];
             currentHamiltonianCycle[nextVertex] = currentHamiltonianCycle[j];
             currentHamiltonianCycle[j] = a;
         }
         
         /*
          * Find the total cost of the current Hamiltonian cycle.
          */
         double currentTotalCost = table[currentHamiltonianCycle[0]][currentHamiltonianCycle[numberOfVertices - 1]];
         for (int z = 0; z < numberOfVertices - 1; z++) {
             currentTotalCost += table[currentHamiltonianCycle[z]][currentHamiltonianCycle[z + 1]];
         }
         
         if (currentTotalCost < lowestTotalCost) {
             lowestTotalCost = currentTotalCost;
             shortestHamiltonianCycle = currentHamiltonianCycle;
         }
     }
     return shortestHamiltonianCycle;
 }

编辑

作为一个简单的例子,我已经用笔和纸浏览了这段代码,我找不到算法实现的任何问题。基于此,在我看来它应该适用于一般情况。


编辑 2

我已经使用以下模拟示例测试了我的实现:

double[][] table = {{0, 2.3, 1.8, 4.5}, {2.3, 0, 0.4, 0.1}, 
                {1.8, 0.4, 0, 1.3}, {4.5, 0.1, 1.3, 0}}; 

它似乎产生了最近邻算法的预期输出,即 3 -> 1 -> 2 -> 0

我现在想知道自动标记程序是否不正确,或者只是我的实现在一般情况下不起作用。

【问题讨论】:

  • 第一个提示:有一个选择起点的方法。然后有一个基于所选起点运行算法的方法。这样你就可以用两种不同的方法来保存两种不同的东西。第二个提示:尽可能使用面向对象的代码。使用数组索引四处寻找会使事情变得非常难以编写+理解+维护。第三:我假设您总是选择相同的顶点序列,只是具有不同的起始索引。第四:你总是重复使用currentHamiltonianCycle 数组,从而覆盖旧的结果。您应该在 for i 循环中分配它。
  • @JayC667 我想将 OOP 原理与单独的方法一起使用,但有人告诉我,它必须全部写在一个用于自动标记的方法中,所以这就是它这样写的原因。跨度>
  • @JayC667 嗯,我不确定我理解你的意思。 currentHamiltonianCycle 在开始时分配一次,然后在 for i 循环中分配 shortestHamiltonianCycle = currentHamiltonianCycle;
  • 没错。并且= 运算符不会复制数组内容,它只分配对数组的引用,因此您总是在处理同一个数组,覆盖以前的方式/结果。如果您想复制其值,请使用 System.arrayCopy() 之类的内容。
  • 我不知道,tbh。对不起。如果没有深入分析并在调试模式下运行它,这段代码对我来说太复杂了,我无法理解。这一切看起来就像一个数学家的工作。从设置,你如何做事,你似乎正在使用auto-marker,有奇怪的限制,但你决定使用 Java(它只比 C 或 C++ 或任何其他没有数组索引的语言快 50%检查)。因此,这一切与我所知道的以及从一开始就可以看到的解决方案相去甚远。

标签: java algorithm nearest-neighbor traveling-salesman


【解决方案1】:

正如我在 cmets 中所述,我发现算法本身存在一个基本问题

  • 它不会正确排列城镇,但总是按顺序工作(A-B-C-D-A-B-C-D,从任何地方开始,然后取 4 个)

为了证明这个问题,我编写了以下代码来测试和设置简单和高级的示例。

  • 请先通过static public final 常量配置它,然后再更改代码本身。
  • 专注于简单示例:如果算法运行良好,则结果将始终为 A-B-C-D 或 D-C-B-A。
  • 但您可以通过输出观察到,该算法不会选择(全球)最佳游览,因为它对测试城镇的排列是错误的。

我已经添加了我自己的面向对象的实现来展示:

  • 选择问题,在 ONE 方法中很难一次性正确完成
  • OO 风格有何优势
  • 正确的测试/开发非常容易设置和执行(我什至没有在这里使用单元测试,这将是验证/验证算法的下一步)

代码:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;

public class TSP_NearestNeighbour {



    static public final int NUMBER_OF_TEST_RUNS = 4;

    static public final boolean GENERATE_SIMPLE_TOWNS = true;

    static public final int NUMBER_OF_COMPLEX_TOWNS         = 10;
    static public final int DISTANCE_RANGE_OF_COMPLEX_TOWNS = 100;



    static private class Town {
        public final String Name;
        public final int    X;
        public final int    Y;
        public Town(final String pName, final int pX, final int pY) {
            Name = pName;
            X = pX;
            Y = pY;
        }
        public double getDistanceTo(final Town pOther) {
            final int dx = pOther.X - X;
            final int dy = pOther.Y - Y;
            return Math.sqrt(Math.abs(dx * dx + dy * dy));
        }
        @Override public int hashCode() { // not really needed here
            final int prime = 31;
            int result = 1;
            result = prime * result + X;
            result = prime * result + Y;
            return result;
        }
        @Override public boolean equals(final Object obj) {
            if (this == obj) return true;
            if (obj == null) return false;
            if (getClass() != obj.getClass()) return false;
            final Town other = (Town) obj;
            if (X != other.X) return false;
            if (Y != other.Y) return false;
            return true;
        }
        @Override public String toString() {
            return Name + " (" + X + "/" + Y + ")";
        }
    }

    static private double[][] generateDistanceTable(final ArrayList<Town> pTowns) {
        final double[][] ret = new double[pTowns.size()][pTowns.size()];
        for (int outerIndex = 0; outerIndex < pTowns.size(); outerIndex++) {
            final Town outerTown = pTowns.get(outerIndex);

            for (int innerIndex = 0; innerIndex < pTowns.size(); innerIndex++) {
                final Town innerTown = pTowns.get(innerIndex);

                final double distance = outerTown.getDistanceTo(innerTown);
                ret[outerIndex][innerIndex] = distance;
            }
        }
        return ret;
    }



    static private ArrayList<Town> generateTowns_simple() {
        final Town a = new Town("A", 0, 0);
        final Town b = new Town("B", 1, 0);
        final Town c = new Town("C", 2, 0);
        final Town d = new Town("D", 3, 0);
        return new ArrayList<>(Arrays.asList(a, b, c, d));
    }
    static private ArrayList<Town> generateTowns_complex() {
        final ArrayList<Town> allTowns = new ArrayList<>();
        for (int i = 0; i < NUMBER_OF_COMPLEX_TOWNS; i++) {
            final int randomX = (int) (Math.random() * DISTANCE_RANGE_OF_COMPLEX_TOWNS);
            final int randomY = (int) (Math.random() * DISTANCE_RANGE_OF_COMPLEX_TOWNS);
            final Town t = new Town("Town-" + (i + 1), randomX, randomY);
            if (allTowns.contains(t)) { // do not allow different towns at same location!
                System.out.println("Towns colliding at " + t);
                --i;
            } else {
                allTowns.add(t);
            }
        }
        return allTowns;
    }
    static private ArrayList<Town> generateTowns() {
        if (GENERATE_SIMPLE_TOWNS) return generateTowns_simple();
        else return generateTowns_complex();
    }



    static private void printTowns(final ArrayList<Town> pTowns, final double[][] pDistances) {
        System.out.println("Towns:");
        for (final Town town : pTowns) {
            System.out.println("\t" + town);
        }

        System.out.println("Distance Matrix:");
        for (int y = 0; y < pDistances.length; y++) {
            System.out.print("\t");
            for (int x = 0; x < pDistances.length; x++) {
                System.out.print(pDistances[y][x] + " (" + pTowns.get(y).Name + "-" + pTowns.get(x).Name + ")" + "\t");
            }
            System.out.println();
        }
    }



    private static void testAlgorithm() {
        final ArrayList<Town> towns = generateTowns();

        for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
            final double[][] distances = generateDistanceTable(towns);
            printTowns(towns, distances);

            {
                final int[] path = tspnn(distances);
                System.out.println("tspnn Path:");
                for (int pathIndex = 0; pathIndex < path.length; pathIndex++) {
                    final Town t = towns.get(pathIndex);
                    System.out.println("\t" + t);
                }
            }
            {
                final ArrayList<Town> path = tspnn_simpleNN(towns);
                System.out.println("tspnn_simpleNN Path:");
                for (final Town t : path) {
                    System.out.println("\t" + t);
                }
                System.out.println("\n");
            }

            // prepare for for next run. We do this at the end of the loop so we can only print first config
            Collections.shuffle(towns);
        }

    }

    public static void main(final String[] args) {
        testAlgorithm();
    }



    /*
     * Returns the shortest tour found by exercising the NN algorithm
     * from each possible starting city in table.
     * table[i][j] == table[j][i] gives the cost of travel between City i and City j.
     */
    public static int[] tspnn(final double[][] table) {

        // number of vertices
        final int numberOfVertices = table.length;
        // the Hamiltonian cycle built starting from vertex i
        final int[] currentHamiltonianCycle = new int[numberOfVertices];
        // the lowest total cost Hamiltonian cycle
        double lowestTotalCost = Double.POSITIVE_INFINITY;
        //  the shortest Hamiltonian cycle
        int[] shortestHamiltonianCycle = new int[numberOfVertices];

        // consider each vertex i as a starting point
        for (int i = 0; i < numberOfVertices; i++) {
            /*
             * Consider all vertices that are reachable from the starting point i,
             * thereby creating a new current Hamiltonian cycle.
             */
            for (int j = 0; j < numberOfVertices; j++) {
                /*
                 * The modulo of the sum of i and j allows us to account for the fact
                 * that Java indexes arrays from 0.
                 */
                currentHamiltonianCycle[j] = (i + j) % numberOfVertices;
            }
            for (int j = 1; j < numberOfVertices - 1; j++) {
                int nextVertex = j;
                for (int p = j + 1; p < numberOfVertices; p++) {
                    if (table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[p]] < table[currentHamiltonianCycle[j - 1]][currentHamiltonianCycle[nextVertex]]) {
                        nextVertex = p;
                    }
                }

                final int a = currentHamiltonianCycle[nextVertex];
                currentHamiltonianCycle[nextVertex] = currentHamiltonianCycle[j];
                currentHamiltonianCycle[j] = a;
            }

            /*
             * Find the total cost of the current Hamiltonian cycle.
             */
            double currentTotalCost = table[currentHamiltonianCycle[0]][currentHamiltonianCycle[numberOfVertices - 1]];
            for (int z = 0; z < numberOfVertices - 1; z++) {
                currentTotalCost += table[currentHamiltonianCycle[z]][currentHamiltonianCycle[z + 1]];
            }

            if (currentTotalCost < lowestTotalCost) {
                lowestTotalCost = currentTotalCost;
                shortestHamiltonianCycle = currentHamiltonianCycle;
            }
        }
        return shortestHamiltonianCycle;
    }



    /**
     * Here come my basic implementations.
     * They can be heavily (heavily!) improved, but are verbose and direct to show the logic behind them
     */



    /**
     * <p>example how to implement the NN solution th OO way</p>
     * we could also implement
     * <ul>
     * <li>a recursive function</li>
     * <li>or one with running counters</li>
     * <li>or one with a real map/route objects, where further optimizations can take place</li>
     * </ul>
     */
    public static ArrayList<Town> tspnn_simpleNN(final ArrayList<Town> pTowns) {
        ArrayList<Town> bestRoute = null;
        double bestCosts = Double.MAX_VALUE;

        for (final Town startingTown : pTowns) {
            //setup
            final ArrayList<Town> visitedTowns = new ArrayList<>(); // ArrayList because we need a stable index
            final HashSet<Town> unvisitedTowns = new HashSet<>(pTowns); // all towns are available at start; we use HashSet because we need fast search; indexing plays not role here

            // step 1
            Town currentTown = startingTown;
            visitedTowns.add(currentTown);
            unvisitedTowns.remove(currentTown);

            // steps 2-n
            while (unvisitedTowns.size() > 0) {
                // find nearest town
                final Town nearestTown = findNearestTown(currentTown, unvisitedTowns);
                if (nearestTown == null) throw new IllegalStateException("Something in the code is wrong...");

                currentTown = nearestTown;
                visitedTowns.add(currentTown);
                unvisitedTowns.remove(currentTown);
            }

            // selection
            final double cost = getCostsOfRoute(visitedTowns);
            if (cost < bestCosts) {
                bestCosts = cost;
                bestRoute = visitedTowns;
            }
        }
        return bestRoute;
    }



    static private Town findNearestTown(final Town pCurrentTown, final HashSet<Town> pSelectableTowns) {
        double minDist = Double.MAX_VALUE;
        Town minTown = null;

        for (final Town checkTown : pSelectableTowns) {
            final double dist = pCurrentTown.getDistanceTo(checkTown);
            if (dist < minDist) {
                minDist = dist;
                minTown = checkTown;
            }
        }

        return minTown;
    }
    static private double getCostsOfRoute(final ArrayList<Town> pTowns) {
        double costs = 0;
        for (int i = 1; i < pTowns.size(); i++) { // use pre-index
            final Town t1 = pTowns.get(i - 1);
            final Town t2 = pTowns.get(i);
            final double cost = t1.getDistanceTo(t2);
            costs += cost;
        }
        return costs;
    }



}

这在未更改的状态下为我们提供类似于以下的输出:

Towns:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
Distance Matrix:
    0.0 (A-A)   1.0 (A-B)   2.0 (A-C)   3.0 (A-D)   
    1.0 (B-A)   0.0 (B-B)   1.0 (B-C)   2.0 (B-D)   
    2.0 (C-A)   1.0 (C-B)   0.0 (C-C)   1.0 (C-D)   
    3.0 (D-A)   2.0 (D-B)   1.0 (D-C)   0.0 (D-D)   
tspnn Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
tspnn_simpleNN Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)


Towns:
    C (2/0)
    D (3/0)
    B (1/0)
    A (0/0)
Distance Matrix:
    0.0 (C-C)   1.0 (C-D)   1.0 (C-B)   2.0 (C-A)   
    1.0 (D-C)   0.0 (D-D)   2.0 (D-B)   3.0 (D-A)   
    1.0 (B-C)   2.0 (B-D)   0.0 (B-B)   1.0 (B-A)   
    2.0 (A-C)   3.0 (A-D)   1.0 (A-B)   0.0 (A-A)   
tspnn Path:
    C (2/0)
    D (3/0)
    B (1/0)
    A (0/0)
tspnn_simpleNN Path:
    D (3/0)
    C (2/0)
    B (1/0)
    A (0/0)


Towns:
    D (3/0)
    B (1/0)
    C (2/0)
    A (0/0)
Distance Matrix:
    0.0 (D-D)   2.0 (D-B)   1.0 (D-C)   3.0 (D-A)   
    2.0 (B-D)   0.0 (B-B)   1.0 (B-C)   1.0 (B-A)   
    1.0 (C-D)   1.0 (C-B)   0.0 (C-C)   2.0 (C-A)   
    3.0 (A-D)   1.0 (A-B)   2.0 (A-C)   0.0 (A-A)   
tspnn Path:
    D (3/0)
    B (1/0)
    C (2/0)
    A (0/0)
tspnn_simpleNN Path:
    D (3/0)
    C (2/0)
    B (1/0)
    A (0/0)


Towns:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
Distance Matrix:
    0.0 (A-A)   1.0 (A-B)   2.0 (A-C)   3.0 (A-D)   
    1.0 (B-A)   0.0 (B-B)   1.0 (B-C)   2.0 (B-D)   
    2.0 (C-A)   1.0 (C-B)   0.0 (C-C)   1.0 (C-D)   
    3.0 (D-A)   2.0 (D-B)   1.0 (D-C)   0.0 (D-D)   
tspnn Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)
tspnn_simpleNN Path:
    A (0/0)
    B (1/0)
    C (2/0)
    D (3/0)

如您所见,您的算法严重依赖于输入/城镇的顺序。如果算法正确,则结果将始终是 A-B-C-D 或 D-C-B-A。

所以使用这个“测试”框架来改进您的代码。您提供的tspnn() 的方法不依赖于其他代码,因此一旦您改进了代码,您就可以注释掉我所有的东西。或者把这一切都放在另一个类中,并跨类调用你的真实实现。因为它是static public,所以您可以通过YourClassName.tspnn(distances) 轻松调用它。

另一方面,也许看看你是否可以改进自动标记程序,这样你就可以毫无问题地使用完整的Java。

【讨论】: