【问题标题】:How to improve multiplication efficiency in big integer?如何提高大整数的乘法效率?
【发布时间】:2021-04-15 05:45:31
【问题描述】:

这个周末我跟着wiki实现了基本的大整数乘法。我使用Tom-3算法来实现。但是一开始花费的时间比长乘法(小学乘法)慢,而且一去不复返。我希望程序能在500位以内超过小学乘法,请问我该怎么做?

我尝试优化,我保留向量容量并去除多余的代码。但是效果不是很好。

我应该使用vector<long long> 作为我的基数吗?

Github中的完整源代码:

typedef long long BigIntBase;
typedef vector<BigIntBase> BigIntDigits;

// ceil(numeric_limits<BigIntBase>::digits10 / 2.0) - 1;
static const int digit_base_len = 9;
// b
static const BigIntBase digit_base = 1000000000;

class BigInt {

public:
  BigInt(int digits_capacity = 0, bool nega = false) {
    negative = nega;
    digits.reserve(digits_capacity);
  }

  BigInt(BigIntDigits _digits, bool nega = false) {
    negative = nega;
    digits = _digits;
  }

  BigInt(const span<const BigIntBase> &range, bool nega = false) {
    negative = nega;
    digits = BigIntDigits(range.begin(), range.end());
  }

  BigInt operator+(const BigInt &rhs) {
    if ((*this).negative == rhs.negative)
      return BigInt(plus((*this).digits, rhs.digits), (*this).negative);

    if (greater((*this).digits, rhs.digits))
      return BigInt(minus((*this).digits, rhs.digits), (*this).negative);

    return BigInt(minus(rhs.digits, (*this).digits), rhs.negative);
  }

  BigInt operator-(const BigInt &rhs) { return *this + BigInt(rhs.digits, !rhs.negative); }

  BigInt operator*(const BigInt &rhs) {
    if ((*this).digits.empty() || rhs.digits.empty()) {
      return BigInt();
    } else if ((*this).digits.size() == 1 && rhs.digits.size() == 1) {
      BigIntBase val = (*this).digits[0] * rhs.digits[0];
      return BigInt(val < digit_base ? BigIntDigits{val} : BigIntDigits{val % digit_base, val / digit_base}, (*this).negative ^ rhs.negative);
    } else if ((*this).digits.size() == 1)
      return BigInt(multiply(rhs, (*this).digits[0]).digits, (*this).negative ^ rhs.negative);
    else if (rhs.digits.size() == 1)
      return BigInt(multiply((*this), rhs.digits[0]).digits, (*this).negative ^ rhs.negative);

    return BigInt(toom3(span((*this).digits), span(rhs.digits)), (*this).negative ^ rhs.negative);
  }

  string to_string() {
    if (this->digits.empty())
      return "0";

    stringstream ss;
    if (this->negative)
      ss << "-";

    ss << std::to_string(this->digits.back());
    for (auto it = this->digits.rbegin() + 1; it != this->digits.rend(); ++it)
      ss << setw(digit_base_len) << setfill('0') << std::to_string(*it);

    return ss.str();
  }

  BigInt from_string(string s) {
    digits.clear();
    negative = s[0] == '-';
    for (int pos = max(0, (int)s.size() - digit_base_len); pos >= 0; pos -= digit_base_len)
      digits.push_back(stoll(s.substr(pos, digit_base_len)));

    if (s.size() % digit_base_len)
      digits.push_back(stoll(s.substr(0, s.size() % digit_base_len)));

    return *this;
  }

private:
  bool negative;
  BigIntDigits digits;

  const span<const BigIntBase> toom3_slice_num(const span<const BigIntBase> &num, const int &n, const int &i) {
    int begin = n * i;
    if (begin < num.size()) {
      const span<const BigIntBase> result = num.subspan(begin, min((int)num.size() - begin, i));
      return result;
    }

    return span<const BigIntBase>();
  }

