【问题标题】:How to modify xtick label of plt in Matplotlib如何在 Matplotlib 中修改 plt 的 xtick 标签
【发布时间】:2022-01-19 17:05:55
【问题描述】:

目标是在绘制pcolormeshscatter 时修改xticklabel

但是,我无法访问现有的 xtick 标签。

简单

ax = plt.axes()
labels_x = [item.get_text() for item in ax.get_xticklabels()]

生产的:

['', '', '', '', '', '']

fig.canvas.draw()
xticks = ax.get_xticklabels()

生产的:

['', '', '', '', '', '']

不返回对应的标签。

我能否知道如何正确访问 plt 案例的轴刻度标签。

为了可读性,我将代码分成两部分。

  1. 生成用于绘图的数据的第一部分
  2. 第二部分处理绘图

第 1 部分:生成用于绘图的数据

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

np.random.seed(0)
increment=120
max_val=172800

aran=np.arange(0,max_val,increment).astype(int)
arr=np.concatenate((aran.reshape(-1,1), np.random.random((aran.shape[0],4))), axis=1)
df=pd.DataFrame(arr,columns=[('lapse',''),('a','i'),('a','j'),('b','k'),('c','')])

ridx=df.index[df[('lapse','')] == 3600].tolist()[0]+1 # minus 1 so to allow 3600 start at new row


df[('event','')]=0
df.loc[[1,2,3,10,20,30],[('event','')]]=1

arr=df[[('a','i'),('event','')]].to_numpy()
col_len=ridx
v=arr[:,0].view()

nrow_size=math.ceil(v.shape[0]/col_len)
X=np.pad(arr[:,0].astype(float), (0, nrow_size*col_len - arr[:,0].size),
       mode='constant', constant_values=np.nan).reshape(nrow_size,col_len)

mask_append_val=0  # This value must equal to 1 for masking
arrshape=np.pad(arr[:,1].astype(float), (0, nrow_size*col_len - arr[:,1].size),
       mode='constant', constant_values=mask_append_val).reshape(nrow_size,col_len)

第 2 节绘图

fig = plt.figure(figsize=(8,6))
plt.pcolormesh(X,cmap="plasma")

x,y = X.shape
xs,ys = np.ogrid[:x,:y]
# the non-zero coordinates
u = np.argwhere(arrshape)

plt.scatter(ys[:,u[:,1]].ravel()+.5,xs[u[:,0]].ravel()+0.5,marker='*', color='r', s=55)

plt.gca().invert_yaxis()

xlabels_to_use_this=df.loc[:30,[('lapse','')]].values.tolist()

# ax = plt.axes()
# labels_x = [item.get_text() for item in ax.get_xticklabels()]
# labels_y = [item.get_text() for item in ax.get_yticklabels()]

plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title("Plot 2D array")
plt.colorbar()
plt.tight_layout()
plt.show()

预期输出

【问题讨论】:

  • 请描述具体问题是什么:您返回的刻度值是什么?提及您需要刻度值的目标也很有用。
  • 好吧,随着更多数据添加到绘图中,xtick 位置仍然会发生很大变化。您可以强制平局,但推荐的解决方案是您自己决定在哪里打勾,然后致电ax.set_xticks(your_ticks)。请注意,“plt 案例”与 ax = plt.gca() 案例完全相同。在内部两者都做同样的事情。例如,您可以使用plt.xticks(range(len(xlabels_to_use_this)), xlabels_to_use_this)。请注意,将刻度设置为尝试接收刻度都应在您的代码中尽可能晚地发生,因为许多绘图命令都会更改刻度。
  • 另请注意ax = plt.axes() 在您现有的绘图之上创建一个新的虚拟绘图。你真的应该避免这种情况。请改用ax = plt.gca() 来访问现有的子图。获得正确刻度的首选方法是使用plt.pcolormesh(x-values, y-values, color-values, ...)
  • 感谢@JohanC 的详细解释,也许我理解了一些东西,但仍然使用 ax = plt.gca() 后跟 labels_x = [item.get_text() 作为 ax.get_xticklabels()] 中的项目产量['', '', '', '', '', '', '', '']
  • 您的代码中没有设置任何 xtick 标签。它们仅在“显示”时间显示,并使用类似于代码格式化程序 (docs) 的东西。你可以得到plt.gca().get_xticks()。但更好的是,您将ax.set_xticks(...)ax.set_xticklabels(...) 一起使用。随着所有的重塑,您的代码非常难以理解。此外,颜色条需要一个参数。例如。 mesh = plt.pcolormesh(...)plt.colorbar(mesh, ....)

标签: python matplotlib


【解决方案1】:

这就是使用 matplotlib 的 pcolormeshscatter 生成绘图的方式:

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import pandas as pd
import numpy as np

np.random.seed(0)
increment = 120
max_val = 172800
aran = np.arange(0, max_val, increment).astype(int)
arr_df = np.concatenate((aran.reshape(-1, 1), np.random.random((aran.shape[0], 4))), axis=1)
df = pd.DataFrame(arr_df, columns=[('lapse', ''), ('a', 'i'), ('a', 'j'), ('b', 'k'), ('c', '')])
df[('event', '')] = 0
df.loc[[1, 2, 3, 10, 20, 30], [('event', '')]] = 1

col_len_lapse = 3600
col_len = df[df[('lapse', '')] == col_len_lapse].index[0]
nrow_size = int(np.ceil(v.shape[0] / col_len))

a_i_values = df[('a', 'i')].values
a_i_values_meshed = np.pad(a_i_values.astype(float), (0, nrow_size * col_len - len(a_i_values)),
                           mode='constant', constant_values=np.nan).reshape(nrow_size, col_len)

