【问题标题】:Plot normal distribution in 3D在 3D 中绘制正态分布
【发布时间】:2021-04-30 14:42:53
【问题描述】:

我正在尝试绘制两个正态分布变量的共同分布。

下面的代码绘制了一个正态分布变量。绘制两个正态分布变量的代码是什么?

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.mlab as mlab
import math

mu = 0
variance = 1
sigma = math.sqrt(variance)
x = np.linspace(-3, 3, 100)
plt.plot(x,mlab.normpdf(x, mu, sigma))

plt.show()

【问题讨论】:

标签: python matplotlib mplot3d


【解决方案1】:

听起来您正在寻找的是Multivariate Normal Distribution。这在 scipy 中实现为scipy.stats.multivariate_normal。请务必记住,您将协方差矩阵传递给函数。所以为了简单起见,保持非对角元素为零:

[X variance ,     0    ]
[     0     ,Y Variance]

这是一个使用此函数并生成结果分布的 3D 图的示例。我添加了颜色图以便更轻松地查看曲线,但可以随意删除它。

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from mpl_toolkits.mplot3d import Axes3D

#Parameters to set
mu_x = 0
variance_x = 3

mu_y = 0
variance_y = 15

#Create grid and multivariate normal
x = np.linspace(-10,10,500)
y = np.linspace(-10,10,500)
X, Y = np.meshgrid(x,y)
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X; pos[:, :, 1] = Y
rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])

#Make a 3D plot
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, rv.pdf(pos),cmap='viridis',linewidth=0)
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
plt.show()

给你这个情节:

编辑下面使用的方法在 Matplotlib v2.2 中已弃用并在 v3.1 中删除

可通过matplotlib.mlab.bivariate_normal 获得更简单的版本 它采用以下参数,因此您无需担心矩阵 matplotlib.mlab.bivariate_normal(X, Y, sigmax=1.0, sigmay=1.0, mux=0.0, muy=0.0, sigmaxy=0.0) 这里 X 和 Y 再次是网格网格的结果,因此使用它来重新创建上面的图:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.mlab import bivariate_normal
from mpl_toolkits.mplot3d import Axes3D

#Parameters to set
mu_x = 0
sigma_x = np.sqrt(3)

mu_y = 0
sigma_y = np.sqrt(15)

#Create grid and multivariate normal
x = np.linspace(-10,10,500)
y = np.linspace(-10,10,500)
X, Y = np.meshgrid(x,y)
Z = bivariate_normal(X,Y,sigma_x,sigma_y,mu_x,mu_y)

#Make a 3D plot
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z,cmap='viridis',linewidth=0)
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
plt.show()

给予:

【讨论】:

  • 应该是from matplotlib.mlab import bivariate_normal
  • 看来bivariate_normal 最终会被删除:MatplotlibDeprecationWarning: The bivariate_normal function was deprecated in Matplotlib 2.2 and will be removed in 3.1.
  • 假设高斯函数是我的可能性,但我想将它乘以其相关的先验 (Pa)。如何在此代码中包含 Pa?
【解决方案2】:

以下对@Ianhi 上面代码的改编返回了上面 3D 图的等高线图版本。

import matplotlib.pyplot as plt
from matplotlib import style
style.use('fivethirtyeight')
import numpy as np
from scipy.stats import multivariate_normal




#Parameters to set
mu_x = 0
variance_x = 3

mu_y = 0
variance_y = 15

x = np.linspace(-10,10,500)
y = np.linspace(-10,10,500)
X,Y = np.meshgrid(x,y)

pos = np.array([X.flatten(),Y.flatten()]).T



rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])


fig = plt.figure(figsize=(10,10))
ax0 = fig.add_subplot(111)
ax0.contour(X, Y, rv.pdf(pos).reshape(500,500))

plt.show()

【讨论】:

    【解决方案3】:

    虽然其他答案很好,但我希望获得类似的结果,同时还用样本的散点图说明分布。

    更多详情可以在这里找到:Python 3d plot of multivariate gaussian distribution

    结果如下:

    并使用以下代码生成:

    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from scipy.stats import multivariate_normal
    
    
    # Sample parameters
    mu = np.array([0, 0])
    sigma = np.array([[0.7, 0.2], [0.2, 0.3]])
    rv = multivariate_normal(mu, sigma)
    sample = rv.rvs(500)
    
    # Bounds parameters
    x_abs = 2.5
    y_abs = 2.5
    x_grid, y_grid = np.mgrid[-x_abs:x_abs:.02, -y_abs:y_abs:.02]
    
    pos = np.empty(x_grid.shape + (2,))
    pos[:, :, 0] = x_grid
    pos[:, :, 1] = y_grid
    
    levels = np.linspace(0, 1, 40)
    
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    
    # Removes the grey panes in 3d plots
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    
    # The heatmap
    ax.contourf(x_grid, y_grid, 0.1 * rv.pdf(pos),
                zdir='z', levels=0.1 * levels, alpha=0.9)
    
    # The wireframe
    ax.plot_wireframe(x_grid, y_grid, rv.pdf(
        pos), rstride=10, cstride=10, color='k')
    
    # The scatter. Note that the altitude is defined based on the pdf of the
    # random variable
    ax.scatter(sample[:, 0], sample[:, 1], 1.05 * rv.pdf(sample), c='k')
    
    ax.legend()
    ax.set_title("Gaussian sample and pdf")
    ax.set_xlim3d(-x_abs, x_abs)
    ax.set_ylim3d(-y_abs, y_abs)
    ax.set_zlim3d(0, 1)
    
    plt.show()
    

    【讨论】:

      最近更新 更多