  BigIntDigits toom3(const span<const BigIntBase> &num1, const span<const BigIntBase> &num2) {
    int i = ceil(max(num1.size() / 3.0, num2.size() / 3.0));
    const span<const BigIntBase> m0 = toom3_slice_num(num1, 0, i);
    const span<const BigIntBase> m1 = toom3_slice_num(num1, 1, i);
    const span<const BigIntBase> m2 = toom3_slice_num(num1, 2, i);
    const span<const BigIntBase> n0 = toom3_slice_num(num2, 0, i);
    const span<const BigIntBase> n1 = toom3_slice_num(num2, 1, i);
    const span<const BigIntBase> n2 = toom3_slice_num(num2, 2, i);

    BigInt pt0 = plus(m0, m2);
    BigInt pp0 = m0;
    BigInt pp1 = plus(pt0.digits, m1);
    BigInt pn1 = pt0 - m1;
    BigInt pn2 = multiply(pn1 + m2, 2) - m0;
    BigInt pin = m2;

    BigInt qt0 = plus(n0, n2);
    BigInt qp0 = n0;
    BigInt qp1 = plus(qt0.digits, n1);
    BigInt qn1 = qt0 - n1;
    BigInt qn2 = multiply(qn1 + n2, 2) - n0;
    BigInt qin = n2;

    BigInt rp0 = pp0 * qp0;
    BigInt rp1 = pp1 * qp1;
    BigInt rn1 = pn1 * qn1;
    BigInt rn2 = pn2 * qn2;
    BigInt rin = pin * qin;

    BigInt r0 = rp0;
    BigInt r4 = rin;
    BigInt r3 = divide(rn2 - rp1, 3);
    BigInt r1 = divide(rp1 - rn1, 2);
    BigInt r2 = rn1 - rp0;
    r3 = divide(r2 - r3, 2) + multiply(rin, 2);
    r2 = r2 + r1 - r4;
    r1 = r1 - r3;

    BigIntDigits result = r0.digits;
    if (!r1.digits.empty()) {
      shift_left(r1.digits, i);
      result = plus(result, r1.digits);
    }

    if (!r2.digits.empty()) {
      shift_left(r2.digits, i << 1);
      result = plus(result, r2.digits);
    }

    if (!r3.digits.empty()) {
      shift_left(r3.digits, i * 3);
      result = plus(result, r3.digits);
    }

    if (!r4.digits.empty()) {
      shift_left(r4.digits, i << 2);
      result = plus(result, r4.digits);
    }

    return result;
  }

  BigIntDigits plus(const span<const BigIntBase> &lhs, const span<const BigIntBase> &rhs) {
    if (lhs.empty())
      return BigIntDigits(rhs.begin(), rhs.end());

    if (rhs.empty())
      return BigIntDigits(lhs.begin(), lhs.end());

    int max_length = max(lhs.size(), rhs.size());
    BigIntDigits result;
    result.reserve(max_length + 1);

    for (int w = 0; w < max_length; ++w)
      result.push_back((lhs.size() > w ? lhs[w] : 0) + (rhs.size() > w ? rhs[w] : 0));

    for (int w = 0; w < result.size() - 1; ++w) {
      result[w + 1] += result[w] / digit_base;
      result[w] %= digit_base;
    }

    if (result.back() >= digit_base) {
      result.push_back(result.back() / digit_base);
      result[result.size() - 2] %= digit_base;
    }

    return result;
  }

  BigIntDigits minus(const span<const BigIntBase> &lhs, const span<const BigIntBase> &rhs) {
    if (lhs.empty())
      return BigIntDigits(rhs.begin(), rhs.end());

    if (rhs.empty())
      return BigIntDigits(lhs.begin(), lhs.end());

    BigIntDigits result;
    result.reserve(lhs.size() + 1);

    for (int w = 0; w < lhs.size(); ++w)
      result.push_back((lhs.size() > w ? lhs[w] : 0) - (rhs.size() > w ? rhs[w] : 0));

    for (int w = 0; w < result.size() - 1; ++w)
      if (result[w] < 0) {
        result[w + 1] -= 1;
        result[w] += digit_base;
      }

    while (!result.empty() && !result.back())
      result.pop_back();

    return result;
  }

