Generate the data and import packagesΒΆ
First, we need to create the data.
We start by defining the path to the flags and the shape of the table.
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import numpy as np
no = plt.imread("../flags/no-sq.png")
sw = plt.imread("../flags/sw-sq.png")
de = plt.imread("../flags/de-sq.png")
# create a 40x40 pixel empty image (transparent background)
height, width =40, 40
empty_image = np.zeros((height, width, 4), dtype=np.uint8)
# Save the transparent image as "empty.png"
plt.imsave('empty.png', empty_image)
bk = plt.imread("empty.png")
img_col = [de, sw, no, de, sw, no]
img_row = [no, sw, de, no, sw, de]
img = [no,sw,no,de,sw,bk,
sw,sw,sw,sw,
de,sw,de,
no,sw,
sw,
bk]
img_texts = ["+4", "+5", "+3", "+2", "+7", "",
"+11","+2","+10","+5",
"+6","+3","+5",
"+1","+8",
"+9",
""]
mosaic = [
'abcdef', # row 1
'ghijXX', #...
'klmXXX',
'noXXXX',
'pXXXXX',
'qXXXXX'
]
Function to place an image in an axes:
def ax_flag(img, ax, xy, x, y):
""""
add an image to an axes
img: path to the image
ax: axes to place the image
xy: the (x, y) position where the center of the image will be placed, in data coordinates.
x: x-offset of the image
y: y-offset
"""
image_box = OffsetImage(img, zoom = 0.05) #container for the image
ab = AnnotationBbox(
image_box,
xy,
xybox=(x, y),
xycoords='data',
boxcoords="offset points",
frameon = False
)
ax.add_artist(ab)
Then we will use mosaic to place all the axes needed:
# Convert the rows list into a single multiline string
layout = "\n".join(mosaic)
fig = plt.figure(figsize=(8, 8))
axes = fig.subplot_mosaic(
layout,
empty_sentinel="X",
gridspec_kw={
"wspace": 0,
"hspace": 0,
},
)
for img_text, im, (key, ax) in zip(img_texts, img, axes.items()):
#add the flag
ax_flag(
im,
ax,
(0.5,0.5),
0.5,
0.5
)
#add the datalabels
ax.text(
0.5,
0.1,
img_text,
ha="center",
va="center",
fontsize=10,
color="darkgrey"
)
#add the styling
ax.tick_params(length = 0, labelleft = False, labelbottom = False)
if key not in ['f', 'q']:
for spine in ax.spines.values():
spine.set_color("#D9DDDE")
elif key in ['f', 'q']:
ax.spines[['left','top']].set_color("#D9DDDE")
ax.spines[['bottom','right']].set_color("w")
#print(ax['a'].get_position())
print(ax)
Axes(0.125,0.751667;0.129167x0.128333) Axes(0.254167,0.751667;0.129167x0.128333) Axes(0.383333,0.751667;0.129167x0.128333) Axes(0.5125,0.751667;0.129167x0.128333) Axes(0.641667,0.751667;0.129167x0.128333) Axes(0.770833,0.751667;0.129167x0.128333) Axes(0.125,0.623333;0.129167x0.128333) Axes(0.254167,0.623333;0.129167x0.128333) Axes(0.383333,0.623333;0.129167x0.128333) Axes(0.5125,0.623333;0.129167x0.128333) Axes(0.125,0.495;0.129167x0.128333) Axes(0.254167,0.495;0.129167x0.128333) Axes(0.383333,0.495;0.129167x0.128333) Axes(0.125,0.366667;0.129167x0.128333) Axes(0.254167,0.366667;0.129167x0.128333) Axes(0.125,0.238333;0.129167x0.128333) Axes(0.125,0.11;0.129167x0.128333)
Add the column headers:
for im_col, ax in zip(img_col, list(axes.values())[:6]): # first five axes
#ax.text(0.5, 5, "hello", transform=ax.transAxes, ha="center", va = "center")
ax_flag(
im_col, ax,
(0.5, 1),
0, 30
)
fig
Get the first axes on each row to add the row headers:
# get the first axis of each row
first_axes_each_row = []
for row in mosaic:
# find the first non-X character in the row
for key in row:
if key != 'X' and key in axes:
first_axes_each_row.append(axes[key])
break
for im_row, ax in zip(img_row, first_axes_each_row):
#ax.text(0.5, 5, "hello", transform=ax.transAxes, ha="center", va = "center")
ax_flag(
im_row,
ax,
(0, 0.5),
-35,
0
)
fig
Add the lines and years:
# define the pairs of axes to connect with lines
pairs = [('a', 'c'), ('d', 'f'), ('a', 'k'), ('n', 'q')]
year = [2004, 2022, 2022, 2004]
box_color = ["w","#ECEFEF", "#ECEFEF", "w" ]
#add the lines and annotations
for i, (start, end) in enumerate(pairs):
common_text_params = {
"va": "center",
"ha":'center',
"fontsize": 12,
"color":'#C0C6CA',
"bbox": dict(
boxstyle="round,pad=0.4",
fc=box_color[i],
ec="w",
lw=0.8
)
}
# get positions of the axes in figure coordinates
start_pos = axes[start].get_position()
end_pos = axes[end].get_position()
# get the center of the axes
x_start = start_pos.x0 + start_pos.width / 2
x_end = end_pos.x0 + end_pos.width / 2
y_start = start_pos.y0 + start_pos.height / 2
y_end = end_pos.y0 + end_pos.height / 2
if i < 2:
offset_hline = 0.12
y_line = start_pos.y0 + start_pos.height + offset_hline
fig.lines.append(plt.Line2D([x_start, x_end], [y_line, y_line], transform=fig.transFigure, color='#C0C6CA', lw=0.5))
x_mid = (x_start + x_end) / 2
fig.text(x_mid, y_line, year[i], **common_text_params)
else:
offset_vline = 0.2
fig.lines.append(plt.Line2D([x_start - offset_vline, x_start - offset_vline], [y_start, y_end], transform=fig.transFigure, color='#C0C6CA', lw=0.5))
y_line = (y_start + y_end) / 2
fig.text(x_start - offset_vline, y_line, year[i], rotation= 90, **common_text_params)
fig