【问题标题】:leetcode673:why did my solution run out of time?leetcode673:为什么我的解决方案超时了?
【发布时间】:2022-01-14 02:33:13
【问题描述】:

Leetcode 673: 给定一个整数数组 nums,返回最长递增子序列的数量。 请注意,序列必须严格递增。 这是我的代码:

class Solution {
//recursively compute the max increasing subsequence length and number of them at a given start point
//and a Min which means the next value must be at least Min.The result will be recorded on dp and values in
//dp will be checked first
    pair<int,int> length_and_count(int ptr,int Min,vector<int>& nums,vector<unordered_map<int,pair<int,int>>>& dp){
        if(ptr==nums.size())
        return make_pair(0,1);
        if(dp[ptr].find(Min)!=dp[ptr].end()){
            return dp[ptr][Min];
        } 
        else{
            if(nums[ptr]<Min)
            return dp[ptr][Min]=length_and_count(ptr+1,Min,nums,dp);
            else{
                auto pair1=length_and_count(ptr+1,Min,nums,dp),pair2=length_and_count(ptr+1,nums[ptr]+1,nums,dp);
                if(pair1.first>++pair2.first)
                return dp[ptr][Min]=pair1;
                else if(pair1.first<pair2.first)
                return dp[ptr][Min]=pair2;
                else return dp[ptr][Min]=make_pair(pair1.first,pair1.second+pair2.second);
            }
        }
    } 
public:
    int findNumberOfLIS(vector<int>& nums) {
        vector<unordered_map<int,pair<int,int>>> dp(nums.size());
        return length_and_count(0,INT_MIN,nums,dp).second;
    }
};

我认为我的解决方案的复杂度是O(n2),因为我的dp参数是nums的起点和当前的最大值,是从向量中得到的,所以dp的大小不能大于平方输入向量的大小。由于问题大小小于 2000,我的解决方案应该是 10 毫秒。那么我的解决方案有什么问题?

【问题讨论】:

  • 算法看起来不错,但还是要检查明显的东西:1)您是否使用优化进行编译; 2) 分析您的代码并查看您在哪里浪费时间
  • “我的解决方案应该是”你测试了吗?
  • 请注意,您的代码创建了几百万个 unordered_map 条目。它们很便宜,但不是免费的。
  • 自己尝试使用大小为 2000 的任意数组确实会在不到一秒的时间内返回。

标签: c++ algorithm time-complexity dynamic-programming


【解决方案1】:

感谢您提出非常有趣的问题!很高兴为此实施我自己的解决方案,并提高您的速度。

在测试 2000 个元素的随机输入时,我将您的解决方案的速度提高了 8x-9x 倍。为此,我做了以下事情:

  1. 通过使dp 结构unordered_map&lt;int, shared_ptr&lt;vector&lt;pair&lt;int, int&gt;&gt;&gt;&gt; 交换vector 和unordered_map 的顺序。这样我们以后就可以重用指向无序映射的Min 条目的指针,方法是将这个条目向下传递给递归函数。

  2. dp 中将向量包装成std::shared_ptr。这可以确保即使无序映射增长,向量也不会改变内存位置。

  3. 我没有使用dp[ptr][Min] 再次进行相同的无序映射搜索,而是通过.find(Min) 在单个函数调用中搜索了一次,然后在函数的所有其他位置重用了这个迭代器。

  4. Min_vec 指针向下传递给所有递归调用,以便其他函数调用不会再次在无序映射中搜索相同的Min 条目。只是一个地方没有重用这个指针,nums[ptr] + 1 作为 Min 传递。

  5. 创建了特殊的 lambda 函数GetMinVec(),它找到并返回指向 Min 条目的指针。如果此条目尚不存在,我会通过分配 shared_ptr 并将向量的大小调整为 nums.size() 的大小来初始化它,并用 pair(-1, -1) 填充表示空值。

以上所有步骤都将您的解决方案提高了8x-9x 倍。您的原始解决方案位于我的代码中 SolutionOriginal 类。而我的增强变体是SolutionBoosted

我还创建了自己的解决方案,该解决方案比您的原始解决方案更快100x,比提升原始解决方案更快11x-12x。这个新的解决方案我称之为SolutionSqrt。它的工作原理如下:

  1. 让我们创建K = sqrt(N) 存储桶。每个K 存储桶最后平均会有大约K 元素。

  2. 每个存储桶都将帮助我们缩小所有小于当前元素的元素的搜索范围。对于每个新数字,我们都会找到桶索引buck_i,它表示所有具有较小索引的桶都包含不大于当前的元素。

  3. 在当前存储桶buck_i 中,我们将进行线性搜索以查找所有小于当前存储桶的元素。

  4. 每个存储桶都包含一个 Heap 数组,该数组使用 std::push_heap 增长。这个堆是这样排序的,它的第一个元素(头)包含最大的元素。

  5. Bucket 的 Heap 以这样一种方式排序,即第一个元素在此堆中具有最长的数字链。

  6. 我们平均对每个新号码进行Sqrt(N) + Sqrt(N) 搜索。我们搜索如何通过添加这个新元素来扩展当前最长的链。

  7. 如果最长链被成功延长,那么我们会记住当前桶buck_i的新最大长度和新最大计数。然后新元素成为第一个 Heap 元素,因为它提供了最长的链。

  8. 每个桶都有自己的最长链长度和形成的这种链的数量。

  9. 此类算法的总体复杂度为O(N * Sqrt(N))。如果我们使用 3 级存储桶代替 2 级存储桶,这种复杂性可以进一步提高,在这种情况下,我们将实现复杂性O(N * N^(1/3))。这个O(N * Sqrt(N)) 的复杂性远低于原始OP 的代码O(N^2)

