主要问题是交替求和的公式极易出现数值精度问题。
避免右侧问题的一个技巧是假设分布是对称的,只计算一半。
一个直接的精度优化是通过调用scipy.special.comb 来替换combinaison 公式中的阶乘。这避免了需要划分非常大的数字。
一个较小的精度优化是同时计算偶数和奇数的g。但是乍一看公式不能减少多少,所以替换:
for k in range(0, int(floor(n * y[i] + 1))):
g += pow(-1, k) * combinaison(n, k) * pow(y[i] - k / n, n - 1)
作者:
last_k = int(floor(n * y[i]))
for k in range(0, last_k + 1, 2): # note that k increments in steps of 2
if k == last_k:
g += combinaison(n, k) * (pow(y[i] - k / n, n - 1))
else:
g += combinaison(n, k) * (pow(y[i] - k / n, n - 1) - pow(y[i] - (k + 1)/ n, n - 1) * (n - k) / (k + 1))
其他一些评论:
- 变量
samples仅用于告诉xaxis中的除法。一个小得多的数字就足够了。 (在下面的代码中,我将变量重命名为 xaxis_steps)。
- 将
append 用于F 会非常慢。最好创建一个正确大小的 numpy 数组,然后将其填充。(这也使得复制一半更容易。)
from matplotlib import pyplot as plt
import numpy as np
from scipy.special import comb
from math import factorial as fac
from math import floor
xaxis_steps = 500
def combinaison(n, k): # combination of K out of N
return comb(n, k)
def dens_probas(a, b, n):
x = np.linspace(a, b, num=xaxis_steps)
y = (x - a) / (b - a)
F = np.zeros_like(y)
for i in range(0, (len(y)+1) // 2):
g = 0
for k in range(0, int(floor(n * y[i] + 1))):
g += pow(-1, k) * combinaison(n, k) * pow(y[i] - k / n, n - 1)
F[i] = (n ** n / fac(n - 1)) * g
F[-i-1] = F[i] # symmetric graph
plt.plot(x, F, label=f'n={n}')
return F
for n in (5, 30, 50, 80, 90):
dens_probas(-1, 1, n)
plt.legend()
plt.show()
所有这些优化共同将准确度问题从n=30 转移到n=80 附近:
一种完全不同的方法是生成大量统一的样本并采取手段。从这些样本中可以生成kde 图。这种曲线的平滑度取决于样本的数量。可以通过seaborn's kdeplot 直接绘制 kde。您也可以单独calculate the kde function,然后将其应用于给定的 x 范围并通过标准 matplotlib 进行绘制。
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import gaussian_kde
num_samples = 10 ** 5
def dens_probas(a, b, n):
samples = np.random.uniform(a, b, size=(num_samples, n)).mean(axis=1)
samples = np.hstack([samples, a + b - samples]) # force symmetry; this is not strictly necessary
return gaussian_kde(samples)
for n in (5, 30, 50, 80, 90, 200):
kde = dens_probas(-1, 1, n)
xs = np.linspace(-1, 1, 1000)
F = kde(xs)
plt.plot(xs, F, label=f'n={n}')
plt.legend()
plt.show()