感谢您提出非常有趣的问题!很高兴为此实施我自己的解决方案,并提高您的速度。
在测试 2000 个元素的随机输入时,我将您的解决方案的速度提高了 8x-9x 倍。为此,我做了以下事情:
-
通过使dp 结构unordered_map<int, shared_ptr<vector<pair<int, int>>>> 交换vector 和unordered_map 的顺序。这样我们以后就可以重用指向无序映射的Min 条目的指针,方法是将这个条目向下传递给递归函数。
-
在dp 中将向量包装成std::shared_ptr。这可以确保即使无序映射增长,向量也不会改变内存位置。
-
我没有使用dp[ptr][Min] 再次进行相同的无序映射搜索,而是通过.find(Min) 在单个函数调用中搜索了一次,然后在函数的所有其他位置重用了这个迭代器。
-
将Min_vec 指针向下传递给所有递归调用,以便其他函数调用不会再次在无序映射中搜索相同的Min 条目。只是一个地方没有重用这个指针,nums[ptr] + 1 作为 Min 传递。
-
创建了特殊的 lambda 函数GetMinVec(),它找到并返回指向 Min 条目的指针。如果此条目尚不存在,我会通过分配 shared_ptr 并将向量的大小调整为 nums.size() 的大小来初始化它,并用 pair(-1, -1) 填充表示空值。
以上所有步骤都将您的解决方案提高了8x-9x 倍。您的原始解决方案位于我的代码中 SolutionOriginal 类。而我的增强变体是SolutionBoosted。
我还创建了自己的解决方案,该解决方案比您的原始解决方案更快100x,比提升原始解决方案更快11x-12x。这个新的解决方案我称之为SolutionSqrt。它的工作原理如下:
-
让我们创建K = sqrt(N) 存储桶。每个K 存储桶最后平均会有大约K 元素。
-
每个存储桶都将帮助我们缩小所有小于当前元素的元素的搜索范围。对于每个新数字,我们都会找到桶索引buck_i,它表示所有具有较小索引的桶都包含不大于当前的元素。
-
在当前存储桶buck_i 中,我们将进行线性搜索以查找所有小于当前存储桶的元素。
-
每个存储桶都包含一个 Heap 数组,该数组使用 std::push_heap 增长。这个堆是这样排序的,它的第一个元素(头)包含最大的元素。
-
Bucket 的 Heap 以这样一种方式排序,即第一个元素在此堆中具有最长的数字链。
-
我们平均对每个新号码进行Sqrt(N) + Sqrt(N) 搜索。我们搜索如何通过添加这个新元素来扩展当前最长的链。
-
如果最长链被成功延长,那么我们会记住当前桶buck_i的新最大长度和新最大计数。然后新元素成为第一个 Heap 元素,因为它提供了最长的链。
-
每个桶都有自己的最长链长度和形成的这种链的数量。
-
此类算法的总体复杂度为O(N * Sqrt(N))。如果我们使用 3 级存储桶代替 2 级存储桶,这种复杂性可以进一步提高,在这种情况下,我们将实现复杂性O(N * N^(1/3))。这个O(N * Sqrt(N)) 的复杂性远低于原始OP 的代码O(N^2)。
我在 2000 个数字的随机测试集上对所有 3 种解决方案的变体进行了计时。代码后可以看到计时。
Try it online!
#include <cstdint>
#include <vector>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <limits>
class SolutionSqrt {
public:
using u64 = uint64_t;
using NumT = int;
u64 findNumberOfLIS(std::vector<NumT> const & nums) {
size_t const
N = nums.size(),
nbucks = std::max<size_t>(1, std::llround(std::sqrt(N))),
buck_size = (N + nbucks - 1) / nbucks;
std::vector<size_t> si(N), rsi(N), ll(N);
std::vector<u64> lc(N);
for (size_t i = 0; i < si.size(); ++i)
si[i] = i;
std::sort(si.begin(), si.end(), [&](size_t a, size_t b){
return nums[a] < nums[b];
});
for (size_t i = 0; i < rsi.size(); ++i)
rsi[si[i]] = i;
struct Buck {
std::vector<size_t> sj;
size_t max_ll = 0;
u64 max_lc = 0;
NumT max_num = std::numeric_limits<NumT>::min(),
min_num = std::numeric_limits<NumT>::max();
};
std::vector<Buck> bucks(nbucks);
for (size_t i = 0; i < N; ++i) {
auto const num = nums[i];
size_t const buck_i = rsi[i] / buck_size;
size_t max_ll = 0;
for (size_t j = 0; j <= buck_i; ++j) {
auto const & buckj = bucks[j];
if (buckj.sj.empty())
continue;
if (j < buck_i && buckj.max_num < num)
max_ll = std::max(max_ll, buckj.max_ll);
else {
if (buckj.min_num == buckj.max_num && buckj.max_num == num)
continue;
for (size_t k = 0; k < buckj.sj.size(); ++k)
if (nums[buckj.sj[k]] < num)
max_ll = std::max(max_ll, ll[buckj.sj[k]]);
}
}
if (max_ll > 0) {
u64 clc = 0;
for (size_t j = 0; j <= buck_i; ++j) {
auto const & buckj = bucks[j];
if (buckj.sj.empty())
continue;
if (j < buck_i && buckj.max_num < num) {
if (ll[buckj.sj[0]] == max_ll)
clc += buckj.max_lc;
} else {
if (buckj.min_num == buckj.max_num && buckj.max_num == num)
continue;
for (size_t k = 0; k < buckj.sj.size(); ++k)
if (nums[buckj.sj[k]] < num && ll[buckj.sj[k]] == max_ll)
clc += lc[buckj.sj[k]];
}
}
ll[i] = max_ll + 1;
lc[i] = clc;
} else {
ll[i] = 1;
lc[i] = 1;
}
auto & buck = bucks[buck_i];
bool const new_max = buck.sj.empty() || ll[i] > ll[buck.sj[0]];
buck.sj.push_back(i);
std::push_heap(buck.sj.begin(), buck.sj.end(), [&](size_t a, size_t b){
return ll[a] < ll[b];
});
if (ll[i] >= buck.max_ll)
buck.max_lc = lc[i] + (new_max ? u64(0) : buck.max_lc);
buck.max_ll = std::max(buck.max_ll, ll[i]);
buck.min_num = std::min(buck.min_num, num);
buck.max_num = std::max(buck.max_num, num);
}
size_t total_max_ll = 0;
for (size_t i = 0; i < bucks.size(); ++i)
if (!bucks[i].sj.empty())
total_max_ll = std::max(total_max_ll, bucks[i].max_ll);
u64 total_lc = 0;
for (size_t i = 0; i < bucks.size(); ++i)
if (!bucks[i].sj.empty() && ll[bucks[i].sj[0]] == total_max_ll)
total_lc += bucks[i].max_lc;
//std::cout << "Length " << total_max_ll << std::endl;
//std::cout << "Count " << total_lc << std::endl;
return total_lc;
}
};
#include <unordered_map>
#include <climits>
#include <memory>
using namespace std;
class SolutionOriginal {
// recursively compute the max increasing subsequence length and number of them at a given start point
// and a Min which means the next value must be at least Min.The result will be recorded on dp and values in
// dp will be checked first
pair<int, int> length_and_count(int ptr, int Min, vector<int> & nums,
vector<unordered_map<int, pair<int, int>>> & dp) {
if (ptr == nums.size())
return make_pair(0, 1);
if (dp[ptr].find(Min) != dp[ptr].end()) {
return dp[ptr][Min];
} else {
if (nums[ptr] < Min)
return dp[ptr][Min] = length_and_count(ptr + 1, Min, nums, dp);
else {
auto pair1 = length_and_count(ptr + 1, Min, nums, dp),
pair2 = length_and_count(ptr + 1, nums[ptr] + 1, nums, dp);
if (pair1.first > ++pair2.first)
return dp[ptr][Min] = pair1;
else if (pair1.first < pair2.first)
return dp[ptr][Min] = pair2;
else
return dp[ptr][Min] = make_pair(pair1.first, pair1.second + pair2.second);
}
}
}
public:
int findNumberOfLIS(vector<int> nums) {
vector<unordered_map<int, pair<int, int>>> dp(nums.size());
auto p = length_and_count(0, INT_MIN, nums, dp);
// std::cout << "Length " << p.first << std::endl;
// std::cout << "Count " << p.second << std::endl;
return p.second;
}
};
class SolutionBoosted {
// recursively compute the max increasing subsequence length and number of them at a given start point
// and a Min which means the next value must be at least Min.The result will be recorded on dp and values in
// dp will be checked first
pair<int, int> length_and_count(int ptr, int Min, vector<int> & nums,
unordered_map<int, shared_ptr<vector<pair<int, int>>>> & dp,
vector<pair<int, int>> * Min_vec = nullptr) {
if (ptr == nums.size())
return make_pair(0, 1);
auto GetMinVec = [&]{
auto it0 = dp.find(Min);
if (it0 != dp.end())
return it0->second.get();
auto it1 = dp.insert(make_pair(Min, make_shared<
vector<pair<int, int>>>())).first;
it1->second->resize(nums.size(), make_pair(-1, -1));
return it1->second.get();
};
if (!Min_vec)
Min_vec = GetMinVec();
auto & e = (*Min_vec)[ptr];
if (e.first != -1)
return e;
else {
if (nums[ptr] < Min)
return e = length_and_count(ptr + 1, Min, nums, dp, Min_vec);
else {
auto pair1 = length_and_count(ptr + 1, Min, nums, dp, Min_vec),
pair2 = length_and_count(ptr + 1, nums[ptr] + 1, nums, dp);
if (pair1.first > ++pair2.first)
return e = pair1;
else if (pair1.first < pair2.first)
return e = pair2;
else
return e = make_pair(
pair1.first, pair1.second + pair2.second);
}
}
}
public:
int findNumberOfLIS(vector<int> nums) {
unordered_map<int, shared_ptr<vector<pair<int, int>>>> dp;
auto p = length_and_count(0, INT_MIN, nums, dp);
// std::cout << "Length " << p.first << std::endl;
// std::cout << "Count " << p.second << std::endl;
return p.second;
}
};
#include <random>
#include <chrono>
int main() {
std::mt19937_64 rng(123);
std::uniform_int_distribution<int> distr(-100, 100);
std::vector<int> v;
for (size_t i = 0; i < (1 << 11); ++i)
v.push_back(distr(rng));
{
auto tb = std::chrono::system_clock::now();
std::cout << "Sqrt " << SolutionSqrt().findNumberOfLIS(v) << " ";
std::cout << "Time " << std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now() - tb).count() / 1000000.0 << " sec" << std::endl;
}
{
auto tb = std::chrono::system_clock::now();
std::cout << "Original "
<< SolutionOriginal().findNumberOfLIS(v) << " ";
std::cout << "Time " << std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now() - tb).count() / 1000000.0 << " sec" << std::endl;
}
{
auto tb = std::chrono::system_clock::now();
std::cout << "Original_Boosted "
<< SolutionBoosted().findNumberOfLIS(v) << " ";
std::cout << "Time " << std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now() - tb).count() / 1000000.0 << " sec" << std::endl;
}
}
输出:
Sqrt 15240960 Time 0.000856 sec
Original 15240960 Time 0.081954 sec
Original_Boosted 15240960 Time 0.011343 sec