fig, ax = plt.subplots(figsize=(8, 6))
# the x_values indicate the mesh borders, subtract one half so the ticks can be at the centers
x_values = df[('lapse', '')][:col_len + 1].values - increment / 2
# divide lapses for y by col_len_lapse to get hours
y_values = df[('lapse', '')][::col_len].values / col_len_lapse - 0.5
y_values = np.append(y_values, 2 * y_values[-1] - y_values[-2])  # add the bottommost border (linear extension)

mesh = ax.pcolormesh(x_values, y_values, a_i_values_meshed, cmap="plasma")

event_lapses = df[('lapse', '')][df[('event', '')] == 1]
ax.scatter(event_lapses % col_len_lapse,
           np.floor(event_lapses / col_len_lapse),
           marker='*', color='red', edgecolor='white', s=55)

ax.xaxis.set_major_locator(MultipleLocator(increment * 5))
ax.yaxis.set_major_locator(MultipleLocator(5))
ax.invert_yaxis()
ax.set_xlabel('X-axis (s)')
ax.set_ylabel('Y-axis (hours)')
ax.set_title("Plot 2D array")
plt.colorbar(mesh)

plt.tight_layout()  # fit the labels nicely into the plot
plt.show()

使用 Seaborn 可以简化事情,添加新的小时和秒列,并使用 pandas 的pivot(它会自动用NaNs 填充不可用的数据)。添加 xtick_labels=5 每 5 个位置设置标签。 (lapse=3600 的星号位于 1 小时 0 秒)。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# df created as before

df['hours'] = (df[('lapse', '')].astype(int) // 3600)
df['seconds'] = (df[('lapse', '')].astype(int) % 3600)

df_heatmap = df.pivot(index='hours', columns='seconds', values=('a', 'i'))
df_heatmap_markers = df.pivot(index='hours', columns='seconds', values=('event', '')).replace(
    {0: '', 1: '★', np.nan: ''})

fig, ax = plt.subplots(figsize=(8, 6))

sns.heatmap(df_heatmap, xticklabels=5, yticklabels=5,
            annot=df_heatmap_markers, fmt='s', annot_kws={'color': 'lime'}, ax=ax)
ax.tick_params(rotation=0)

plt.tight_layout()
plt.show()

除了“秒”列,“分钟”列也可能很有趣。

这里尝试按照 cmets 中的建议添加时间信息:

from matplotlib import patheffects # to add some outline effect

# df prepared as the other seaborn example

fig, ax = plt.subplots(figsize=(8, 6))

path_effect = patheffects.withStroke(linewidth=2, foreground='yellow')
sns.heatmap(df_heatmap, xticklabels=5, yticklabels=5,
            annot=df_heatmap_markers, fmt='s',
            annot_kws={'color': 'red', 'path_effects': [path_effect]},
            cbar=True, cbar_kws={'pad': 0.16}, ax=ax)
ax.tick_params(rotation=0)

ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
yticks = ax.get_yticks()
ax2.set_yticks(yticks)
ax2.set_yticklabels([str(pd.to_datetime('2019-01-15 7:00:00') + pd.to_timedelta(h, unit='h')).replace(' ', '\n')
                     for h in yticks])

【讨论】:

  • 这比最初的建议更好更紧凑。
【解决方案2】:

我最终使用 Seaborn 来解决这个问题。

具体来说,以下几行可以轻松调整 xticklabel

fig.canvas.draw()
new_ticks = [i.get_text() for i in g.get_xticklabels()]

i=[int(idx) for idx in new_ticks]
newlabel=xlabels_to_use_this[i]
newlabel=[np.array2string(x, precision=0) for x in newlabel]

完整的绘图代码如下

import seaborn as sns
fig, ax = plt.subplots()

sns.heatmap(X,ax=ax)
x,y = X.shape
xs,ys = np.ogrid[:x,:y]
# the non-zero coordinates
u = np.argwhere(arrshape)
g=sns.scatterplot(ys[:,u[:,1]].ravel()+.5,xs[u[:,0]].ravel()+0.5,marker='*', color='r', s=55)


fig.canvas.draw()
new_ticks = [i.get_text() for i in g.get_xticklabels()]

i=[int(idx) for idx in new_ticks]
newlabel=xlabels_to_use_this[i]
newlabel=[np.array2string(x, precision=0) for x in newlabel]

ax.set_xticklabels(newlabel)

ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)

for ind, label in enumerate(g.get_xticklabels()):
    if ind % 2 == 0:  # every 10th label is kept
        label.set_visible(True)
    else:
        label.set_visible(False)


for ind, label in enumerate(g.get_yticklabels()):
    if ind % 4 == 0:  # every 10th label is kept
        label.set_visible(True)
    else:
        label.set_visible(False)
      
plt.xlabel('Elapsed (s)')
plt.ylabel('Hour (h)')
plt.title("Rastar Plot")
plt.tight_layout()
plt.show()

【讨论】:

  • 请注意,您将 31 次 120 秒分组为小时数(因为 172800 是 3600 的倍数,所以末尾不应有任何空白)。
  • 对于 seaborn,您应该使用 sns.heatmap(...., xticklabels=..., yticklabels=...),为未标记的刻度留下空字符串。
猜你喜欢
  • 1970-01-01
  • 2017-02-02
  • 1970-01-01
  • 1970-01-01
  • 2014-07-25
  • 2016-07-24
  • 1970-01-01
  • 2021-11-30
  • 1970-01-01
相关资源
最近更新 更多