【问题标题】:Matplotlib Error TypeError: Cannot cast array data from dtype('float64') to dtype('<U32') according to the rule 'safe'Matplotlib 错误 TypeError:无法根据规则“安全”将数组数据从 dtype('float64') 转换为 dtype('<U32')
【发布时间】:2019-02-27 21:44:49
【问题描述】:

对不起,我对 python 知之甚少,但我试图在 3d 图形中输出一个 csv 文件(csvfile)数据集。到目前为止我的代码如下:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import csv
   fig = plt.figure()
   ax = fig.add_subplot(111, projection='3d')

with open('new3.csv') as csvfile:
readCSV = csv.reader(csvfile, delimiter=',')
next(readCSV)
next(readCSV)
next(readCSV)
XS =[]
YS =[]
ZS =[]
for column in readCSV:
    xs = column[1]
    ys = column[2]
    zs = column[3]

    XS.append(xs)
    YS.append(ys)
    ZS.append(zs)
    ax.scatter(XS, YS, ZS, c='r', marker='o')
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')

    plt.show()

但我一直想出标题中的错误。任何帮助表示赞赏

【问题讨论】:

    标签: python matplotlib


    【解决方案1】:

    错误是因为您试图绘制三个str 类型对象的列表。它们需要是float 或类似类型,并且不能被隐式转换。您可以通过以下修改显式进行类型转换:

    for column in readCSV:
            xs = float(column[1])
            ys = float(column[2])
            zs = float(column[3])
    

    还要注意ax.scatter 应该在循环之外,像这样

        for column in readCSV:
            xs = float(column[1])
            ys = float(column[2])
            zs = float(column[3])
    
            XS.append(xs)
            YS.append(ys)
            ZS.append(zs)
    
    ax.scatter(XS, YS, ZS, c='r', marker='o')
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    

    否则,.csv 中的每一行都会得到一个新的散点图。我隔离了您数据的前 5 行,并用这些修改绘制了它们以给出

    【讨论】:

    • 非常感谢,这帮助我输出了完整的文件。
    【解决方案2】:

    只是为了好玩,使用 numpy 在默认情况下绕过了将字符串传递给 matplotlib 的原始问题,同时稍微压缩了代码。

    raw = """
    id,gx,gy,gz,ax,ay,az
    0,4.47,-33.23,-77,-106,94
    1,-129.04,4.48,-33.22,-78,-94,117
    2,-129.04,4.49,33.2,-70,-81,138
    3,-129.02,4.49,-33.18,-70,-64,157
    4,-129.02,4.5,-33.15,-64,-47,165
    """
    
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    from io import StringIO
    
    # read data
    csvfile = StringIO(raw)
    d = plt.np.loadtxt(csvfile, delimiter=',', skiprows=2, usecols=[1,2,3])
    # instead of csvfile just use filename when using the real file
    xyz = plt.np.split(d.T, indices_or_sections=len(d.T))
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(*xyz, c='r', marker='o')
    ax.set(**{'%slabel'%s: s.upper() + ' Label' for s in 'xyz'})
    

    【讨论】: