【问题标题】:Unable to multi-thread a scalable method无法对可扩展方法进行多线程处理
【发布时间】:2012-03-27 01:40:08
【问题描述】:

更新:为了帮助澄清我的要求,我发布了一些 Java 代码,可以理解这个想法。

不久前,我向question 询问了如何让算法分解一组数字,我的想法是给它一个数字列表(1,2,3,4,5) 和总数(10),它会找出每个数字的所有倍数,加起来就是总数('1*10' 或 '1*1,1*2,1*3,1*4' 或 '2*5' 等)。这是我做过的第一次编程练习,所以我花了一段时间才让它工作,但现在我想看看我是否可以扩展它。原始问题中的人说它是可扩展的,但我对如何做到这一点有点困惑。递归部分是我坚持缩放组合所有结果的部分的区域(它所指的表不可缩放,但应用缓存我能够使其快速)

我有以下算法(伪代码):

//generates table
for i = 1 to k
    for z = 0 to sum:
        for c = 1 to z / x_i:
            if T[z - c * x_i][i - 1] is true:
                set T[z][i] to true

//uses table to bring all the parts together
function RecursivelyListAllThatWork(k, sum) // Using last k variables, make sum
    /* Base case: If we've assigned all the variables correctly, list this
     * solution.
     */
    if k == 0:
        print what we have so far
        return

    /* Recursive step: Try all coefficients, but only if they work. */
    for c = 0 to sum / x_k:
       if T[sum - c * x_k][k - 1] is true:
           mark the coefficient of x_k to be c
           call RecursivelyListAllThatWork(k - 1, sum - c * x_k)
           unmark the coefficient of x_k

我真的不知道如何线程/多处理 RecursivelyListAllThatWork 函数。我知道如果我向它发送一个较小的 K(它是列表中项目总数的整数),它将处理该子集,但我不知道如何将结果合并到子集中。例如,如果列表是[1,2,3,4,5,6,7,8,9,10],我发送它 K=3,那么只有 1,2,3 得到处理,这很好,但如果我需要包含 1 和 10 的结果呢?我试图修改表(变量 T),所以只有我想要的子集在那里,但仍然不起作用,因为就像上面的解决方案一样,它做了一个子集,但无法处理需要更广泛范围的答案。

我不需要任何代码,只要有人可以解释如何在概念上打破这个递归步骤,以便可以使用其他内核/机器。

更新:我似乎仍然无法弄清楚如何将 RecursivelyListAllThatWork 变成可运行的(我在技术上知道如何做到这一点,但我不明白如何更改 RecursivelyListAllThatWork 算法,以便它可以并行运行。其他部分只是为了让示例工作,我只需要在 RecursivelyListAllThatWork 方法上实现 runnable)。这是java代码:

import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


public class main
{
    public static void main(String[] args)
    {
        System.out.println("starting..");
        int target_sum = 100;
        int[] data = new int[] { 10, 5, 50, 20, 25, 40 };
        List T = tableGeneator(target_sum, data);
        List<Integer> coeff = create_coeff(data.length);
        RecursivelyListAllThatWork(data.length, target_sum, T, coeff, data);
    }

    private static List<Integer> create_coeff(int i) {
        // TODO Auto-generated method stub
        Integer[] integers = new Integer[i];
        Arrays.fill(integers, 0);
        List<Integer> integerList = Arrays.asList(integers);
        return integerList;
    }


    private static void RecursivelyListAllThatWork(int k, int sum, List T, List<Integer> coeff, int[] data) {
        // TODO Auto-generated method stub
        if (k == 0) {
            //# print what we have so far
            for (int i = 0; i < coeff.size(); i++) {
                System.out.println(data[i] + " = " + coeff.get(i));
            }

            System.out.println("*******************");
            return;
        }

        Integer x_k = data[k-1];
        //  Recursive step: Try all coefficients, but only if they work. 
        for (int c = 0; c <= sum/x_k; c++) { //the c variable caps the percent
            if (T.contains(new Point((sum - c * x_k), (k-1))))
            {
                    // mark the coefficient of x_k to be c
                    coeff.set((k-1), c);
                    RecursivelyListAllThatWork((k - 1), (sum - c * x_k), T, coeff, data);
                    // unmark the coefficient of x_k
                    coeff.set((k-1), 0);
            }

        }

    }

    public static List tableGeneator(int target_sum, int[] data) {
        List T = new ArrayList();
        T.add(new Point(0, 0));

        float max_percent = 1;
        int R = (int) (target_sum * max_percent * data.length);
        for (int i = 0; i < data.length; i++)
        {
            for (int s = -R; s < R + 1; s++)
            {
                int max_value = (int) Math.abs((target_sum * max_percent)
                        / data[i]);
                for (int c = 0; c < max_value + 1; c++)
                {
                    if (T.contains(new Point(s - c * data[i], i)))
                    {
                        Point p = new Point(s, i + 1);
                        if (!T.contains(p))
                        {
                            T.add(p);
                        }
                    }
                }
            }
        }
        return T;
    }
} 