我在 2000 个数字的随机测试集上对所有 3 种解决方案的变体进行了计时。代码后可以看到计时。

Try it online!

#include <cstdint>
#include <vector>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <limits>

class SolutionSqrt {
public:
    using u64 = uint64_t;
    using NumT = int;
    u64 findNumberOfLIS(std::vector<NumT> const & nums) {
        size_t const
            N = nums.size(),
            nbucks = std::max<size_t>(1, std::llround(std::sqrt(N))),
            buck_size = (N + nbucks - 1) / nbucks;
        std::vector<size_t> si(N), rsi(N), ll(N);
        std::vector<u64> lc(N);
        for (size_t i = 0; i < si.size(); ++i)
            si[i] = i;
        std::sort(si.begin(), si.end(), [&](size_t a, size_t b){
            return nums[a] < nums[b];
        });
        for (size_t i = 0; i < rsi.size(); ++i)
            rsi[si[i]] = i;
        struct Buck {
            std::vector<size_t> sj;
            size_t max_ll = 0;
            u64 max_lc = 0;
            NumT max_num = std::numeric_limits<NumT>::min(),
                 min_num = std::numeric_limits<NumT>::max();
        };
        std::vector<Buck> bucks(nbucks);
        for (size_t i = 0; i < N; ++i) {
            auto const num = nums[i];
            size_t const buck_i = rsi[i] / buck_size;
            size_t max_ll = 0;
            for (size_t j = 0; j <= buck_i; ++j) {
                auto const & buckj = bucks[j];
                if (buckj.sj.empty())
                    continue;
                if (j < buck_i && buckj.max_num < num)
                    max_ll = std::max(max_ll, buckj.max_ll);
                else {
                    if (buckj.min_num == buckj.max_num && buckj.max_num == num)
                        continue;
                    for (size_t k = 0; k < buckj.sj.size(); ++k)
                        if (nums[buckj.sj[k]] < num)
                            max_ll = std::max(max_ll, ll[buckj.sj[k]]);
                }
            }
            if (max_ll > 0) {
                u64 clc = 0;
                for (size_t j = 0; j <= buck_i; ++j) {
                    auto const & buckj = bucks[j];
                    if (buckj.sj.empty())
                        continue;
                    if (j < buck_i && buckj.max_num < num) {
                        if (ll[buckj.sj[0]] == max_ll)
                            clc += buckj.max_lc;
                    } else {
                        if (buckj.min_num == buckj.max_num && buckj.max_num == num)
                            continue;
                        for (size_t k = 0; k < buckj.sj.size(); ++k)
                            if (nums[buckj.sj[k]] < num && ll[buckj.sj[k]] == max_ll)
                                clc += lc[buckj.sj[k]];
                    }
                }
                ll[i] = max_ll + 1;
                lc[i] = clc;
            } else {
                ll[i] = 1;
                lc[i] = 1;
            }
            auto & buck = bucks[buck_i];
            bool const new_max = buck.sj.empty() || ll[i] > ll[buck.sj[0]];
            buck.sj.push_back(i);
            std::push_heap(buck.sj.begin(), buck.sj.end(), [&](size_t a, size_t b){
                return ll[a] < ll[b];
            });
            if (ll[i] >= buck.max_ll)
                buck.max_lc = lc[i] + (new_max ? u64(0) : buck.max_lc);
            buck.max_ll = std::max(buck.max_ll, ll[i]);
            buck.min_num = std::min(buck.min_num, num);
            buck.max_num = std::max(buck.max_num, num);
        }
        size_t total_max_ll = 0;
        for (size_t i = 0; i < bucks.size(); ++i)
            if (!bucks[i].sj.empty())
                total_max_ll = std::max(total_max_ll, bucks[i].max_ll);
        u64 total_lc = 0;
        for (size_t i = 0; i < bucks.size(); ++i)
            if (!bucks[i].sj.empty() && ll[bucks[i].sj[0]] == total_max_ll)
                total_lc += bucks[i].max_lc;

        //std::cout << "Length " << total_max_ll << std::endl;
        //std::cout << "Count " << total_lc << std::endl;

        return total_lc;
    }
};

#include <unordered_map>
#include <climits>
#include <memory>

using namespace std;

