【问题标题】:Function to check if linear combination of elements in list are equal to a certain sum检查列表中元素的线性组合是否等于某个总和的函数
【发布时间】:2020-08-14 15:05:50
【问题描述】:

问题是创建一个函数来检查该列表中某些项目的线性组合是否等于某个总和。结果将是一个带有元组的列表(与列表的长度相同)。 例如:给定列表:[3,7,10]sum= 60 结果:[(0, 0, 6), (1, 1, 5), (2, 2, 4), (3, 3, 3), (4, 4, 2), (5, 5, 1), (6, 6, 0), (10, 0, 3), (11, 1, 2), (12, 2, 1), (13, 3, 0), (20, 0, 0)] 问题是列表的长度会有所不同。我尝试使用一堆 if 语句而不是使用 for 循环来解决它,但必须有更有效的方法来解决它。

这是我使用的一些代码。

def get_index(l, s):
    res = []
    if len(l)==3:
        for i in range(s+1):
                for j in range(s+1):
                    for k in range(s+1):
                        if l[0]*i + l[1]*j + l[2]*k==s:
                            res.append((i,j,k))
    return res

已经谢谢了!!

注意: 如果我确实将范围更改为(s//l[i]+1),它会起作用。

【问题讨论】:

  • 展示你尝试过的东西会很好。
  • 是的,对不起,我是新来的。
  • @V_Buffalo 欢迎来到 SO。请查看stackoverflow.com/help/someone-answers 以了解如何接受关于 SO 的答案。谢谢。
  • 我的解决方案与 Ehsan 的解决方案非常相似,它使用蛮力。然而,这里解释了一种数学方法,我将把它用于非暴力方法。等待更新。

标签: python numpy linear-equation


【解决方案1】:

我觉得有更好的方法来做到这一点,但这里有一个使用数组的蛮横的方法:

A = np.array([3,7,10])
b = np.array([60])

from itertools import product
combin = [np.arange(i) for i in (b//A)+1]
d = np.stack(list(product(*combin)))
[tuple(i) for i in d[d@A==b]]

或等效地没有 itertools:

d = np.rollaxis(np.indices((b//A)+1),0,4).reshape(-1,3)
[tuple(i) for i in d[d@A==b]]

输出:

[(0, 0, 6), (1, 1, 5), (2, 2, 4), (3, 3, 3), (4, 4, 2), (5, 5, 1), (6, 6, 0), (10, 0, 3), (11, 1, 2), (12, 2, 1), (13, 3, 0), (20, 0, 0)]

比较

#@Ehsan's solution 1
def m1(b):
  combin = [np.arange(i) for i in (b//A)+1]
  d = np.stack(list(product(*combin)))
  return [tuple(i) for i in d[d@A==b]]

#@Ehsan's solution 2
def m2(b):
  d = np.rollaxis(np.indices((b//A)+1), 0, 4).reshape(-1,3)
  return [tuple(i) for i in d[d@A==b]]

#@mathfux solution
def m3(b):
  A, B, C = range(0, b+1, 3), range(0, b+1, 7), range(0, b+1, 10)
  triplets = list(product(A, B, C)) #all triplets
  suitable_triplets = list(filter(lambda x: sum(x)==b, triplets)) #triplets that adds up to 60
  return [(a//3, b//7, c//10) for a, b, c in suitable_triplets]

性能

in_ = [np.array([n]) for n in [10,100,1000]]

这使得 m2 是其中最快的。

【讨论】:

  • 是的,我认为使用数组是个好主意,但它并不能解决我的问题。谢谢
  • @V_Buffalo 它应该比循环快得多。是可扩展性的问题吗?如果是这样,你的数组有多大,你需要这个代码有多快?
  • 给定列表的长度在 2 到 5 之间。在 jupyter 笔记本中,您的代码和我的代码都很好。但是当我通过我的学校用来检查代码的程序运行它时,我的代码(使用 if 语句和 for 循环)达到了时间限制,而你的代码超出了内存限制。这是我试图解决的一个练习,为我的考试做准备,而不是作业或任何东西。
  • 哪种考试?如果和高中数学有关,我们需要一个完全不同的算法:)
  • @V_Buffalo 然后使用 Sympy。它可能有一个快速且内存高效的解决方案。此外,我希望数组比列表占用更多内存,但效率更高。
【解决方案2】:

这完全是一个数学问题。你需要找到线性丢番图方程3a+7b+10c=60的所有非负解三元组。

为这个方程寻找解的主要思想可以用生成函数(多项式)来说明。让我们取三个这样的多项式:

A=1+x^3+x^6+x^9+...+x^60
B=1+x^7+x^14+x^21+...+x^56
C=1+x^10+x^20+x^30+...+x^60

如果你将它们相乘,你会看到每个术语 x^n 可以表示为 x^ax^bx^c 的乘积,这些术语中的每一个都取自 ABC

蛮力方法。您需要定义这些多项式的乘法,以跟踪被相乘的项,就像这样:

[0, 3, 6] * [0, 7, 14] * [0, 10] = [(0,0,0), (0,0,10), (0,7,0), (0,7,10), (3,0,0), (3,0,10), (3,7,0), (3,7,10), (6,0,0), (6,0,10), (6,7,0), (6,7,10)]

列表在 Python 中没有 * 运算符,但幸运的是,您可以改用 itertools.product 方法。这是一个完整的解决方案:

from itertools import product
A, B, C = range(0, 61, 3), range(0, 61, 7), range(0, 61, 10)
triplets = list(product(A, B, C)) # all triplets
suitable_triplets = list(filter(lambda x: sum(x)==60, triplets)) #triplets that adds up to 60
print([[a//3, b//7, c//10] for a, b, c in suitable_triplets])

矢量化蛮力。这是基于之前的脚本,将所有循环替换为 numpy 操作:

import numpy as np
l = np.array([3,7,10])
s = 60
unknowns = [range(0, s+1, n) for n in l]
triplets = np.stack(np.meshgrid(*unknowns), axis=-1).reshape(-1, len(unknowns))
suitable_triplets = triplets[np.sum(triplets, axis=1) == s]
solutions = suitable_triplets//l

数学方法。一般来说,求解线性丢番图方程是困难的。看看这个SO answer。它说sympy 只能找到参数化解决方案,但无法识别域:

import sympy as sp
from sympy.solvers.diophantine.diophantine import diop_solve
x,y,z = sp.symbols('x y z')
solution = diop_solve(3*x + 7*y + 10*z-60)

解决方案的输出是(t_0, t_0 + 10*t_1 + 180, -t_0 - 7*t_1 - 120)。 使用 Sage 可以优化解决方案,但在这种情况下您需要下载 Linux 操作系统:D

【讨论】:

  • 这是相同的方法,但与使用数组相比效率较低。
  • 是的,它仍然是蛮力。它可以通过两种方式进行轻微优化:在每个产品之后使用filter,或者像你一样使用numpy(因为它使用C)。如果您需要进行必要的更新,则应应用生成函数的代数。
  • 感谢您的帮助。
  • 但是我应该以什么方式使用 Numpy?
  • 您应该尝试用基于numpy 的操作替换所有列表推导、产品、for 循环等。这并不总是那么容易。这是一个很好的例子:product(A, B, C) 很慢。可以替换为np.stack(np.meshgrid(A, B, C), axis=-1).reshape(-1, 3)
猜你喜欢
  • 2017-08-17
  • 2012-09-28
  • 2018-05-15
  • 2023-04-02
  • 1970-01-01
  • 2015-04-16
  • 1970-01-01
  • 2021-02-07
  • 1970-01-01
相关资源
最近更新 更多