【问题讨论】:

  • tableGenerator 方法中“max_percent = 1”的原因是什么?
  • @YvesMartin 它用于限制结果,因此没有任何内容超过 100%,但忽略 tableGenerator 方法的细节,它只是提供一个表格来向您展示 RecursivelyListAllThatWork 的工作原理。所以鉴于该表格, RecursivelyListAllThatWork 可以线程化吗?
  • 您是否知道表格生成器已经找到了问题的解决方案,而无需任何递归,而是通过尝试所有组合?这个工作真的很费钱,最后 RecursivelyListAllThatWork 也没用了。我已经发布了一个“真正的”递归实现作为原始问题的答案。您误解了伪代码中“T”的含义——如果答案是解决,它是一个谓词,它回答“真”。
  • @YvesMartin 我认为 T 表示该值是否可解......然后我需要 RecursivelyListAllThatWork 才能将所有结果汇总在一起?
  • 好的,我现在明白会发生什么了。生成的表用于削减大量没有解决方案的递归调用。

标签: java algorithm


【解决方案1】:

多线程的一般答案是通过堆栈(LIFO 或 FIFO)对递归实现进行反递归。在实现这样的算法时,线程数是算法的固定参数(例如内核数)。

为了实现它,当测试条件结束递归时,语言调用堆栈被一个堆栈替换,该堆栈存储最后一个上下文作为检查点。在您的情况下,k=0coeff 值与目标总和匹配。

在去递归之后,第一个实现是运行多个线程来消耗堆栈但是堆栈访问成为一个争用点,因为它可能需要同步。

更好的可扩展解决方案是为每个线程专用一个堆栈,但需要在堆栈中初始生成上下文。

我提出了一种混合方法,第一个线程对有限数量的k 递归工作作为最大递归深度:示例中的小数据集为 2,但如果更大,我建议使用 3。然后第一部分将生成的中间上下文委托给一个线程池,该线程池将以非递归实现处理剩余的k。此代码不是基于您使用的复杂算法,而是基于相当“基本”的实现:

import java.util.Arrays;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

public class MixedParallel
{
    // pre-requisite: sorted values !!
    private static final int[] data = new int[] { 5, 10, 20, 25, 40, 50 };

    // Context to store intermediate computation or a solution
    static class Context {
        int k;
        int sum;
        int[] coeff;
        Context(int k, int sum, int[] coeff) {
            this.k = k;
            this.sum = sum;
            this.coeff = coeff;
        }
    }

    // Thread pool for parallel execution
    private static ExecutorService executor;
    // Queue to collect solutions
    private static Queue<Context> solutions;

    static {
        final int numberOfThreads = 2;
        executor =
            new ThreadPoolExecutor(numberOfThreads, numberOfThreads, 1000, TimeUnit.SECONDS,
                                   new LinkedBlockingDeque<Runnable>());
        // concurrent because of multi-threaded insertions
        solutions = new ConcurrentLinkedQueue<Context>();
    }


    public static void main(String[] args)
    {
        int target_sum = 100;
        // result vector, init to 0
        int[] coeff = new int[data.length];
        Arrays.fill(coeff, 0);
        mixedPartialSum(data.length - 1, target_sum, coeff);

        executor.shutdown();
        // System.out.println("Over. Dumping results");
        while(!solutions.isEmpty()) {
            Context s = solutions.poll();
            printResult(s.coeff);
        }
    }

    private static void printResult(int[] coeff) {
        StringBuffer sb = new StringBuffer();
        for (int i = coeff.length - 1; i >= 0; i--) {
            if (coeff[i] > 0) {
                sb.append(data[i]).append(" * ").append(coeff[i]).append("   ");
            }
        }
        System.out.println(sb.append("from ").append(Thread.currentThread()));
    }

    private static void mixedPartialSum(int k, int sum, int[] coeff) {
        int x_k = data[k];
        for (int c = sum / x_k; c >= 0; c--) {
            coeff[k] = c;
            int[] newcoeff = Arrays.copyOf(coeff, coeff.length);
            if (c * x_k == sum) {
                //printResult(newcoeff);
                solutions.add(new Context(0, 0, newcoeff));
                continue;
            } else if (k > 0) {
                if (data.length - k < 2) {
                    mixedPartialSum(k - 1, sum - c * x_k, newcoeff);
                    // for loop on "c" goes on with previous coeff content
                } else {
                    // no longer recursive. delegate to thread pool
                    executor.submit(new ComputePartialSum(new Context(k - 1, sum - c * x_k, newcoeff)));
                }
            }
        }
    }

