Generate the data and import packagesΒΆ
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.colors
color_dict = {"Norway": "#CC5A43", "Denmark": "#3B4D83", "Sweden": "#5375D4"}
label_color, datalabels_color = "#757C85", "#FFFFFF"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Denmark", "Denmark", "Norway", "Norway", "Sweden", "Sweden"],
"sites": [4, 10, 5, 8, 13, 15],
}
df = pd.DataFrame(data)
# custom sort
sort_order_dict = {"Denmark": 2, "Sweden": 3, "Norway": 1, 2004: 4, 2022: 5}
df = df.sort_values(
by=[
"year",
"countries",
],
key=lambda x: x.map(sort_order_dict),
)
df["ctry_code"] = df.countries.astype(str).str[:2].astype(str).str.upper()
df["year_lbl"] = "'" + df["year"].astype(str).str[-2:].astype(str)
df["color"] = df.countries.map(color_dict)
df
| year | countries | sites | ctry_code | year_lbl | color | |
|---|---|---|---|---|---|---|
| 2 | 2004 | Norway | 5 | NO | '04 | #CC5A43 |
| 0 | 2004 | Denmark | 4 | DE | '04 | #3B4D83 |
| 4 | 2004 | Sweden | 13 | SW | '04 | #5375D4 |
| 3 | 2022 | Norway | 8 | NO | '22 | #CC5A43 |
| 1 | 2022 | Denmark | 10 | DE | '22 | #3B4D83 |
| 5 | 2022 | Sweden | 15 | SW | '22 | #5375D4 |
sites = np.array(df.groupby(["year"], sort=False).sites.apply(list).tolist())
countries = df.countries.unique()
years = df.year.unique()
colors = df.color.unique()
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"", colors[::-1], df.sites.max()
) # df.sites.max adds the discrete option to the colorbar
fig, ax = plt.subplots()
# specify the range of the colormap regarless of the values of the plot
im = ax.imshow(
sites,
cmap = cmap,
vmin = 0,
vmax = 15,
)
# Loop over data dimensions and create text annotations.
for i in range(len(years)):
for j in range(len(countries)):
text = ax.text(
j, i, sites[i, j], size=12, ha="center", va="center", color=datalabels_color
)
#add cbar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.9, ticks=[0, 15], location="bottom")
cbar.outline.set_visible(False)
cbar.ax.tick_params(size=0)
cbar.ax.set_xticklabels(
cbar.ax.get_xticklabels(),
color=label_color,
)
#styling
ax.tick_params(
axis = "both",
which = "major",
labeltop = True,
labelbottom = False,
length = 0,
labelsize = 12,
colors = label_color,
pad = 20,
)
ax.yaxis.set_ticks(range(len(years)), labels=years)
ax.xaxis.set_ticks(range(len(countries)), labels=countries)
ax.set_frame_on(False)