Generate the data and import packages¶
First we need to create the data. I will do it using a dictionary and then converting it to a pandas dataframe as a lot projects use pandas to work with data.
import matplotlib.pyplot as plt
import pandas as pd
color_dict = {
"Norway": "#2B314D",
"Denmark": "#A54836",
"Sweden": "#5375D4",
"AVG.": "#838B93",
}
xy_ticklabel_color, title_color, grid_color = "#C8C9C9", "#838B93", "#C8C9C9"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Denmark", "Denmark", "Norway", "Norway", "Sweden", "Sweden"],
"sites": [4, 10, 5, 8, 13, 15],
}
df = pd.DataFrame(data)
df["ctry_code"] = df.countries.astype(str).str[:2].astype(str).str.upper()
df["year_lbl"] = "'" + df["year"].astype(str).str[-2:].astype(str)
df
| year | countries | sites | ctry_code | year_lbl | |
|---|---|---|---|---|---|
| 0 | 2004 | Denmark | 4 | DE | '04 |
| 1 | 2022 | Denmark | 10 | DE | '22 |
| 2 | 2004 | Norway | 5 | NO | '04 |
| 3 | 2022 | Norway | 8 | NO | '22 |
| 4 | 2004 | Sweden | 13 | SW | '04 |
| 5 | 2022 | Sweden | 15 | SW | '22 |
Add the average row to the dataset.
First, we need to calculate it:
df_avg = df.groupby("year")["sites"].mean().astype(int).reset_index()
df_avg[["ctry_code", "countries"]] = "AVG."
df_avg
| year | sites | ctry_code | countries | |
|---|---|---|---|---|
| 0 | 2004 | 7 | AVG. | AVG. |
| 1 | 2022 | 11 | AVG. | AVG. |
and now, add it back to the dataframe:
final = pd.concat([df, df_avg], ignore_index=True)
final["color"] = final.countries.map(color_dict)
# custom sort
sort_order_dict = {"Denmark": 2, "Sweden": 1, "Norway": 3, "AVG.": 4, 2004: 4, 2022: 5}
final = final.sort_values(
by=[
"countries",
"year",
],
key=lambda x: x.map(sort_order_dict),
)
final
| year | countries | sites | ctry_code | year_lbl | color | |
|---|---|---|---|---|---|---|
| 4 | 2004 | Sweden | 13 | SW | '04 | #5375D4 |
| 5 | 2022 | Sweden | 15 | SW | '22 | #5375D4 |
| 0 | 2004 | Denmark | 4 | DE | '04 | #A54836 |
| 1 | 2022 | Denmark | 10 | DE | '22 | #A54836 |
| 2 | 2004 | Norway | 5 | NO | '04 | #2B314D |
| 3 | 2022 | Norway | 8 | NO | '22 | #2B314D |
| 6 | 2004 | AVG. | 7 | AVG. | NaN | #838B93 |
| 7 | 2022 | AVG. | 11 | AVG. | NaN | #838B93 |
Group the dataframe by country:
groups = final.groupby("countries", sort=False)
Plot the chart¶
fig, axes = plt.subplots(
ncols=final.countries.nunique(), nrows=1, figsize=(6, 6), sharey=True
)
fig.tight_layout(pad=2.0)
for (country, group), ax in zip(groups, axes.ravel()):
for row in group.itertuples(index=False):
site = row.sites
color = row.color
year = row.year
ax.axhline(
y = site,
lw = 4,
color = color)
if year == df.year.min():
ax.text(
0.5,
site - 0.3,
year,
ha = "center",
va = "top",
size = 12,
color = color,
alpha = 0.5,
)
else:
ax.text(
0.5,
site + 0.3,
year,
ha = "center",
va = "bottom",
size = 12,
color = color,
alpha = 0.5,
)
# add the connecting line
ax.annotate(
"",
xy = (0.5, group.sites.max()),
xytext = (0.5, group.sites.min()),
color = "w",
weight = "bold",
arrowprops = dict(
arrowstyle = "->",
color = color,
lw = 1),
)
The final code¶
We add the styling, and the country titles and we are done!
import matplotlib.pyplot as plt
import pandas as pd
color_dict = {
"Norway": "#2B314D",
"Denmark": "#A54836",
"Sweden": "#5375D4",
"AVG.": "#838B93",
}
(
xy_ticklabel_color,
title_color,
grid_color,
) = (
"#C8C9C9",
"#838B93",
"#C8C9C9",
)
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Denmark", "Denmark", "Norway", "Norway", "Sweden", "Sweden"],
"sites": [4, 10, 5, 8, 13, 15],
}
df = pd.DataFrame(data)
df["ctry_code"] = df.countries.astype(str).str[:2].astype(str).str.upper()
df["year_lbl"] = "'" + df["year"].astype(str).str[-2:].astype(str)
# Add the average row to the dataset.
df_avg = df.groupby("year")["sites"].mean().astype(int).reset_index()
df_avg[["ctry_code", "countries"]] = "AVG."
# add it back to the dataframe:
final = pd.concat([df, df_avg], ignore_index=True)
final["color"] = final.countries.map(color_dict)
# custom sort
sort_order_dict = {"Denmark": 2, "Sweden": 1, "Norway": 3, "AVG.": 4, 2004: 4, 2022: 5}
final = final.sort_values(
by=[
"countries",
"year",
],
key=lambda x: x.map(sort_order_dict),
)
groups = final.groupby("countries", sort=False)
fig, axes = plt.subplots(
ncols=final.countries.nunique(), nrows=1, figsize=(6, 6), sharey=True
)
fig.tight_layout(pad=2.0)
for (country, group), ax in zip(groups, axes.ravel()):
for row in group.itertuples(index=False):
site = row.sites
color = row.color
year = row.year
ax.axhline(y=site, lw=4, color=color)
if year == df.year.min():
ax.text(
0.5,
site - 0.3,
year,
ha = "center",
va = "top",
size = 12,
color = color,
alpha = 0.5,
)
else:
ax.text(
0.5,
site + 0.3,
year,
ha = "center",
va = "bottom",
size = 12,
color = color,
alpha = 0.5,
)
# add the connecting line
ax.annotate(
"",
xy = (0.5, group.sites.max()),
xytext = (0.5, group.sites.min()),
color = "w",
weight = "bold",
arrowprops = dict(
arrowstyle = "->",
color = color,
lw = 1),
)
ax.set_ylim(0, 16)
ax.set_frame_on(False)
ax.tick_params(
axis = "both",
which = "major",
length = 0,
labelbottom = False,
labelsize = 12,
colors = xy_ticklabel_color,
pad = 10,
)
# add vertical grid lines
ax.grid(True, axis="y", linestyle="solid", linewidth=1, color=grid_color)
ax.set_title(
country,
color=title_color,
x=0.5,
y=1.05,
size=12,
)