    static class ComputePartialSum implements Callable<Void> {
        // queue with contexts to process
        private Queue<Context> contexts;

        ComputePartialSum(Context request) {
            contexts = new ArrayDeque<Context>();
            contexts.add(request);
        }

        public Void call() {
            while(!contexts.isEmpty()) {
                Context current = contexts.poll();
                int x_k = data[current.k];
                for (int c = current.sum / x_k; c >= 0; c--) {
                    current.coeff[current.k] = c;
                    int[] newcoeff = Arrays.copyOf(current.coeff, current.coeff.length);
                    if (c * x_k == current.sum) {
                        //printResult(newcoeff);
                        solutions.add(new Context(0, 0, newcoeff));
                        continue;
                    } else if (current.k > 0) {
                        contexts.add(new Context(current.k - 1, current.sum - c * x_k, newcoeff));
                    }
                }
            }
            return null;
        }
    }
}

您可以检查哪个线程找到了输出结果并检查所有线程:递归模式的主线程和上下文堆栈模式的池中的两个线程。

现在当data.length 很高时,这个实现是可扩展的:

  • 最大递归深度限制在低级别的主线程
  • 池中的每个线程都使用自己的上下文堆栈,而不会与其他线程争用
  • 现在要调整的参数是numberOfThreadsmaxRecursionDepth

所以答案是肯定的,您的算法可以并行化。这是基于您的代码的完全递归实现:

import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

public class OriginalParallel
{
    static final int numberOfThreads = 2;
    static final int maxRecursionDepth = 3;

    public static void main(String[] args)
    {
        int target_sum = 100;
        int[] data = new int[] { 50, 40, 25, 20, 10, 5 };
        List T = tableGeneator(target_sum, data);
        int[] coeff = new int[data.length];
        Arrays.fill(coeff, 0);
        RecursivelyListAllThatWork(data.length, target_sum, T, coeff, data);
        executor.shutdown();
    }

    private static void printResult(int[] coeff, int[] data) {
        StringBuffer sb = new StringBuffer();
        for (int i = coeff.length - 1; i >= 0; i--) {
            if (coeff[i] > 0) {
                sb.append(data[i]).append(" * ").append(coeff[i]).append("   ");
            }
        }
        System.out.println(sb.append("from ").append(Thread.currentThread()));
    }

    // Thread pool for parallel execution
    private static ExecutorService executor;
    static {
        executor =
            new ThreadPoolExecutor(numberOfThreads, numberOfThreads, 1000, TimeUnit.SECONDS,
                                   new LinkedBlockingDeque<Runnable>());
    }

    private static void RecursivelyListAllThatWork(int k, int sum, List T, int[] coeff, int[] data) {
        if (k == 0) {
            printResult(coeff, data);
            return;
        }
        Integer x_k = data[k-1];
        //  Recursive step: Try all coefficients, but only if they work. 
        for (int c = 0; c <= sum/x_k; c++) { //the c variable caps the percent
            if (T.contains(new Point((sum - c * x_k), (k-1)))) {
                    // mark the coefficient of x_k to be c
                    coeff[k-1] = c;
                    if (data.length - k != maxRecursionDepth) {
                        RecursivelyListAllThatWork((k - 1), (sum - c * x_k), T, coeff, data);
                    } else {
                        // delegate to thread pool when reaching depth 3
                        int[] newcoeff = Arrays.copyOf(coeff, coeff.length);
                        executor.submit(new RecursiveThread(k - 1, sum - c * x_k, T, newcoeff, data)); 
                    }
                    // unmark the coefficient of x_k
                    coeff[k-1] = 0;
            }
        }
    }

    static class RecursiveThread implements Callable<Void> {
        int k;
        int sum;
        int[] coeff;
        int[] data;
        List T;

        RecursiveThread(int k, int sum, List T, int[] coeff, int[] data) {
            this.k = k;
            this.sum = sum;
            this.T = T;
            this.coeff = coeff;
            this.data = data;
            System.out.println("New job for k=" + k);
        }

        public Void call() {
            RecursivelyListAllThatWork(k, sum, T, coeff, data);
            return null;
        }
    }

    public static List tableGeneator(int target_sum, int[] data) {
        List T = new ArrayList();
        T.add(new Point(0, 0));

        float max_percent = 1;
        int R = (int) (target_sum * max_percent * data.length);
        for (int i = 0; i < data.length; i++) {
            for (int s = -R; s < R + 1; s++) {
                int max_value = (int) Math.abs((target_sum * max_percent) / data[i]);
                for (int c = 0; c < max_value + 1; c++) {
                    if (T.contains(new Point(s - c * data[i], i))) {
                        Point p = new Point(s, i + 1);
                        if (!T.contains(p)) {
                            T.add(p);
                        }
                    }
                }
            }
        }
        return T;
    }
}

