Generate the data and import packages¶
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D
color_dict = {
(2022,"Norway"): "#9194A3",
(2004,"Norway"): "#2B314D",
(2022,"Denmark"): "#E2AFA5",
(2004,"Denmark"): "#A54836",
(2022,"Sweden"): "#C4D6F8",
(2004,"Sweden"): "#5375D4",
}
legend_color = "#C4C7CB"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13,15,4,10,5,8]
}
df= pd.DataFrame(data)
df['diff']= df.groupby("countries").sites.diff()
df['diff'] = df['diff'].fillna(df.sites)
#df = df.sort_values(['year','sites' ], ascending=False ).reset_index(drop=True)
#Add the color based on the color dictionary
df['color'] = df.set_index(['year', 'countries']).index.map(color_dict.get)
df
| year | countries | sites | diff | color | |
|---|---|---|---|---|---|
| 0 | 2004 | Sweden | 13 | 13.0 | #5375D4 |
| 1 | 2022 | Sweden | 15 | 2.0 | #C4D6F8 |
| 2 | 2004 | Denmark | 4 | 4.0 | #A54836 |
| 3 | 2022 | Denmark | 10 | 6.0 | #E2AFA5 |
| 4 | 2004 | Norway | 5 | 5.0 | #2B314D |
| 5 | 2022 | Norway | 8 | 3.0 | #9194A3 |
We will place each chart on its own axis. Here are the coordinates:
x = [0,0.75,0.35]
y = [1, 1, 0.5]
w = [0.5]*3
h = [0.5]*3
We also need to define the grid size (4x4) and the length of the 3d bars:
# define the 3D box dimensions
dx = dy = dz = 1
# define the grid 4x4
num_rows = 4
num_cols = 4
max_index = df.groupby("countries", sort = False)["sites"].max().tolist()
print(max_index)
[15, 10, 8]
Plot the chart¶
fig = plt.figure(figsize=(6, 6))
grouped_colors = []
for i, (country, group) in enumerate(df.groupby('countries', sort = False)):
# create a list with the repeated colors merged for each country
combined_colors = []
for _, row in group.iterrows():
combined_colors.extend([row['color']] * int(row['diff']))
grouped_colors.append(combined_colors)
#add the different axes
ax= fig.add_axes([x[i], y[i], w[i], h[i]], projection="3d")
color_index = 0
for j in range(num_rows):
for k in range(num_cols):
if color_index < max_index[i]:
x_start = j * dx
y_start = k * dy
z_start = 0
#add the bars
ax.bar3d(x_start, y_start, z_start, dx, dy, dz, color=grouped_colors[i][color_index], edgecolor='w', alpha=1)
color_index += 1
ax.set_title(country, weight = "bold")
ax.set(xlim = [0,4], ylim = [0,4], zlim = [0,1])
ax.set_aspect('equal')
ax.set_axis_off()
ax.set_xlim([-1, num_rows * dx])
ax.set_ylim([-1, num_cols * dy])
ax.set_zlim([0, dz])
ax.view_init(elev=30, azim=45)
Add the legend:
inset_ax = fig.add_axes([0.1, 0.3, 1, 0.2]) # [left, bottom, width, height]
inset_ax.axis('off')
sorted_df = df.sort_values("year")
colors = sorted_df['color'].unique()
x0 = 0
y = 1
for i, color in enumerate(colors[:3]):
inset_ax.plot([x0 + i, x0 + i +1], [y, y], color=color, lw=10)
inset_ax.text(x0 + 3.2, y, 'Before 2004', va='center', fontsize=10, color = legend_color)
# Shift x for second group
x0 = x0 + 9 # Add space between the two groups visually
for i, color in enumerate(colors[3:7]):
inset_ax.plot([x0 + i, x0 + i + 1], [y, y], color=color, lw=10)
inset_ax.text(x0 + 3.2, y, 'After 2004', va='center', fontsize=10, color = legend_color)
fig