Generate the data and import packagesΒΆ
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
color_dict = {"Norway": "#2B314D", "Denmark": "#A54836", "Sweden": "#5375D4" }
code_dict = {"Norway": "NO", "Denmark": "DK", "Sweden": "SE" }
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13,15,4,10,5,8]
}
df= pd.DataFrame(data)
sort_order_dict = {"Denmark": 2, "Sweden": 1, "Norway": 3, 2004: 4, 2022: 5}
df = df.sort_values(by=["year","countries"], key=lambda x: x.map(sort_order_dict))
df['ctry_code'] = df.countries.map(code_dict)
#map the colors of a dict to a dataframe
df['color']= df.countries.map(color_dict)
df
| year | countries | sites | ctry_code | color | |
|---|---|---|---|---|---|
| 0 | 2004 | Sweden | 13 | SE | #5375D4 |
| 2 | 2004 | Denmark | 4 | DK | #A54836 |
| 4 | 2004 | Norway | 5 | NO | #2B314D |
| 1 | 2022 | Sweden | 15 | SE | #5375D4 |
| 3 | 2022 | Denmark | 10 | DK | #A54836 |
| 5 | 2022 | Norway | 8 | NO | #2B314D |
fig, axes = plt.subplots(ncols = 2, figsize = (12,5), sharey = True)
center = 25
total_width = 35
y_bottom = 0
for ax, (year, group) in zip(axes.ravel(), df.groupby("year", sort = False)):
y_bottoms = [0] + list(group['sites'].cumsum().iloc[:-1])
y_tops = list(group['sites'].cumsum())
total_height = y_tops[-1]
print(y_bottoms, y_tops, total_height)
# slope of the triangle
half_width = total_width / 2
slope = half_width / total_height
for i, row in enumerate(group.itertuples()):
yb, yt = y_bottoms[i], y_tops[i]
# x-coordinates along the side lines at these heights
left_bottom = center - slope * (total_height - yb)
right_bottom = center + slope * (total_height - yb)
left_top = center - slope * (total_height - yt)
right_top = center + slope * (total_height - yt)
ax.fill_betweenx(
y=[yb, yt],
x1=[left_bottom, left_top],
x2=[right_bottom, right_top],
color=row.color
)
mid_segment = (yb + yt) / 2
#add data labels
ax.text(
center,
mid_segment,
row.sites,
ha='center',
va='center',
color='white',
fontsize=10
)
#add country labels
center_offset = 2
label_x_position = center - center_offset - slope * (total_height - mid_segment)
ax.text(
label_x_position,
mid_segment,
row.ctry_code,
ha='center',
va='center',
color=row.color,
fontsize=10,
)
ax.set_xlabel(group['year'].unique().item(), size = 12, labelpad=16)
ax.tick_params(length= 0, labelleft = False, labelbottom = False)
ax.set_frame_on(False)
[0, 13, 17] [13, 17, 22] 22 [0, 15, 25] [15, 25, 33] 33