【讨论】:

  • 哇..这看起来真不错。我会玩弄它并让你知道。结果实际上被转储到另一个程序,所以我现在只禁用打印并测试它编译所有内容的速度,然后我将使用打印来测试结果。非常感谢。
  • 到目前为止有 2 个问题(轻微),一个似乎不是从命令行编译的(我将其复制/粘贴到虚拟机中,然后在其上运行 javac,但我得到“注意:MixedParallel.java 使用未经检查或不安全的操作。”)当我在 eclipse executor.shutdown();不存在..执行器存在但没有关闭选项。
  • 我知道我的代码与您实现的算法不完全对应,但它应该可以帮助您在自己的逻辑中实现反递归和并行执行 - 但是这个新版本需要比现有的更多内存一个
  • 非常感谢您的发帖。它的定义。非常有趣且不同的方法(我将使用许多变量进行测试)。所以我发布的问题非常接近我正在使用的逻辑,但不完全是,所以在旧的解决方案中,有一行“s - c * data[i]”我修改了逻辑,有一个类似的区域?是mixparticalsum吗?例如,假设我想要百分比或限制结果(如果我有负值)等等。不知道在哪里放置我自己的逻辑。我有逻辑/代码只需要研究一下就可以了解如何整合它。
  • 您可以删除“解决方案”队列以减少内存占用...我真的认为您可以先将有限递归深度的概念应用到您的算法中,然后再并行执行。但是带有上下文的队列会消耗内存!而且我认为下一个问题将是并行化表生成器,我不知道如何做到这一点,因为每个新点可能取决于以前的点。
【解决方案2】:

1) 代替

if k == 0:
    print what we have so far
    return

您可以查看有多少系数是非零的;如果该计数大于某个阈值(在您的示例中为 3),则不要打印它。 (提示:这将与

mark the coefficient of x_k to be c

行。)

2) 递归函数本质上通常是指数型的,这意味着随着规模的扩大,运行时间会急剧增长。

考虑到这一点,您可以将多线程应用于计算表和递归函数。

在考虑表格的时候,想想循环的哪些部分相互影响,必须按顺序进行;反过来,当然是找出哪些部分不会相互影响并且可以并行运行。

至于递归函数,最好的办法可能是将多线程应用到分支部分。

【讨论】:

  • 感谢您的帮助。关于第一部分,我之前有类似的逻辑,但未能成功运行。每次将 x_k 更改为 C 时,我都添加了一个计数器来计数,然后如果 K ==0 我添加了或 count = 2。这似乎根本没有限制结果(尽管它确实影响了结果,因为结果只有缺少 1 个变量。关于第 2 点,我让表格工作得很快(但我无法缩放它,因为它依赖于前面的步骤)所以我正在研究递归函数。“分支部分”是什么意思? for 循环?
  • 任何时候递归函数调用自己,它都会创建另一个分支,所以是的,for循环就是分支部分。
  • 谢谢我错了,你的逻辑在限制结果方面是正确的。我尝试失败了,因为我没有意识到我的计数器在递归中没有被重置。我的解决方案并不完美,但我认为我有基础可以玩。感谢您清理它。
  • 我将问题更新为仅关注线程,因为我仍然不明白如何分解它。如果我必须对分支进行线程处理,但我的目标是缩放递归函数本身,并可能向它发送不同的数据或其他东西,以便每个函数在不同的部分上工作(但我的逻辑遇到了问题)。
【解决方案3】:

实现这种多线程的关键是确保您没有不必要的全局数据结构,例如系数上的“标记”。

假设您的表中有 K 个数字 n[0] ... n[K-1] 并且您想要达到的总和是 S。我在下面假设数组 n[] 从最小到最大排序号码。

这里有一个简单的枚举算法。 i 是数字列表的索引,s 是已经建立的当前总和,cs 是数字 0 .. i - 1 的系数列表:

function enumerate(i, s, cs):
  if (s == S):
     output_solution(cs)
  else if (i == K):
     return // dead end
  else if ((S - s) < n[i]):
     return // no solution can be found
  else:
     for c in 0 .. floor((S - s) / n[i]): // note: floor(...) > 0
        enumerate(i + 1, s + c * n[i], append(cs, c))

运行进程:

 enumerate(0, 0, make_empty_list())

现在这里不再有全局数据结构,除了表 n[](常量数据),并且“枚举”也不会返回任何内容,因此您可以随意更改递归调用以在其自己的线程中运行。例如。您可以为递归 enumerate() 调用生成一个新线程,除非您已经运行了太多线程,在这种情况下您需要等待。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2022-07-23
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-03-07
    • 2016-04-01
    相关资源
    最近更新 更多