Generate the data and import packages¶
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
import math
color_dict = {"Norway": "#2B314D", "Denmark": "#A54836", "Sweden": "#5375D4"}
xy_ticklabel_color, xy_label_color = "#101628", "#101628"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Denmark", "Denmark", "Norway", "Norway", "Sweden", "Sweden"],
"sites": [4, 10, 5, 8, 13, 15],
}
df = pd.DataFrame(data)
df = df.sort_values(["year", "countries"], ascending=True).reset_index(drop=True)
# map the colors of a dict to a dataframe
df["color"] = df.countries.map(color_dict)
df["sub_total"] = df.groupby("year")["sites"].transform("sum")
df
| year | countries | sites | color | sub_total | |
|---|---|---|---|---|---|
| 0 | 2004 | Denmark | 4 | #A54836 | 22 |
| 1 | 2004 | Norway | 5 | #2B314D | 22 |
| 2 | 2004 | Sweden | 13 | #5375D4 | 22 |
| 3 | 2022 | Denmark | 10 | #A54836 | 33 |
| 4 | 2022 | Norway | 8 | #2B314D | 33 |
| 5 | 2022 | Sweden | 15 | #5375D4 | 33 |
groups = df.groupby("year")
max_subtotal = df.sub_total.max()
matrix_rows = 4
matrix_columns = math.ceil(max_subtotal/ matrix_rows)
total_cells = matrix_rows * matrix_columns
We want to create a flat list of colors color_squares where:
- Each site's color color is repeated according to sites
- Then pad the beginning with white (#FFFFFF) squares to make the total length 36 (for a 4 × 9 matrix).
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(4, 3.7))
# create the matrix
x = np.repeat(np.arange(1, matrix_columns + 1), matrix_rows)
y = np.tile(np.arange(1, matrix_rows + 1), matrix_columns)
for ax, (year, group) in zip(axes.ravel(), groups):
ax.set(xlim=(0, matrix_rows + 1), ylim=(0, matrix_columns + 2))
ax.invert_yaxis()
color_squares = np.insert( np.repeat(group.color, group.sites), 0, ["#FFFFFF"] * (total_cells - group.sub_total.iloc[0]))
ax.scatter(
y,
x,
marker="s",
s=300,
color=color_squares,
)
ax.set_xlabel(
group.year.iloc[0],
color=xy_label_color,
size=12,
weight="bold",
)
ax.tick_params(
axis="both",
which="both",
length=0,
labelleft = False,
labelbottom = False
)
ax.set_frame_on(False)
# add legend
lines = [
Line2D(
[0],
[0],
color=c,
marker="s",
linestyle="",
markersize=12,
)
for c in df.color.unique()
]
plt.figlegend(
lines,
df.countries.unique(),
labelcolor=xy_label_color,
prop=dict(size=10, weight="light"),
bbox_to_anchor=(0.5, -0.3),
loc="lower center",
ncols=3,
frameon=False,
fontsize=10,
)
<matplotlib.legend.Legend at 0x1cdf328d820>