【问题标题】:Finding a slope of 3D linear regression line寻找 3D 线性回归线的斜率
【发布时间】:2020-05-05 19:21:10
【问题描述】:

我试图在 3D 空间中找到一条线的斜率。 this post

中给出了绘制此类线的解决方案

这是上面链接中给定的代码:

import numpy as np

pts = np.add.accumulate(np.random.random((10,3)))
x,y,z = pts.T

# this will find the slope and x-intercept of a plane
# parallel to the y-axis that best fits the data
A_xz = np.vstack((x, np.ones(len(x)))).T
m_xz, c_xz = np.linalg.lstsq(A_xz, z)[0]

# again for a plane parallel to the x-axis
A_yz = np.vstack((y, np.ones(len(y)))).T
m_yz, c_yz = np.linalg.lstsq(A_yz, z)[0]

# the intersection of those two planes and
# the function for the line would be:
# z = m_yz * y + c_yz
# z = m_xz * x + c_xz
# or:
def lin(z):
    x = (z - c_xz)/m_xz
    y = (z - c_yz)/m_yz
    return x,y

#verifying:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = Axes3D(fig)
zz = np.linspace(0,5)
xx,yy = lin(zz)
ax.scatter(x, y, z)
ax.plot(xx,yy,zz)
plt.savefig('test.png')
plt.show()

在数学方面,我知道如何找到两个平面的交点和给定直线的斜率,但我无法将其放入代码中。

如何使用此解决方案找到所得回归线的斜率?

【问题讨论】:

    标签: python numpy matplotlib regression linear-regression


    【解决方案1】:

    3D 线的“斜率​​”通常被认为是“投影”到 x、y 和 z 平面上的线的斜率。见this question第二个回答

    如果这是您想要的,那么计算这些就很容易了;下面这个修改后的代码版本在 sxsysz 变量中执行此操作:

    import numpy as np
    import matplotlib.pyplot as plt
    from   mpl_toolkits.mplot3d import Axes3D
    from   math import pow, sqrt
    
    pts = np.add.accumulate(np.random.random((10,3)))
    
    x, y, z = pts.T
    
    # plane parallel to the y-axis
    A_xz = np.vstack((x, np.ones(len(x)))).T
    m_xz, c_xz = np.linalg.lstsq(A_xz, z, rcond=None)[0]
    
    # plane parallel to the x-axis
    A_yz = np.vstack((y, np.ones(len(y)))).T
    m_yz, c_yz = np.linalg.lstsq(A_yz, z, rcond=None)[0]
    
    # the intersection of those two planes and
    # the function for the line would be:
    # z = m_yz * y + c_yz
    # z = m_xz * x + c_xz
    # or:
    def lin(z):
        x = (z - c_xz)/m_xz
        y = (z - c_yz)/m_yz
        return x,y
    
    
    # get 2 points on the intersection line 
    za = z[0]
    zb = z[len(z) - 1]
    xa, ya = lin(za)
    xb, yb = lin(zb)
    
    # get distance between points
    len = sqrt(pow(xb - xa, 2) + pow(yb - ya, 2) + pow(zb - za, 2))
    
    # get slopes (projections onto x, y and z planes)
    sx = (xb - xa) / len  # x slope
    sy = (yb - ya) / len  # y slope
    sz = (zb - za) / len  # z slope
    
    # integrity check - the sum of squares of slopes should equal 1.0
    # print (pow(sx, 2) + pow(sy, 2) + pow(sz, 2))
    
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.set_xlabel("x, slope: %.4f" %sx, color='blue')
    ax.set_ylabel("y, slope: %.4f" %sy, color='blue')
    ax.set_zlabel("z, slope: %.4f" %sz, color='blue')
    ax.scatter(x, y, z)
    ax.plot([xa], [ya], [za], markerfacecolor='k', markeredgecolor='k', marker = 'o')
    ax.plot([xb], [yb], [zb], markerfacecolor='k', markeredgecolor='k', marker = 'o')
    ax.plot([xa, xb], [ya, yb], [za, zb], color = 'r')
    
    plt.show()
    

    下面的输出图显示了有问题的线,它只是在两个极端 xyz 点之间绘制。

    我希望这可能会有所帮助

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-02-25
      • 2016-01-22
      • 1970-01-01
      • 1970-01-01
      • 2016-06-07
      • 2013-09-27
      相关资源
      最近更新 更多