Generate the data and import packages¶
First, we need to create the data. I'll start by defining it as a dictionary and then convert it into a pandas DataFrame, since pandas is commonly used in many projects for data manipulation.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap
code_dict = {"Norway": "NO", "Denmark": "DK", "Sweden": "SE", }
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries" : ["Sweden", "Sweden", "Denmark", "Denmark", "Norway", "Norway"],
"sites": [13,15,4,10,5,8]
}
df= pd.DataFrame(data)
df['pct_change'] = df.groupby('countries', sort=False)['sites'].apply(
lambda x: x.pct_change()).to_numpy()*-1
df['ctry_code'] = df.countries.map(code_dict)
df['diff'] = df.groupby(['countries'])['sites'].diff()
df['diff'] = df['diff'].fillna(df.sites)
sort_order_dict = {"Denmark": 3, "Sweden": 1, "Norway": 2, 2004: 4, 2022: 5}
df = df.sort_values( by=["countries","year",], key=lambda x: x.map(sort_order_dict))
df
| year | countries | sites | pct_change | ctry_code | diff | |
|---|---|---|---|---|---|---|
| 0 | 2004 | Sweden | 13 | NaN | SE | 13.0 |
| 1 | 2022 | Sweden | 15 | -0.153846 | SE | 2.0 |
| 4 | 2004 | Norway | 5 | NaN | NO | 5.0 |
| 5 | 2022 | Norway | 8 | -0.600000 | NO | 3.0 |
| 2 | 2004 | Denmark | 4 | NaN | DK | 4.0 |
| 3 | 2022 | Denmark | 10 | -1.500000 | DK | 6.0 |
countries = df.countries.unique()
codes = df.ctry_code.unique()
pct_changes = df['pct_change'].max()
#color of the diamonds
colors = ["#CC5A43","#5375D4"]*3
# use a colormap
cmap = plt.cm.RdBu
x_coord = df.groupby('countries')['diff'].apply(lambda x: x.values) #convert the columns into numpy 2D array
print(x_coord)
countries Denmark [4.0, 6.0] Norway [5.0, 3.0] Sweden [13.0, 2.0] Name: diff, dtype: object
Plot the chart¶
fig, ax = plt.subplots(figsize=(8,5), facecolor = "#FFFFFF")
ax.set(xlim=[0, df.sites.max() + 1], ylim=[-1, df.countries.nunique()])
bar_height = 0.3
center_bar = bar_height / 2
bars = []
for i, (country, group) in enumerate(df.groupby("countries", sort = False)):
diff = group['diff'].to_numpy()
bar = ax.broken_barh(
[diff],
(i - center_bar, 0.3)
)
#append the bars to add gradient later
bars.append(bar)
#add country code
offset = 1.5
ax.annotate(
text=group.ctry_code.iloc[0],
xy=(diff[0] - offset, i ),
color='#6F7780',
fontsize=12,
ha='left',
va='center',
)
#add percentage
ax.annotate(
text=f"{group['pct_change'].iloc[1] * -100:+.0f}%",
xy=(diff[0] + diff[1] / 2, i),
color='w',
fontsize=12,
ha='center',
va='center'
)
Add gradient to the bars¶
def gradientbars(bars, ax):
colors = [(1, 0, 0), (0, 0, 1), ] # first color is red, last is blue
cm = LinearSegmentedColormap.from_list(
"Custom", colors, N=256) # Conver to color map
mat = np.indices((10,10))[1] # define a matrix for imshow
lim = ax.get_xlim()+ax.get_ylim()
for bar in bars:
bar.set_zorder(1)
bar.set_facecolor("none")
# get the coordinates of the rectangle
x_all = bar.get_paths()[0].vertices[:, 0]
y_all = bar.get_paths()[0].vertices[:, 1]
# Get the first coordinate (lower left corner)
x,y = x_all[0], y_all[0]
# Get the height and width of the rectangle
h, w = max(y_all) - min(y_all), max(x_all) - min(x_all)
# Show the colormap
ax.imshow(mat, extent=[x,x+w,y,y+h], aspect="auto", zorder=0, cmap=cm, alpha=0.2)
ax.axis(lim)
gradientbars(bars, ax)
fig
Add the bar ends¶
colors = ["#CC5A43","#5375D4"]*3
ax.scatter(
df.sites,
df.countries,
marker="D",
s=200,
color = colors)
fig
Add styling and legends¶
ax.xaxis.set_ticks(np.arange(0,20,5),labels = [0,5,10,15])
# add minor ticks every 1 unit
ax.xaxis.set_minor_locator(plt.MultipleLocator(1))
ax.tick_params(
axis="x",
which="both",
length=0,
labelsize=14,
colors= '#6F7780'
)
ax.grid(
which='both',
axis='x',
linestyle='-',
alpha=0.4,
color = "#C8C9C9"
)
ax.set_axisbelow(True) #put grid in the back
ax.set_yticks([])
ax.set_frame_on(False)
#add legend
labels = df.year.unique()
colors = set(colors)
lines = [Line2D([0], [0], color=c, marker='D',linestyle='', markersize=12,) for c in colors]
leg = ax.get_legend()
fig.legend( lines,labels,
labelcolor="#6F7780",
bbox_to_anchor=(0.3, -0.1),
loc="lower center",
ncols = 2,
frameon=False,
fontsize= 12
)
fig