Generate the data and import packagesΒΆ
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
#tutorial pip install squarify
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import squarify
import numpy as np
import pandas as pd
color_dict = {
(2022, "Norway"): "#464B64",
(2004, "Norway"): "#2B314D",
(2022, "Denmark"): "#D57968",
(2004, "Denmark"): "#CE5A43",
(2022, "Sweden"): "#7296E9",
(2004, "Sweden"): "#5375D4",
}
legend_color = "#C4C7CB"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : [ "Denmark", "Denmark", "Norway", "Norway","Sweden", "Sweden",],
"sites": [4,10,5,8,13,15]
}
df= pd.DataFrame(data)
df["colors"] = df.set_index(["year", "countries"]).index.map(color_dict.get)
df['diff'] = (df.groupby(['countries'])['sites'].diff()).abs()
df['diff'] = df['diff'].fillna(df.sites)
sort_order_dict = {"Denmark": 1, "Sweden": 2, "Norway": 3}
df = df.sort_values(by=["countries"], key=lambda x: x.map(sort_order_dict))
df
| year | countries | sites | colors | diff | |
|---|---|---|---|---|---|
| 0 | 2004 | Denmark | 4 | #CE5A43 | 4.0 |
| 1 | 2022 | Denmark | 10 | #D57968 | 6.0 |
| 5 | 2022 | Sweden | 15 | #7296E9 | 2.0 |
| 4 | 2004 | Sweden | 13 | #5375D4 | 13.0 |
| 3 | 2022 | Norway | 8 | #464B64 | 3.0 |
| 2 | 2004 | Norway | 5 | #2B314D | 5.0 |
fig, ax = plt.subplots(figsize=(6, 6))
df_max = df[df.year == df.year.max()]
squarify.plot(
sizes=df_max.sites,
color = df_max.colors,
edgecolor = "w"
)
<Axes: >
Add 2004 years manually with rectangles and the country labels:
rectangles = [
{ 'position': (0, 0), 'width': 40, 'height': 30, 'color': '#CE5A43' }, # Denmark
{ 'position': (0, 45),'width': 70, 'height': 55, 'color': '#5375D4' }, # Sweden
{ 'position': (80, 15),'width': 20, 'height': 85, 'color': '#2B314D' } # Norway
]
ctry_labels =[
{ 'pos': (0,-4)},
{ 'pos': (0,102)},
{ 'pos': (78,102)},
]
# Loop through the rectangles list and add each rectangle to the plot
for i, rect_data in enumerate(rectangles):
rect = patches.Rectangle(
rect_data['position'], # (x, y)
rect_data['width'], # Width
rect_data['height'], # Height
linewidth=2, # Line thickness
edgecolor='none', # No border color
facecolor=rect_data['color'] # Rectangle color
)
ax.add_patch(rect)
ax.text(
ctry_labels[i]['pos'][0],
ctry_labels[i]['pos'][1],
f"{df_max.countries.unique()[i]} {df.sites.unique()[i]}",
color = df_max.colors.unique()[i],
clip_on = False
)
fig
Add the data labels:
data_positions = [
{'posxy': (3, 25)},
{'posxy': (3, 34)},
{'posxy': (95, 18)},
{'posxy': (95, 10)},
{'posxy': (65, 95)},
{'posxy': (72, 95)},
]
for j, data_pos in enumerate(data_positions):
ax.text(
data_positions[j]['posxy'][0],
data_positions[j]['posxy'][1],
int(df['diff'][j]),
color = "w"
)
fig
Add the legend
inset_ax = fig.add_axes([0.1, -0.15, 0.7, 0.1]) # [left, bottom, width, height]
inset_ax.axis('off')
sorted_df = df.sort_values("year")
colors = sorted_df['colors'].unique()
x0 = 0
width = .2
y = 0
for i, color in enumerate(colors[:3]):
inset_ax.plot([x0 + i, x0 + i + 1], [y, y], color=color, lw=4)
inset_ax.text(x0 + 3.2, y, 'Before 2004', va='center', fontsize=10, color = legend_color)
# Shift x for second group
x0 = x0 + 9 # Add space between the two groups visually
for i, color in enumerate(colors[3:7]):
inset_ax.plot([x0 + i, x0 + i + 1], [y, y], color=color, lw=4,)
inset_ax.text(x0 + 3.2, y, 'After 2004', va='center', fontsize=10, color = legend_color)
# Set x/y limits of inset axes
inset_ax.set_xlim(-1, x0 + 5)
inset_ax.set_ylim(-1, 1)
ax.set_axis_off()
fig