用简洁的语言解释算法对我来说有点困难T.T
所以在下面的代码cmets中说明了充分的解释。
基本思想是以递归方式完成(DFS,深度优先搜索)。该函数应该类似于recursion(percent_list, result_list, target)。
- 最初应该是
recursion([0.1, 0.1, 0.8], [], 9)
- 如果我们尝试将第一个值设为 3.25,那么我们将目标值更新为 9 - 3.25*0.1 = 8.675。所以我们接下来调用
recursion([0.1, 0.8], [3.25], 8.675);
- 然后,我们尝试将第二个值设为 4.00,然后将目标值更新为 8.675 - 4.0*0.1 = 8.275。所以打电话给
recursion([0.8], [3.25, 4.0], 8.275);
- 最后,我们尝试第三个值,只有 9.75,10 是有效的,因为相加的值分别是 8.525 和 8.725,并且可以四舍五入到 9。所以我们将结果
[3.25, 4.0, 9.75] 和 [3.25, 4.0, 10.0] 添加到结果列表。
- 之后,我们尝试将第二个值设为 0.25, ..., 3.75, 4.25, 4.5, ..., 10。
- 尝试将第一个值设为 0.25, ..., 3.0, 3.5, 3.75, ..., 10。
为了避免过多的递归调用,我们需要计算每次都可以附加到结果中的有效值,以削减不可能的分支。
实际的函数签名有些不同,要实现四舍五入。
import numpy as np
def recursion(percent_list, value_list, previous_results, target, tolerate_lower, tolerate_upper, result_list):
# change , 0.25 ~ 10 , change, , change, 0.5 = 9.5-9 , 0.4999 < 9-8.5, your answer
# init: [0.1,0.1,0.8] [] 9
# if reach the target within tolerate
if len(percent_list) == 0:
# print(previous_results)
result_dict.append(previous_results)
return
# otherwise, cut impossible branches, check minimum and maximum value acceptable for current percent_list
percent_sum = percent_list.sum() # sum up current percent list, **O(n)**, should be optimized by pre-generating a sum list
value_min = value_list[0] # minimum value from data list (this problem 0.25)
value_max = value_list[-1] # maximum value from data list (this problem 10.0)
search_min = (target - tolerate_lower - (percent_sum - percent_list[0]) * value_max) / percent_list[0] # minimum value acceptable as result
search_max = (target + tolerate_upper - (percent_sum - percent_list[0]) * value_min) / percent_list[0] # maximum value acceptable as result
idx_min = np.searchsorted(value_list, search_min, "left") # index of minimum value (for data list)
idx_max = np.searchsorted(value_list, search_max, "right") # index of maximum value (for data list)
# recursion step
for i in range(idx_min, idx_max):
# update result list
current_results = previous_results + [value_list[i]]
# remove the current state for variables `percent_list`, and update `target` for next step
recursion(percent_list[1:], value_list, current_results, target - percent_list[0] * value_list[i], tolerate_lower, tolerate_upper, result_list)
为了解决当前的问题,
result = []
recursion(np.array([0.1, 0.1, 0.8]), np.arange(0.25, 10.25, 0.25), [], 9, 0.5, 0.49999, result)
总共有 4806 个可能的结果。验证结果总计大约 9 个(但无法验证结果就足够了),
for l in result:
if not (8.5 <= (np.array([0.1, 0.1, 0.8]) * np.array(l)).sum() < 9.5):
print("Wrong code!")
我认为最复杂的情况仍然是 O(m^n * n),如果 m指数据列表长度 (0.25, 0.5, ..., 10),n 指百分比列表长度 (0.1, 0.1, 0.8)。应该进一步优化到O(m^n * log(m)),避免每次递归求和百分表;和O(m^n),如果我们能充分利用数据列表的等差数列的性质。