Generate the data and import packages¶
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import pandas as pd
color_dict = {
(2004,"Norway"): "#9194A3",
(2022,"Norway"): "#2B314D",
(2004,"Denmark"): "#E2AFA5",
(2022,"Denmark"): "#A54836",
(2004,"Sweden"): "#C4D6F8",
(2022,"Sweden"): "#5375D4",
}
xy_ticklabel_color, xlabel_color, grand_totals_color, grid_color, datalabels_color = '#C8C9C9', "#101628", "#101628", "#C8C9C9", "#2B314D"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13,15,4,10,5,8]
}
df= pd.DataFrame(data)
df
| year | countries | sites | |
|---|---|---|---|
| 0 | 2004 | Sweden | 13 |
| 1 | 2022 | Sweden | 15 |
| 2 | 2004 | Denmark | 4 |
| 3 | 2022 | Denmark | 10 |
| 4 | 2004 | Norway | 5 |
| 5 | 2022 | Norway | 8 |
Then we need to add the following columns:
- the year_lbl ('22, '04),
- the title for the x axis pct_change and
- the color for each bar or lolipop
as well as custom sort the dataframe.
#custom sort
sort_order_dict = {
"Denmark":2,
"Sweden":3,
"Norway":1,
2004:4,
2022:5}
df = df.sort_values(by = ['year','countries',], key = lambda x: x.map(sort_order_dict))
# Add the x-axis labels
df['year_lbl'] = "'"+df['year'].astype(str).str[-2:].astype(str)
df['pct_change'] = df.groupby('countries', sort=False)['sites'].apply(lambda x: x.pct_change()).to_numpy()
#Add the color based on the color dictionary
df['color'] = df.set_index(['year', 'countries']).index.map(color_dict.get)
df
| year | countries | sites | year_lbl | pct_change | color | |
|---|---|---|---|---|---|---|
| 4 | 2004 | Norway | 5 | '04 | NaN | #9194A3 |
| 2 | 2004 | Denmark | 4 | '04 | 0.600000 | #E2AFA5 |
| 0 | 2004 | Sweden | 13 | '04 | NaN | #C4D6F8 |
| 5 | 2022 | Norway | 8 | '22 | 1.500000 | #2B314D |
| 3 | 2022 | Denmark | 10 | '22 | NaN | #A54836 |
| 1 | 2022 | Sweden | 15 | '22 | 0.153846 | #5375D4 |
and define some parameters that will be reused later.
The parameter x_coordinates creates a list of (0, 1), one position for each year.
x_coordinates = list(range(df.year.nunique()))
cnt_countries = df.countries.nunique()
groups = df.groupby('countries')
Plot the chart¶
This is a relatively straightforward build. We will use:
- ax.vlines() to draw the vertical lines of the lollipops.
- ax.scatter() to draw the circular markers at the top of each line.
Since we want one plot per country, we’ll create a grid of subplots and then:
- Use axes.ravel() to flatten the axes array.
- Loop through each axes object to plot the data for each country individually.
This approach keeps the code clean and avoids nested loops when working with subplot grids.
Parameters for the line¶
We are going to use ax.vlines() method to create the vertical lines and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x position of each line | (values between 0 and 1) x_coordinates |
| ymin | The start y position of each line | Zero in our case, the line starts at zero. |
| ymax | The end y position of each line | The site value goup.sites in our case. |
Parameters for the bubbles¶
We are going to use ax.scatter() method to create the vertical lines and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x positions of each dot | (values between 0 and 1) x_coordinates |
| y | The y positions of each dot | The site value goup.sites in our case. |
| s | The area of each dot | Hardcoded but can be calculated. |
Parameters for the annotation¶
We are going to use ax.text() method to create the vertical lines and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x position of the text | The looped site value goup.sites in our case |
| y | The y position of the text | x x is the value of i (0 or 1) and y is the site value + an offset of 1 |
fig, axes = plt.subplots(ncols=cnt_countries, nrows=1, figsize=(8,6), sharex=True, sharey=True, facecolor= "white")
fig.tight_layout(pad=3.0)
#loop over the countries
for (ctry, group), ax in zip(groups, axes.ravel()):
pct = group['pct_change'].max()
#add the vertical lines of the lolipoll
ax.vlines(
x_coordinates,
0,
group.sites,
color = group.color,
lw = 4,
zorder = 1)
#add the circles at the end of the lolipoll bars
ax.scatter(
x_coordinates,
group.sites,
s = 150,
c = group.color ,
edgecolors = "w",
lw = 2,
zorder = 2)
offset = 1
#add the data labels
for i, site in enumerate(group.sites):
ax.text(
i,
site + offset,
site,
size = 13,
color = datalabels_color,
weight = "light",
ha = "center",
va = "center"
)
#add the x-axis titles
ax.set_xlabel(
f'\u25B2\n{pct:.0%}\n\n{ctry}',
color = xlabel_color,
size = 12,
weight = "light",
labelpad =12
)
#format the plots
ax.set(
xlim = (-0.5,cnt_countries-1.5),
ylim = (0,df.sites.max()+5)
)
ax.tick_params(
axis = 'both',
which = 'both',
length = 0,
labelleft = False,
labelsize = 12,
colors =xy_ticklabel_color
)
ax.spines[['top', 'left', 'right']].set_visible(False)
ax.spines['bottom'].set_color(grid_color)