  void shift_left(BigIntDigits &lhs, const int n) {
    if (!lhs.empty()) {
      BigIntDigits zeros(n, 0);
      lhs.insert(lhs.begin(), zeros.begin(), zeros.end());
    }
  }

  BigInt divide(const BigInt &lhs, const int divisor) {
    BigIntDigits reminder(lhs.digits);
    BigInt result(lhs.digits.capacity(), lhs.negative);

    for (int w = reminder.size() - 1; w >= 0; --w) {
      result.digits.insert(result.digits.begin(), reminder[w] / divisor);
      reminder[w - 1] += (reminder[w] % divisor) * digit_base;
    }

    while (!result.digits.empty() && !result.digits.back())
      result.digits.pop_back();

    return result;
  }

  BigInt multiply(const BigInt &lhs, const int multiplier) {
    BigInt result(lhs.digits, lhs.negative);

    for (int w = 0; w < result.digits.size(); ++w)
      result.digits[w] *= multiplier;

    for (int w = 0; w < result.digits.size(); ++w)
      if (result.digits[w] >= digit_base) {
        if (w + 1 == result.digits.size())
          result.digits.push_back(result.digits[w] / digit_base);
        else
          result.digits[w + 1] += result.digits[w] / digit_base;
        result.digits[w] %= digit_base;
      }

    return result;
  }

  bool greater(const BigIntDigits &lhs, const BigIntDigits &rhs) {
    if (lhs.size() == rhs.size()) {
      int w = lhs.size() - 1;
      while (w >= 0 && lhs[w] == rhs[w])
        --w;

      return w >= 0 && lhs[w] > rhs[w];
    } else
      return lhs.size() > rhs.size();
  }
};
Digits Grade-school Toom-3
10 4588 10003
50 24147 109084
100 52165 286535
150 92405 476275
200 172156 1076570
250 219599 1135946
300 320939 1530747
350 415655 1689745
400 498172 1937327
450 614467 2629886
500 863116 3184277

【问题讨论】:

  • 您是在安排发布还是优化构建?
  • 发布版本是什么?
  • 好吧,因为release_capacity 的默认参数是0,这最终会调用reserve(0),这并没有完成任何事情。在同一个构造函数中,digits = BigIntDigits(); 也绝对什么都不做,因为digits 已经是默认构造的。不是很浪费,但还是很浪费。另一个构造函数也不保留任何东西。因此,乍一看会发现一些明显的疏忽,所以我预计其他算法很可能也缺少明显的优化。
  • 根据你的编译器,它可能需要一些像-O3这样的标志(至少在发行版中)
  • @Cliff 您必须始终在基准测试上进行优化。否则代码只对调试有用

标签: c++ algorithm performance multiplication biginteger


【解决方案1】:

问题是您在 toom3_slice_num 中进行了一百万次分配,在这里您可以使用 std::span(或实际部分的 std::pair 迭代器),因为您给出的数字是 const。 toom3 也是分配器地狱。

multiply 可能会多分配 1 个时间。计算所需的位数或将大小加 1。

对于几乎无锁的分配,vectors 应该是 pmr(带有适当的分配器)。

如果不使用-O2-O3 编译,所有这些都是浪费的。

【讨论】:

  • 我听取了您的一些建议。我使用 span 更新代码以切片 toom3 nums 并使用 -O3 标志进行编译。它确实提高了效率。原来花费的时间是总时间的5%,现在是1%,谢谢。
猜你喜欢
  • 1970-01-01
  • 2013-07-28
  • 1970-01-01
  • 2019-04-27
  • 1970-01-01
  • 2013-07-03
  • 2012-06-28
  • 1970-01-01
  • 2013-10-21
相关资源
最近更新 更多