【问题标题】:How to plot multiple regression 3D plot in python如何在python中绘制多元回归3D图
【发布时间】:2016-03-09 06:16:43
【问题描述】:

我不是科学家,所以请假设我不知道有经验的程序员的行话,或者科学绘图技术的复杂性。 Python 是我唯一知道的语言(初学者+,也许是中级)。

任务:将多元回归 (z = f(x, y) ) 的结果绘制为 3D 图上的二维平面(我可以使用 OSX 的绘图实用程序,例如,或在此处实现Plot Regression Surface 与 R)。

经过一周的搜索Stackoverflow并阅读了ma​​tplotlibseabornma​​yavi的各种文档,我终于找到了@ 987654322@ 听起来很有希望。所以这是我的数据和代码:

首先尝试使用 matplotlib:

shape: (80, 3) 
type: <type 'numpy.ndarray'> 
zmul: 

[[  0.00000000e+00   0.00000000e+00   5.52720000e+00]
 [  5.00000000e+02   5.00000000e-01   5.59220000e+00]
 [  1.00000000e+03   1.00000000e+00   5.65720000e+00]
 [  1.50000000e+03   1.50000000e+00   5.72220000e+00]
 [  2.00000000e+03   2.00000000e+00   5.78720000e+00]
 [  2.50000000e+03   2.50000000e+00   5.85220000e+00]
 ……]

import matplotlib
from matplotlib.ticker import MaxNLocator
from matplotlib import cm

from numpy.random import randn
from scipy import array, newaxis
Xs = zmul[:,0]
Ys = zmul[:,1]
Zs = zmul[:,2]


surf = ax.plot_trisurf(Xs, Ys, Zs, cmap=cm.jet, linewidth=0)
fig.colorbar(surf)

ax.xaxis.set_major_locator(MaxNLocator(5))
ax.yaxis.set_major_locator(MaxNLocator(6))
ax.zaxis.set_major_locator(MaxNLocator(5))

fig.tight_layout()

plt.show()

我得到的只是一个空的 3D 坐标系,并带有以下错误消息:

RuntimeError: qhull Delaunay 三角剖分计算出错:奇异输入数据(exitcode=2);使用 python 详细选项 (-v) 查看原始 qhull 错误。

我试图看看我是否可以使用绘图参数并检查了这个网站http://www.qhull.org/html/qh-impre.htm#delaunay,但我真的无法理解我应该做什么。

第二次尝试使用 mayavi:

同样的数据,分成3个numpy数组:

type: <type 'numpy.ndarray'> 
X: [    0   500  1000  1500  2000  2500  3000 ….]

type: <type 'numpy.ndarray'> 
Y: [  0.    0.5   1.    1.5   2.    2.5   3.  ….]

type: <type 'numpy.ndarray'> 
Z: [  5.5272   5.5922   5.6572   5.7222   5.7872   5.8522   5.9172  ….] 

代码:

from mayavi import mlab
def multiple3_triple(tpl_lst):

X = xs
Y = ys
Z = zs


# Define the points in 3D space
# including color code based on Z coordinate.
pts = mlab.points3d(X, Y, Z, Z)

# Triangulate based on X, Y with Delaunay 2D algorithm.
# Save resulting triangulation.
mesh = mlab.pipeline.delaunay2d(pts)

# Remove the point representation from the plot
pts.remove()

# Draw a surface based on the triangulation
surf = mlab.pipeline.surface(mesh)

# Simple plot.
mlab.xlabel("x")
mlab.ylabel("y")
mlab.zlabel("z")
mlab.show()

我得到的只有这个:

如果这很重要,我在 OSX 10.9.3 上使用 64 位版本的 Enthought's Canopy

将不胜感激任何关于我做错了什么的意见。

编辑:发布有效的最终代码,以防它帮助某人。

'''After the usual imports'''
def multiple3(tpl_lst):
    mul = []
    for tpl in tpl_lst:
        calc = (.0001*tpl[0]) + (.017*tpl[1])+ 6.166
        mul.append(calc)
    return mul

fig = plt.figure()
ax = fig.gca(projection='3d')
'''some skipped code for the scatterplot'''
X = np.arange(0, 40000, 500)
Y = np.arange(0, 40, .5)
X, Y = np.meshgrid(X, Y)
Z = multiple3(zip(X,Y))

surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,cmap=cm.autumn,
                       linewidth=0, antialiased=False, alpha =.1)
ax.set_zlim(1.01, 11.01)
ax.set_xlabel(' x = IPP')
ax.set_ylabel('y = UNRP20')
ax.set_zlabel('z = DI')

ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()

【问题讨论】:

  • 对于 xy 值的二维网格,您的数据不是 z=f(x,y),而是沿着 y=x 线。而且你不需要 Mayavi 来绘制 data = f(x,y,z)。对于 z=f(x,y) 数据,Matplotlib 就可以了。
  • 您能否解释一下我的数据在结构上与链接页面“Python 绘制 3d 表面的最简单方法”上的数据有何不同?我试图复制该数据的结构和类型...
  • 对于matplotlib,你可以基于surface example(你缺少plt.meshgrid
  • @FelipeLema 非常感谢,您的建议最终对我有用。您可以将其发布为答案,以便我投票并接受它吗?
  • 哦,太好了!我以为这只是一个远射!

标签: python matplotlib 3d regression mayavi


【解决方案1】:

对于 matplotlib,您可以基于 surface example(您缺少 plt.meshgrid):

from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()
ax = fig.gca(projection='3d')
X = np.arange(-5, 5, 0.25)
Y = np.arange(-5, 5, 0.25)
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)
ax.set_zlim(-1.01, 1.01)

ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()

【讨论】:

    猜你喜欢
    • 2019-12-21
    • 2019-10-05
    • 2022-01-15
    • 1970-01-01
    • 2021-07-21
    • 2013-07-17
    • 2019-07-01
    • 2019-11-30
    • 1970-01-01
    相关资源
    最近更新 更多