【发布时间】:2014-07-22 20:07:58
【问题描述】:
我最近在Codeforces 开始使用 Scala 解决一些编程挑战,以锻炼函数式编程技能。这样做我遇到了一个特殊的挑战,我无法以尊重给定的 1000 毫秒执行时间限制的方式解决; Painting Fence 问题。
我尝试了各种不同的方法,从直接递归解决方案开始,尝试使用流而不是列表的类似方法,并最终尝试通过更多地使用索引来减少列表操作。我最终在较大的测试中遇到了堆栈溢出异常,我可以使用 Scala 的TailCall. 修复这些异常。但是,尽管该解决方案正确地解决了问题,但在 1000 毫秒内完成太慢了。除此之外,还有一个 C++ 实现,相比之下,它的速度可笑(
这是我的 scala 代码,您可以将其粘贴到 REPL 中,包括需要 >1000 毫秒的示例:
import scala.util.control.TailCalls._
def solve(l: List[(Int, Int)]): Int = {
def go(from: Int, to: Int, prevHeight: Int): TailRec[Int] = {
val max = to - from
val currHeight = l.slice(from, to).minBy(_._1)._1
val hStrokes = currHeight - prevHeight
val splits = l.slice(from, to).filter(_._1 - currHeight == 0).map(_._2)
val indices = from :: splits.flatMap(x => List(x, x+1)) ::: List(to)
val subLists = indices.grouped(2).filter(xs => xs.last - xs.head > 0)
val trampolines = subLists.map(xs => tailcall(go(xs.head, xs.last, currHeight)))
val sumTrampolines = trampolines.foldLeft(done(hStrokes))((b, a) => b.flatMap(bVal =>
a.map(aVal => aVal + bVal)))
sumTrampolines.flatMap(v => done(max).map(m => Math.min(m, v)))
}
go(0, l.size, 0).result
}
val lst = (1 to 5000).toList.zipWithIndex
val res = solve(lst)
为了比较,这里有一个 C++ 示例,实现了 Bugman 编写的相同内容(包括一些我在上面的 Scala 版本中没有包含的来自控制台的读/写):
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <string>
#include <set>
#include <map>
#include <cmath>
#include <memory.h>
using namespace std;
typedef long long ll;
const int N = 1e6+6;
const int T = 1e6+6;
int a[N];
int t[T], d;
int rmq(int i, int j){
int r = i;
for(i+=d,j+=d; i<=j; ++i>>=1,--j>>=1){
if(i&1) r=a[r]>a[t[i]]?t[i]:r;
if(~j&1) r=a[r]>a[t[j]]?t[j]:r;
}
return r;
}
int calc(int l, int r, int h){
if(l>r) return 0;
int m = rmq(l,r);
int mn = a[m];
int res = min(r-l+1, calc(l,m-1,mn)+calc(m+1,r,mn)+mn-h);
return res;
}
int main(){
//freopen("input.txt","r",stdin);// freopen("output.txt","w",stdout);
int n, m;
scanf("%d",&n);
for(int i=0;i<n;++i) scanf("%d",&a[i]);
a[n] = 2e9;
for(d=1;d<n;d<<=1);
for(int i=0;i<n;++i) t[i+d]=i;
for(int i=n+d;i<d+d;++i) t[i]=n;
for(int i=d-1;i;--i) t[i]=a[t[i*2]]<a[t[i*2+1]]?t[i*2]:t[i*2+1];
printf("%d\n",calc(0,n-1,0));
return 0;
}
至少在我介绍显式尾调用之前,在我看来,更实用的风格比更命令式的解决方案更自然地解决问题。所以我真的很高兴能更多地了解在编写函数式代码时应该注意什么才能获得可接受的性能。
【问题讨论】:
-
其中一个痛点是
val indices计算——once I optimized this line alone,在我的机器上总时间从1.360s下降到0.672s。我相信如果您以更明智的方式使用集合,您将获得更好的结果。 -
@om-nom-nom:这在一般情况下是行不通的。