Generate the data and import packages¶
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import matplotlib as mpl
import pandas as pd
color_dict = {(2022,"Norway"): "#9194A3", (2004,"Norway"): "#2B314D",
(2022,"Denmark"): "#E2AFA5", (2004,"Denmark"): "#A54836",
(2022,"Sweden"): "#C4D6F8", (2004,"Sweden"): "#5375D4",
}
code_dict = {"Norway": "NO", "Denmark": "DK", "Sweden": "SE", }
xy_ticklabel_color, xlabel_color, grand_totals_color, legend_color, grid_color, datalabels_color ='#101628',"#101628","#101628","#101628", "#C8C9C9", "#757C85"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : [ "Denmark", "Denmark", "Norway", "Norway","Sweden", "Sweden",],
"sites": [4,10,5,8,13,15]
}
df= pd.DataFrame(data)
df['ctry_code'] = df.countries.map(code_dict)
df = df.sort_values(['countries' ,'sites' ], ascending=False ).reset_index(drop=True)
#Add the color based on the color dictionary
df['color'] = df.set_index(['year', 'countries']).index.map(color_dict.get)
#To ensure that the areas are really proportional, use the square root values as the length and height of the rectangles.
df['sq_sites'] = (df['sites'])**.5
df['year_lbl'] ="'"+df['year'].astype(str).str[-2:].astype(str)
df
| year | countries | sites | ctry_code | color | sq_sites | year_lbl | |
|---|---|---|---|---|---|---|---|
| 0 | 2022 | Sweden | 15 | SE | #C4D6F8 | 3.872983 | '22 |
| 1 | 2004 | Sweden | 13 | SE | #5375D4 | 3.605551 | '04 |
| 2 | 2022 | Norway | 8 | NO | #9194A3 | 2.828427 | '22 |
| 3 | 2004 | Norway | 5 | NO | #2B314D | 2.236068 | '04 |
| 4 | 2022 | Denmark | 10 | DK | #E2AFA5 | 3.162278 | '22 |
| 5 | 2004 | Denmark | 4 | DK | #A54836 | 2.000000 | '04 |
# Generate the x and y coordinates
x = [0, 0.4, 0.2]
y = [0.9, 1, 1.2]
fig = plt.figure(figsize=(5, 5))
for x_, y_, (country, group) in zip(x, y, df.groupby("countries", sort = False)):
#add the axes for each plot
ax= fig.add_axes([x_, y_, 0.5, 0.5] )
max_sites = max(group.sq_sites.tolist())
for row in group.itertuples():
sq_site = row.sq_sites
#add the rectangles
r1 = patches.Rectangle(
(0,0),
sq_site,
sq_site,
color = row.color
)
t2 = mpl.transforms.Affine2D().rotate_deg(45) + ax.transData
r1.set_transform(t2)
ax.add_patch(r1)
ax.set_xlim(-5,5)
ax.set_ylim(-0,10)
#add the country code label
ax.text(
0,
max_sites/2,
row.ctry_code ,
color ="w",
ha="center",
va="center"
)
#add the data labels
if row.year == df.year.unique().min():
ax.text(
0,
-0.5,
row.sites ,
color = row.color,
ha="center",
va="center"
)
else:
ax.text(
0,
max_sites +2 ,
row.sites ,
color = row.color,
ha="center",
va="center"
)
#styling
ax.set_frame_on(False)
ax.tick_params(
axis='both',
which='both',
length = 0,
labelleft = False,
labelbottom = False
)
Add the custom legend¶
First add the data:
# Rectangle data (each is a dict of position, size, color, etc.)
x = .5
y = -.3
rectangle_data = [
{"xy": (x, y), "width": 0.4, "height": 0.4, "color": "w", "ec": "#ECEFEF", "fc": "#ECEFEF"},
{"xy": (x, y), "width": 0.25, "height": 0.25, "color": "w", "ec": grid_color, "fc": grid_color },
]
# Text data (position, text, alignment, color)
text_data = [
{"x": 0.55, "y": 0.9, "s": df.year_lbl.max(), "ha": "center", "color": "#2B314D"},
{"x": 0.55, "y": -0.2, "s": df.year_lbl.min(), "ha": "center", "color": "#2B314D"},
]
legend_ax = fig.add_axes([0.85, 0.6, 0.1, 0.1]) # left, bottom, width, height in figure coordinates
for rect in rectangle_data:
r1 = patches.Rectangle(**rect)
t2 = mpl.transforms.Affine2D().rotate_deg(45) + legend_ax.transData
r1.set_transform(t2)
legend_ax.add_patch(r1)
for txt in text_data:
legend_ax.text(**txt)
legend_ax.axis("off")
fig