Generate the data and import packages¶
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
color_dict = {
"Norway": "#2B314D",
"Denmark": "#A54836",
"Sweden": "#5375D4",
}
datalabels_color = "w"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13,15,4,10,5,8,]
}
df= pd.DataFrame(data)
df['year_lbl'] ="'"+df['year'].astype(str).str[-2:].astype(str)
df['sub_total'] = df.groupby('year')['sites'].transform('sum')
df['pct_group'] =100* df['sites'] / df.sub_total
df["colors"] = df.countries.map(color_dict)
df = df.sort_values([ "year", "countries",], ascending = False)
df
| year | countries | sites | year_lbl | sub_total | pct_group | colors | |
|---|---|---|---|---|---|---|---|
| 1 | 2022 | Sweden | 15 | '22 | 33 | 45.454545 | #5375D4 |
| 5 | 2022 | Norway | 8 | '22 | 33 | 24.242424 | #2B314D |
| 3 | 2022 | Denmark | 10 | '22 | 33 | 30.303030 | #A54836 |
| 0 | 2004 | Sweden | 13 | '04 | 22 | 59.090909 | #5375D4 |
| 4 | 2004 | Norway | 5 | '04 | 22 | 22.727273 | #2B314D |
| 2 | 2004 | Denmark | 4 | '04 | 22 | 18.181818 | #A54836 |
Define the coordinates of the bars:
cnt_countries = df.countries.nunique()
# width/depth of bars
dx = 100
dz = np.sort(df.sub_total) *10
# seperation between the two bars
separation = cnt_countries * dx
# x and y anchor points of all the bars
x = [0] * cnt_countries + [separation] * cnt_countries
z = 0
# bar y-positions and heights
y = [0]
dy = []
dy_separation = 200
first = True # Flag to handle the first group differently
for i, (year, group) in enumerate(df.groupby("year", sort=False)):
pcts = group['pct_group'].tolist()
zp = np.cumsum(pcts).tolist()
dy.extend(pcts )
if first:
# First group: take all except last, then add 100
y.extend(zp[:-1])
y.append(dy_separation)
first = False
else:
# Later groups: remove last, then offset by 100
adjusted = [x + dy_separation for x in zp[:-1]]
y.extend(adjusted)
print(y, dy)
[0, 45.45454545454545, 69.69696969696969, 200, 259.0909090909091, 281.8181818181818] [45.45454545454545, 24.242424242424242, 30.303030303030305, 59.09090909090909, 22.727272727272727, 18.181818181818183]
Plot the chart¶
fig = plt.figure(figsize=(15,10))
ax = fig.add_subplot( projection="3d")
#add the 3d bars
ax.bar3d(
x,
y,
z,
dx,
dy,
dz,
color= df.colors
)
ax.set_aspect("equal")
Add the correct view:
ax.view_init(
elev=30, # looking from 30 degrees above the x-y plane (higher/ lower above the plot)
azim=-50, # rotates camera around z-axis (left/ right around the plot)
roll=0 # tilts the view
)
fig
Add the data labels:
for i, (x__, y__, width, height) in enumerate(zip(x, y, dy, dz)):
#add the data labels
ax.text(
x__ + dx/2,
y__ + width/2,
height ,
f"{int(width)}%",
size = 8,
ha = "center",
va = "center",
color = datalabels_color
)
if i in (0,3):
print(x__,y__, width, height)
#add the year labels
ax.text(
x__ + dx + 30,
y__ - dx/2,
0,
df.year[i],
size = 12,
weight = "bold"
)
#add the sub_total labels
ax.text(
x__ - dx / 2,
y__ + dx + dx/2,
height ,
df['sub_total'][i],
size = 10,
color = datalabels_color,
ha = "center",
va = "center",
bbox=dict(
facecolor='black',
edgecolor='black',
boxstyle='round,pad=0.5'
)
)
ax.text(
x__ - dx / 2,
y__ + dx + dx/2,
height -15,
"\u25BC",
size = 10,
color = 'k',
ha = "center",
va = "center"
)
fig
0 0 45.45454545454545 220 300 200 59.09090909090909 330
and the final styling and legend:
#add legend
lines = [Line2D([0], [0], color=c, marker='o',linestyle='', markersize=10,) for c in df.colors.unique()]
labels = df.countries.unique()
fig.legend(
lines,
labels,
bbox_to_anchor=(0.5, 0.05),
loc="lower center",
ncols = 3,
frameon=False,
fontsize= 14
)
ax.set_axis_off()
fig