class SolutionOriginal {
    // recursively compute the max increasing subsequence length and number of them at a given start point
    // and a Min which means the next value must be at least Min.The result will be recorded on dp and values in
    // dp will be checked first
    pair<int, int> length_and_count(int ptr, int Min, vector<int> & nums,
                                    vector<unordered_map<int, pair<int, int>>> & dp) {
        if (ptr == nums.size())
            return make_pair(0, 1);
        if (dp[ptr].find(Min) != dp[ptr].end()) {
            return dp[ptr][Min];
        } else {
            if (nums[ptr] < Min)
                return dp[ptr][Min] = length_and_count(ptr + 1, Min, nums, dp);
            else {
                auto pair1 = length_and_count(ptr + 1, Min, nums, dp),
                     pair2 = length_and_count(ptr + 1, nums[ptr] + 1, nums, dp);
                if (pair1.first > ++pair2.first)
                    return dp[ptr][Min] = pair1;
                else if (pair1.first < pair2.first)
                    return dp[ptr][Min] = pair2;
                else
                    return dp[ptr][Min] = make_pair(pair1.first, pair1.second + pair2.second);
            }
        }
    }

  public:
    int findNumberOfLIS(vector<int> nums) {
        vector<unordered_map<int, pair<int, int>>> dp(nums.size());
        auto p = length_and_count(0, INT_MIN, nums, dp);
        // std::cout << "Length " << p.first << std::endl;
        // std::cout << "Count " << p.second << std::endl;
        return p.second;
    }
};

class SolutionBoosted {
    // recursively compute the max increasing subsequence length and number of them at a given start point
    // and a Min which means the next value must be at least Min.The result will be recorded on dp and values in
    // dp will be checked first
    pair<int, int> length_and_count(int ptr, int Min, vector<int> & nums,
            unordered_map<int, shared_ptr<vector<pair<int, int>>>> & dp,
            vector<pair<int, int>> * Min_vec = nullptr) {
        if (ptr == nums.size())
            return make_pair(0, 1);
        auto GetMinVec = [&]{
            auto it0 = dp.find(Min);
            if (it0 != dp.end())
                return it0->second.get();

            auto it1 = dp.insert(make_pair(Min, make_shared<
                vector<pair<int, int>>>())).first;
            it1->second->resize(nums.size(), make_pair(-1, -1));
            return it1->second.get();
        };
        if (!Min_vec)
            Min_vec = GetMinVec();
        auto & e = (*Min_vec)[ptr];
        if (e.first != -1)
            return e;
        else {
            if (nums[ptr] < Min)
                return e = length_and_count(ptr + 1, Min, nums, dp, Min_vec);
            else {
                auto pair1 = length_and_count(ptr + 1, Min, nums, dp, Min_vec),
                     pair2 = length_and_count(ptr + 1, nums[ptr] + 1, nums, dp);
                if (pair1.first > ++pair2.first)
                    return e = pair1;
                else if (pair1.first < pair2.first)
                    return e = pair2;
                else
                    return e = make_pair(
                        pair1.first, pair1.second + pair2.second);
            }
        }
    }

  public:
    int findNumberOfLIS(vector<int> nums) {
        unordered_map<int, shared_ptr<vector<pair<int, int>>>> dp;
        auto p = length_and_count(0, INT_MIN, nums, dp);
        // std::cout << "Length " << p.first << std::endl;
        // std::cout << "Count " << p.second << std::endl;
        return p.second;
    }
};

#include <random>
#include <chrono>

int main() {
    std::mt19937_64 rng(123);
    std::uniform_int_distribution<int> distr(-100, 100);
    std::vector<int> v;
    for (size_t i = 0; i < (1 << 11); ++i)
        v.push_back(distr(rng));
    {
        auto tb = std::chrono::system_clock::now();
        std::cout << "Sqrt " << SolutionSqrt().findNumberOfLIS(v) << "  ";
        std::cout << "Time " << std::chrono::duration_cast<std::chrono::microseconds>(
            std::chrono::system_clock::now() - tb).count() / 1000000.0 << " sec" << std::endl;
    }
    {
        auto tb = std::chrono::system_clock::now();
        std::cout << "Original "
            << SolutionOriginal().findNumberOfLIS(v) << "  ";
        std::cout << "Time " << std::chrono::duration_cast<std::chrono::microseconds>(
            std::chrono::system_clock::now() - tb).count() / 1000000.0 << " sec" << std::endl;
    }
    {
        auto tb = std::chrono::system_clock::now();
        std::cout << "Original_Boosted "
            << SolutionBoosted().findNumberOfLIS(v) << "  ";
        std::cout << "Time " << std::chrono::duration_cast<std::chrono::microseconds>(
            std::chrono::system_clock::now() - tb).count() / 1000000.0 << " sec" << std::endl;
    }
}

输出:

Sqrt 15240960  Time 0.000856 sec
Original 15240960  Time 0.081954 sec
Original_Boosted 15240960  Time 0.011343 sec

【讨论】:

    猜你喜欢
    • 2015-11-05
    • 2015-09-12
    • 1970-01-01
    • 2011-06-15
    • 2017-06-26
    • 1970-01-01
    • 2017-08-14
    • 2012-01-25
    相关资源
    最近更新 更多