问题及解决方案
在对诸如implicit none等问题大惊小怪之后,似乎问题实际上出在阶乘函数中。
即使 facHelper 已被定义为
recursive integer function facHelper(n, acc) result(returner)
并且在 facHelper 之后定义了阶乘函数,阶乘仍然不知道 facHelper 返回一个整数。
然后解决方案是在阶乘函数中添加一行,告诉它 facHelper 是整数:
function factorial(n)
integer::n
integer::factorial
integer::facHelper ! <-- Missing information now available
factorial = facHelper(n, 1)
end function factorial
回答问题
尾递归阶乘实现在 Fortran95 中的表现如何?
Fortran95 中的尾递归阶乘函数可以实现为:
fac.f95
recursive function facHelper(n, acc) result(returner)
integer::n
integer::acc
integer::returner
if (n <= 1) then
returner = acc
else
returner = facHelper(n - 1, n * acc)
endif
end function facHelper
function factorial(n)
integer::n
integer::factorial
integer::facHelper
factorial = facHelper(n, 1)
end function factorial
或者,更美观(在我看来):
fac.f95
recursive integer function facHelper(n, acc) result(returner)
integer::n
integer::acc
if (n <= 1) then
returner = acc
else
returner = facHelper(n - 1, n * acc)
endif
end function facHelper
integer function factorial(n)
integer::n
integer::facHelper
factorial = facHelper(n, 1)
end function factorial
这两个现在都可以在 GNU Fortran (GCC) 4.8.3 下编译
gfortran --std=f95 -c ./fac.f95
多余的复活节彩蛋
将 f2py v2 与 NumPy v1.8.0 结合使用
虽然 fac.f95 abover 的两个版本都可以使用 gfortran 编译,但第二个版本会导致 f2py 认为来自 facHelper 的返回者是真实的。但是,f2py 确实可以正确处理 fac.f95 的第一个版本。
我想在 Fortran 中对(时间)尾递归阶乘进行基准测试。我添加了阶乘的非尾递归版本(名为 vanillaFac)。整数大小也增加到 kind=8。
fac.f95 现在包含
recursive function tailFacHelper(n, acc) result(returner)
integer (kind=8)::n
integer (kind=8)::acc
integer (kind=8)::returner
if (n <= 1) then
returner = acc
else
returner = tailFacHelper(n - 1, n * acc)
endif
end function tailFacHelper
function tailFac(n)
integer (kind=8)::n
integer (kind=8)::tailFac
integer (kind=8)::tailFacHelper
tailFac = tailFacHelper(n, 1_8)
end function tailFac
recursive function vanillaFac(n) result(returner)
integer (kind=8)::n
integer (kind=8)::returner
if (n <= 1) then
returner = 1
else
returner = n * vanillaFac(n - 1)
endif
end function vanillaFac
新的fac.f95是用
编译的
f2py --overwrite-signature --no-lower fac.f95 -m liboptfac -h fac.pyf;
f2py -c --f90flags=--std=f95 --opt=-O3 fac.pyf fac.f95;
f2py --overwrite-signature --no-lower fac.f95 -m libnooptfac -h fac.pyf;
f2py -c --f90flags=--std=f95 --noopt fac.pyf fac.f95;
我写了一个python脚本timer.py
import liboptfac
import libnooptfac
import timeit
def py_vanilla_fac(n):
if n <= 1:
return 1
else:
return n * py_vanilla_fac(n - 1)
def py_tail_fac_helper(n, acc):
if n <= 1:
return acc
else:
return py_tail_fac_helper(n - 1, n * acc)
def py_tail_fac(n):
return py_tail_fac_helper(n, 1)
LOOPS = 10 ** 6
print "\n*****Fortran (optimizations level 03 enabled)*****"
print "\nliboptfac.vanillaFac(20)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: liboptfac.vanillaFac(20), repeat = 10, number = LOOPS))
print "\nliboptfac.tailFac(20)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: liboptfac.tailFac(20), repeat = 10, number = LOOPS))
print "\nliboptfac.tailFacHelper(20, 1)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: liboptfac.tailFacHelper(20, 1), repeat = 10, number = LOOPS))
print "\n\n*****Fortran (no optimizations enabled)*****"
print "\nlibnooptfac.vanillaFac(20)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: libnooptfac.vanillaFac(20), repeat = 10, number = LOOPS))
print "\nlibnooptfac.tailFac(20)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: libnooptfac.tailFac(20), repeat = 10, number = LOOPS))
print "\nlibnooptfac.tailFacHelper(20, 1)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: libnooptfac.tailFacHelper(20, 1), repeat = 10, number = LOOPS))
print "\n\n*****Python*****"
print "\npy_vanilla_fac(20)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: py_vanilla_fac(20), repeat = 10, number = LOOPS))
print "\npy_tail_fac(20)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: py_tail_fac(20), repeat = 10, number = LOOPS))
print "\npy_tail_fac_helper(20, 1)\n" + str(LOOPS) + " calls\nBest of ten: ", min(timeit.repeat(lambda: py_tail_fac_helper(20, 1), repeat = 10, number = LOOPS))
print "\n\n\n"
终于来了
python timer.py
输出:
*****Fortran (optimizations level 03 enabled)*****
liboptfac.vanillaFac(20)
1000000 calls
Best of ten: 0.813575983047
liboptfac.tailFac(20)
1000000 calls
Best of ten: 0.843787193298
liboptfac.tailFacHelper(20, 1)
1000000 calls
Best of ten: 0.858899831772
*****Fortran (no optimizations enabled)*****
libnooptfac.vanillaFac(20)
1000000 calls
Best of ten: 1.00723600388
libnooptfac.tailFac(20)
1000000 calls
Best of ten: 0.975327014923
libnooptfac.tailFacHelper(20, 1)
1000000 calls
Best of ten: 0.982407093048
*****Python*****
py_vanilla_fac(20)
1000000 calls
Best of ten: 6.47849297523
py_tail_fac(20)
1000000 calls
Best of ten: 6.93045401573
py_tail_fac_helper(20, 1)
1000000 calls
Best of ten: 6.81205391884