Explain ax.bar3d¶
To create the stacked 3D bar we will use the ax.bar3d() method.
If you understand the parameters, you will be able to create it easily:
| Parameter | Description |
|---|---|
| x | The x position of the bar |
| x | The y position of the bar |
| z | The z position of the bar |
| dx | The width of the bar in the X direction |
| dy | The depth of the bar in the Y direction |
| dz | The height of the bar in the Z direction |
Generate the data and import packages¶
Now we can generate the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
color_dict = {
"Norway": "#2B314D",
"Denmark": "#A54836",
"Sweden": "#5375D4",
}
xy_ticklabel_color, xlabel_color, grand_totals_color, legend_color, grid_color, datalabels_color, = "#101628", "#101628", "#101628", "#101628", "#C8C9C9", "#757C85"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13, 15, 4, 10, 5, 8, ],
}
df = pd.DataFrame(data)
# custom sort
sort_order_dict = {"Denmark": 3, "Sweden": 1, "Norway": 2, 2004: 4, 2022: 5}
df = df.sort_values(by=["year","countries"], key=lambda x: x.map(sort_order_dict))
df["sub_total"] = df.groupby("year")["sites"].transform("sum")
df["colors"] = df.countries.map(color_dict)
df
| year | countries | sites | sub_total | colors | |
|---|---|---|---|---|---|
| 0 | 2004 | Sweden | 13 | 22 | #5375D4 |
| 4 | 2004 | Norway | 5 | 22 | #2B314D |
| 2 | 2004 | Denmark | 4 | 22 | #A54836 |
| 1 | 2022 | Sweden | 15 | 33 | #5375D4 |
| 5 | 2022 | Norway | 8 | 33 | #2B314D |
| 3 | 2022 | Denmark | 10 | 33 | #A54836 |
We define some variables that we will need later:
years = df.year.unique()
countries = df.countries.unique()
cnt_countries = len(countries)
sub_totals = df.sub_total.unique()
colors = df.colors
Define x, y, z¶
Let's manually create the x,y and z base coordinates, for the first bar:
(0, 0, 0)
(0, 0, 13)
(0, 0, 18)
and for the second bar:
(8, 4, 0)
(8, 4, 15)
(8, 4, 23)
x_separation and y_separation are just constants, to define the distance betwen the two bars in the x,y axis. We choose 8 and 4 to match the original chart.
Note that we will build the coordinates in lists because the function accepts lists and will loop correctly.
Note also that the z values are accumulated, which is what we do below:
# x and y anchor points of all the bars
x_separation = 8
y_separation = 4
x = np.repeat(np.arange(0, cnt_countries - 1) * x_separation, cnt_countries).tolist()
y = np.repeat(np.arange(0, cnt_countries - 1) * y_separation, cnt_countries).tolist()
# accumulate the z values
z = np.array( df.groupby("year", sort=False).sites.apply(list).tolist() ) # convert to a numpy 2d array
z = np.cumsum(z, axis=1)[:, : cnt_countries - 1] # accumulate sum, remove last item
z = ( np.insert(z, 0, 0, axis=1).flatten().tolist() ) # add a zero at the beginning, flatten and convert to list
for x_, y_, z_ in zip(x, y, z):
print(x_, y_, z_)
0 0 0 0 0 13 0 0 18 8 4 0 8 4 15 8 4 23
Define the dx, dy and dz parameters¶
We choose a fixed value of 2 for dx and dy to match the original chart. dz will be the sites number df.sites (the height of the bars) as alist.
# width/depth/height of bars
dx = 2
dy = 2
dz = df.sites.to_list()
Plot the chart¶
Let's start by plotting the 3d bars first so you can see where they land and their coordinates.
To match the original chart, we will use ax.set_aspect to equal.
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(1, 1, 1, projection="3d")
ax.bar3d(x, y, z, dx, dy, dz, color=colors, label=countries)
ax.set_aspect("equal")
Add the data labels¶
We will use ax.text() method and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x position of the text | x coordinate with an 1.5 offset |
| y | The y position of the text | y coordinate |
| text | The text to display | z coordinate plus half the site to put it in the middle |
offset = 1.5
# annotate the data labels
for x__, y__, z__, site in zip(x, y, z, dz):
print(x__,y__,z__+ site / 2)
ax.text(
x__ - offset,
y__,
z__ + site / 2,
site,
size = 14,
ha = "right",
color = datalabels_color
)
fig
0 0 6.5 0 0 15.5 0 0 20.0 8 4 7.5 8 4 19.0 8 4 28.0
Add the axis labels and grand totals¶
To place the axis labels (the years) and the grand totals, the coordinates are:
(0, 0) for the left bar
(8, 4) for the right bar
We can get those coordinates if we enumerate the loop and we multiply by the number of countries by x[i * cnt_countries]
for i, (year, sub_total) in enumerate(zip(years, sub_totals)):
# add year tick labels
print(x[i*cnt_countries], y[i*cnt_countries], -3)
ax.text(
x[i * cnt_countries],
y[i * cnt_countries],
z = -3,
s = year,
color = xy_ticklabel_color,
weight = "bold",
fontsize = 12,
)
#print(x, y, x[i-cnt_countries], y[i-cnt_countries])
#add grand totals
offset_gt = 0.5
ax.text(
x[i * cnt_countries] + offset_gt,
y[i * cnt_countries],
z = z[i * cnt_countries] + sub_total + 2,
s = sub_total,
fontsize = 12,
weight = "bold",
color = grand_totals_color,
bbox = dict(facecolor="none", edgecolor="#EBEDEE", boxstyle="round,pad=0.3"),
)
fig
0 0 -3 8 4 -3
Only the styling and the lengend left, lets put everyhing togheter!