Generate the data and import packages¶
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
color_dict = {
"Norway": "#2B314D",
"Denmark": "#A54836",
"Sweden": "#5375D4",
}
xy_ticklabel_color, grand_totals_color, grid_color, datalabels_color = "#757C85", "#101628", "#C8C9C9", "#FFFFFF"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13, 15, 4, 10, 5, 8],
}
df = pd.DataFrame(data)
df
| year | countries | sites | |
|---|---|---|---|
| 0 | 2004 | Sweden | 13 |
| 1 | 2022 | Sweden | 15 |
| 2 | 2004 | Denmark | 4 |
| 3 | 2022 | Denmark | 10 |
| 4 | 2004 | Norway | 5 |
| 5 | 2022 | Norway | 8 |
Next, we will custom sort by country and year sort_order_dict, to follow the same order as the original chart:
# custom sort
sort_order_dict = {"Denmark": 2, "Sweden": 3, "Norway": 1, 2004: 4, 2022: 5}
df = df.sort_values(by=["year","countries"], key=lambda x: x.map(sort_order_dict))
df
| year | countries | sites | |
|---|---|---|---|
| 4 | 2004 | Norway | 5 |
| 2 | 2004 | Denmark | 4 |
| 0 | 2004 | Sweden | 13 |
| 5 | 2022 | Norway | 8 |
| 3 | 2022 | Denmark | 10 |
| 1 | 2022 | Sweden | 15 |
Now we define the variables that will be used more than once.
The grouped_sites variable will packe the sites in pairs, so we can loop them like that in ax.bar
years = df.year.unique()
grouped_sites = df.groupby(["countries"])["sites"].apply(np.array)
grouped_sites
countries Denmark [4, 10] Norway [5, 8] Sweden [13, 15] Name: sites, dtype: object
Plot the stacked bar chart¶
We will use the ax.bar() method which requires the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x position of each bar | (values between 0 and 1) generated from range(len(years)) |
| y | The y position of each bar | The site values site in numpy format to match x coordinates |
| bottom | The bottom of each bar | I explain in more detail below. |
| width | The width of each bar | A hardcoded value to match the original chart. |
Explaining the bottom parameter¶
In Matplotlib, stacking bars requires starting from zero and then incrementally adding the heights of the bars below.
To match the original chart, we need to plot the countries in the order: Norway, Denmark, and Sweden. This means the bottom values for each bar will be calculated as follows:
| Step | Calculation | Result | Explanation |
|---|---|---|---|
| Initial bottom | [0, 0] |
np.zeros(len(years)) |
Start at zero for all years |
| Add Norway | [0 + 5, 0 + 8] |
[5, 8] |
Add Norway's values (5, 8) |
| Add Denmark | [5 + 4, 8 + 10] |
[9, 18] |
Add Denmark's values (4, 10) |
| Add Sweden | [9 + 13, 18 + 15] |
[22, 33] |
Add Sweden's values (13, 15) |
This way, each country's bars start where the previous one ended, creating the stacked effect.
fig, ax = plt.subplots(figsize=(7,10), facecolor = "#FFFFFF" )
#set the bottom of the stack bars
bottom = np.zeros(len(years))
for country, sites in grouped_sites.items():
color = color_dict.get(country, "#CCCCCC")
ax.bar(
range(len(years)),
sites,
bottom = bottom,
width = 0.6,
color = color,
label = country
)
bottom += sites
print(sites, list(range(len(years))), bottom)
[ 4 10] [0, 1] [ 4. 10.] [5 8] [0, 1] [ 9. 18.] [13 15] [0, 1] [22. 33.]
Add the grand totals to the chart¶
To calculate the grand totals, we sum the number of sites per year and put it in a list:
grand_total = df.groupby("year")["sites"].sum().tolist()
grand_total
[22, 33]
To add the grand totals to the plot, we will use ax.text() which requires the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x position of the text | (values between 0 and 1) generated from range(len(years)) |
| y | The y position of the text | grand_total plus an offset of 1. |
offset_label = 1 #offset the sub_total label by 1
# Show sum on each stacked bar
for i, total in enumerate(grand_total):
print(i, total)
ax.text(
i, #(0, 1) which is the position of the bars in the x axis
total + offset_label , #to add some space btw the bar and the label , offset by 1
total, #the label
ha ='center',
size = 20,
color = grand_totals_color)
fig
0 22 1 33
Add the data labels to the stacked chart bar¶
To find the position of each stacked bar, we will use rectangular patch method which requires the following parameters:
| Parameter | Description |
|---|---|
| xy | The xy positions or the coordinates of the lower-left corner of the rectangle. |
| width | The width of each rectangle |
| height | The height of each rectangle |
| rotation | The rotation angle of each rectangle, 0 for vertical bars. |
Now, we can use that information to feed the parameters of ax.text():
| Parameter | Description | Value |
|---|---|---|
| x | The x-position of the text. | get_x() + get_width()/2 to reach the midpoint horizontally. |
| y | The y-position of the text. | get_y() + get_height()/2 to reach the vertical center |
#add data labels
for bar in ax.patches:
print(bar)
ax.text(
bar.get_x() + bar.get_width() / 2, #see the explanation on the text above
bar.get_height()/2 + bar.get_y(),
round(bar.get_height()),
ha = 'center',
color = datalabels_color,
size = 16
)
fig
Rectangle(xy=(-0.3, 0), width=0.6, height=4, angle=0) Rectangle(xy=(0.7, 0), width=0.6, height=10, angle=0) Rectangle(xy=(-0.3, 4), width=0.6, height=5, angle=0) Rectangle(xy=(0.7, 10), width=0.6, height=8, angle=0) Rectangle(xy=(-0.3, 9), width=0.6, height=13, angle=0) Rectangle(xy=(0.7, 18), width=0.6, height=15, angle=0)
Add country legends on the most recent year only¶
To display country legends only on the bars to the right (to avoid repeating the same information), we will loop through the bars again — but this time, only for the bars corresponding to the most recent year last_year_bars.
last_year_bars = ax.patches[len(years) - 1 :: len(years)] # Only select the bars from the last stack (i.e., the last year)
for bar, country in zip(last_year_bars,df.countries.unique() ):
color = color_dict.get(country, "#CCCCCC") #map the colors to the country names
ax.text(
bar.get_x() + bar.get_width() + offset_label,
bar.get_y() + bar.get_height() / 2,
country,
ha = 'left',
va = 'center',
color = color,
size = 16
)
fig
Style the chart¶
And finally just style the chart to match the original.
First lets change the x tick labels to the years ax.xaxis.set_ticks and style the tick labels with ax.tick_params.
ax.xaxis.set_ticks(range(len(years)), labels=years)
ax.tick_params(
axis="both", # change both x and y
which="major", # major ticks
length=0, # remove the ticks
labelsize=16, # set the label size
colors=xy_ticklabel_color, # set the color (specify in the beginning of the script)
pad=15,
) # icrease the distance between the labels and the axis
fig
and some other styling changes:
ax.set_axisbelow(True) # set the grid lines in the BACK
ax.grid(
True, # show grid
axis="y", # but only the y axis
linestyle="solid",
linewidth=1,
color=grid_color,
)
ax.set_xlim(-1, len(years)) # set the limits of the x axis
ax.set_ylim(0, max(grand_total) + 4) # set the limits of the x axis
ax.set_frame_on(False) # Hide box around the axis
fig