Skip to content

Commit 2dbf1a7

Browse files
committed
assign reference values by var_name instead of var_label
1 parent 6fabf27 commit 2dbf1a7

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

arviz/plots/backends/bokeh/pairplot.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def plot_pair(
3737
diverging_mask,
3838
divergences_kwargs,
3939
flat_var_names,
40+
flat_var_labels,
4041
backend_kwargs,
4142
marginal_kwargs,
4243
show,
@@ -262,8 +263,8 @@ def get_width_and_height(jointplot, rotate):
262263
**marginal_kwargs,
263264
)
264265

265-
ax[j, i].xaxis.axis_label = flat_var_names[i]
266-
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
266+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
267+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
267268

268269
elif j + marginals_offset > i:
269270
if "scatter" in kind:
@@ -350,8 +351,8 @@ def get_width_and_height(jointplot, rotate):
350351
y = reference_values_copy[flat_var_names[i]]
351352
if x and y:
352353
ax[j, i].scatter(y, x, **reference_values_kwargs)
353-
ax[j, i].xaxis.axis_label = flat_var_names[i]
354-
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
354+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
355+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
355356

356357
show_layout(ax, show)
357358

arviz/plots/backends/matplotlib/pairplot.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def plot_pair(
3030
diverging_mask,
3131
divergences_kwargs,
3232
flat_var_names,
33+
flat_var_labels,
3334
backend_kwargs,
3435
marginal_kwargs,
3536
show,
@@ -215,8 +216,8 @@ def plot_pair(
215216
reference_values_copy[flat_var_names[1]],
216217
**reference_values_kwargs,
217218
)
218-
ax.set_xlabel(f"{flat_var_names[0]}", fontsize=ax_labelsize, wrap=True)
219-
ax.set_ylabel(f"{flat_var_names[1]}", fontsize=ax_labelsize, wrap=True)
219+
ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
220+
ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
220221
ax.tick_params(labelsize=xt_labelsize)
221222

222223
else:
@@ -344,12 +345,12 @@ def plot_pair(
344345
if j != vars_to_plot - 1:
345346
plt.setp(ax[j, i].get_xticklabels(), visible=False)
346347
else:
347-
ax[j, i].set_xlabel(f"{flat_var_names[i]}", fontsize=ax_labelsize, wrap=True)
348+
ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
348349
if i != 0:
349350
plt.setp(ax[j, i].get_yticklabels(), visible=False)
350351
else:
351352
ax[j, i].set_ylabel(
352-
f"{flat_var_names[j + not_marginals]}",
353+
f"{flat_var_labels[j + not_marginals]}",
353354
fontsize=ax_labelsize,
354355
wrap=True,
355356
)

arviz/plots/pairplot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def plot_pair(
196196
get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
197197
)
198198
)
199-
flat_var_names = [
199+
flat_var_names = [var_name for var_name, _, _, _ in plotters]
200+
flat_var_labels = [
200201
labeller.make_label_vert(var_name, sel, isel) for var_name, sel, isel, _ in plotters
201202
]
202203

@@ -253,6 +254,7 @@ def plot_pair(
253254
diverging_mask=diverging_mask,
254255
divergences_kwargs=divergences_kwargs,
255256
flat_var_names=flat_var_names,
257+
flat_var_labels=flat_var_labels,
256258
backend_kwargs=backend_kwargs,
257259
marginal_kwargs=marginal_kwargs,
258260
show=show,

0 commit comments

Comments
 (0)