Generate the data and import packages¶
First we need to create the data. I will do it using a dictionary and then converting it to a pandas dataframe as a lot projects use pandas to work with data.
import matplotlib.pyplot as plt
from matplotlib.patches import Arc
import pandas as pd
import numpy as np
color_dict = {
"Norway": "#2B314D",
"Denmark": "#A54836",
"Sweden": "#5375D4",
}
code_dict = {
"Norway": "NO",
"Denmark": "DK",
"Sweden": "SE",
}
xy_ticklabel_color, grid_color, datalabels_color = "#757C85", "#C8C9C9", "#FFFFFF"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Denmark", "Denmark", "Norway", "Norway", "Sweden", "Sweden", ],
"sites": [4, 10, 5, 8, 13, 15],
}
df = pd.DataFrame(data)
df = df.sort_values(["year"], ascending=True).reset_index(drop=True)
df["year_lbl"] = "'" + df["year"].astype(str).str[-2:].astype(str)
# map the colors of a dict to a dataframe
df["colors"] = df.countries.map(color_dict)
df["ctry_codes"] = df.countries.map(code_dict)
df
| year | countries | sites | year_lbl | colors | ctry_codes | |
|---|---|---|---|---|---|---|
| 0 | 2004 | Denmark | 4 | '04 | #A54836 | DK |
| 1 | 2004 | Norway | 5 | '04 | #2B314D | NO |
| 2 | 2004 | Sweden | 13 | '04 | #5375D4 | SE |
| 3 | 2022 | Denmark | 10 | '22 | #A54836 | DK |
| 4 | 2022 | Norway | 8 | '22 | #2B314D | NO |
| 5 | 2022 | Sweden | 15 | '22 | #5375D4 | SE |
Define some variables that will be reused:
sites = df.sites
colors = df.colors
Plot the chart¶
We will use ax.scatter() method and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x positions of each dot | sites |
| y | The y positions of each dot | sites |
| s | The area of each dot |
We will also set the boundaries of the chart already now, starting from zero and extending +3.
fig, ax = plt.subplots(figsize=(6, 6), sharex=True, sharey=True, facecolor="#FFFFFF")
ax.scatter(
sites,
sites,
s=340,
c=df.colors)
ax.set(
xlim = (0, sites.max() + 3),
ylim = (0, sites.max() + 3)
)
[(0.0, 18.0), (0.0, 18.0)]
Add the diagonal line¶
We will use ax.vlines() method and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| xs | The start x position of the line | |
| xe | The end x position of the line | |
| ys | The start y position of the line | |
| ye | The end e position of the line |
It is possible to use also (start, end) and slope.
So to add a passing line between 0 (start) and 18 (end), we need ((0,0), (18,18)).
Now, we will use the ax.get method to get the first and last point of the chart, which will be the first and last point of the line. It is enough to get the x_lim because x_lim = y_lim.
print(ax.get_xlim(), ax.get_ylim())
(np.float64(0.0), np.float64(18.0)) (np.float64(0.0), np.float64(18.0))
start = ax.get_xlim()[0]
end = ax.get_xlim()[1]
print(start, end)
ax.axline(
(start, start),
(end, end),
zorder=0,
lw=1,
color=xy_ticklabel_color,
)
fig
0.0 18.0
Add the labels¶
We will use ax.text() method and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| x | The x position of the text | sites |
| y | The y position of the text | sites |
| text | The text to display | year_label |
We need to loop over the sites column
for i, row in enumerate(df.itertuples()):
ax.text(
row.sites,
row.sites,
row.year_lbl,
color = datalabels_color,
va = "center",
ha = "center"
)
#anotate year labels
offset = 1
ax.text(
row.sites + offset,
row.sites,
row.sites,
color = xy_ticklabel_color,
va = "center",
ha = "center",
)
fig
Add the arcs¶
We will use ax.annotate() method and need the following parameters:
| Parameter | Description | Value |
|---|---|---|
| text | The annotation | Empty in this case |
| xy | The start x and y positions of the arrow | x1 |
| xytext | The end x and y positions of the arrow | x2 |
| arrowprops | The arrow between xy and xytext |
We will use groupby and apply to group the sites in pairs by country codes and feed them to ax.annotate.
grouped_sites = df.groupby(["ctry_codes", "colors"])["sites"].apply(np.array)
grouped_sites
ctry_codes colors DK #A54836 [4, 10] NO #2B314D [5, 8] SE #5375D4 [13, 15] Name: sites, dtype: object
rad = 0.6
for group, site in grouped_sites.items():
x1 = site[0]
x2 = site[1]
code = group[0]
color = group[1]
ax.annotate(
"",
xy = (x1, x1),
xytext = (x2, x2),
zorder = 1,
arrowprops = dict(
arrowstyle = "-",
connectionstyle = f"arc3,rad={rad}",
color = color,
linewidth = 2,
alpha = 0.5,
linestyle = "-",
antialiased = True,
)
)
fig
Add the labels in the arcs¶
For the labels we need to find the midpoint of the arc.
connectionstyle="arc3,rad=r" draws a circular arc between two points. The midpoint of this arc is offset perpendicular to the chord connecting the points, by a distance based on r.
rad = 0.6
for group, site in grouped_sites.items():
x1 = site[0]
x2 = site[1]
code = group[0]
color = group[1]
ax.annotate(
"",
xy = (x1, x1),
xytext = (x2, x2),
zorder = 1,
arrowprops = dict(
arrowstyle = "-",
connectionstyle = f"arc3,rad={rad}",
color = color,
linewidth = 2,
alpha = 0.5,
linestyle = "-",
antialiased = True,
)
)
# Calculate label positions
p1 = np.array([x1, x1])
p2 = np.array([x2, x2])
mid = (p1 + p2) / 2 #The center of the chord between the points
vec = p2 - p1 #the direction of the chord
length = np.linalg.norm(vec) #the length of the chord
perp = np.array([-vec[1], vec[0]]) / length #the direction fo the arc (upwards, downwards)
arc_height = rad * (length / 2) #The offset along the perpendicular, controlled by rad
arc_mid = mid + arc_height * perp #The actual midpoint of the arc, i.e. highest point of the curve
label_x, label_y = arc_mid
# Add labels on arc
ax.text(
label_x,
label_y,
code,
color = color,
ha = "center",
va = "center",
fontsize = 9,
bbox = dict(
boxstyle = "round,pad=0.2",
ec = "white",
fc = "white",
lw= 2))
ax.set_axis_off()
fig
The final code¶
import matplotlib.pyplot as plt
from matplotlib.patches import Arc
import pandas as pd
color_dict = {
"Norway": "#2B314D",
"Denmark": "#A54836",
"Sweden": "#5375D4",
}
code_dict = {
"Norway": "NO",
"Denmark": "DK",
"Sweden": "SE",
}
xy_ticklabel_color, grid_color, datalabels_color = "#757C85", "#C8C9C9", "#FFFFFF"
data = {
"year": [2004, 2022, 2004, 2022, 2004, 2022],
"countries": ["Denmark", "Denmark", "Norway", "Norway", "Sweden", "Sweden", ],
"sites": [4, 10, 5, 8, 13, 15],
}
df = pd.DataFrame(data)
df = df.sort_values(["year"], ascending=True).reset_index(drop=True)
df["year_lbl"] = "'" + df["year"].astype(str).str[-2:].astype(str)
# map the colors of a dict to a dataframe
df["colors"] = df.countries.map(color_dict)
df["ctry_codes"] = df.countries.map(code_dict)
#define the variables
sites = df.sites
#plot the chart
fig, ax = plt.subplots(figsize=(6, 6), sharex=True, sharey=True, facecolor="#FFFFFF")
ax.scatter(sites, sites, s=340, c=df.colors)
ax.set(
xlim = (0, sites.max() + 3),
ylim = (0, sites.max() + 3)
)
#get the start and end of the chart to plot the line
start = ax.get_xlim()[0]
end = ax.get_xlim()[1]
#plot the diagonal line
ax.axline(
(start, start),
(end, end),
zorder=0,
lw=1,
color=xy_ticklabel_color,
)
#plot the annotations
for i, row in enumerate(df.itertuples()):
ax.text(
row.sites,
row.sites,
row.year_lbl,
color = datalabels_color,
va = "center",
ha = "center"
)
#anotate year labels
offset = 1
ax.text(
row.sites + offset,
row.sites,
row.sites,
color = xy_ticklabel_color,
va = "center",
ha = "center",
)
rad = 0.6
for group, site in grouped_sites.items():
x1 = site[0]
x2 = site[1]
code = group[0]
color = group[1]
ax.annotate(
"",
xy = (x1, x1),
xytext = (x2, x2),
zorder = 1,
arrowprops = dict(
arrowstyle = "-",
connectionstyle = f"arc3,rad={rad}",
color = color,
linewidth = 2,
alpha = 0.5,
linestyle = "-",
antialiased = True,
)
)
# Calculate label positions
p1 = np.array([x1, x1])
p2 = np.array([x2, x2])
mid = (p1 + p2) / 2 #The center of the chord between the points
vec = p2 - p1 #the direction of the chord
length = np.linalg.norm(vec) #the length of the chord
perp = np.array([-vec[1], vec[0]]) / length #the direction fo the arc (upwards, downwards)
arc_height = rad * (length / 2) #The offset along the perpendicular, controlled by rad
arc_mid = mid + arc_height * perp #The actual midpoint of the arc, i.e. highest point of the curve
label_x, label_y = arc_mid
# Add labels on arc
ax.text(
label_x,
label_y,
code,
color = color,
ha = "center",
va = "center",
fontsize = 9,
bbox = dict(
boxstyle = "round,pad=0.2",
ec = "white",
fc = "white",
lw= 2))
ax.set_axis_off()