Skip to content
Open
173 changes: 144 additions & 29 deletions forestplot/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.pyplot import Axes
from matplotlib.patches import Polygon

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -47,9 +49,16 @@ def draw_ci(
if ll is not None:
lw = kwargs.get("lw", 1.4)
linecolor = kwargs.get("linecolor", ".6")
# 250808: originally, the y is specified as dataframe[yticklabel]. This works until there are duplicate values in the yticklabel column. In this case, pyplot skips the duplicated values without yielding any warning. This is very bad practice. When plotting, always specify numerical x-y coordinates!!!
# x = dataframe[estimate].to_numpy(copy=False) # float
# lo = dataframe[ll].to_numpy(copy=False)
# hi = dataframe[hl].to_numpy(copy=False)
# y = np.arange(len(x), dtype=float)
# xerr = np.vstack([x - lo, hi - x])

ax.errorbar(
x=dataframe[estimate],
y=dataframe[yticklabel],
y=np.arange(dataframe.shape[0]),
xerr=[dataframe[estimate] - dataframe[ll], dataframe[hl] - dataframe[estimate]],
ecolor=linecolor,
elinewidth=lw,
Expand All @@ -60,9 +69,12 @@ def draw_ci(
ax.set_xscale("log", base=10)
return ax

# Utility function, determine marker size.
def determine_marker_size(weight):
return np.power(2+0.125*weight,2)

def draw_est_markers(
dataframe: pd.core.frame.DataFrame, estimate: str, yticklabel: str, ax: Axes, **kwargs: Any
dataframe: pd.core.frame.DataFrame, estimate: str, yticklabel: str, ax: Axes, weight_col=None, total_col=None, **kwargs: Any
) -> Axes:
"""
Draws the markers of the estimates using the Matplotlib plt.scatter API.
Expand All @@ -79,26 +91,68 @@ def draw_est_markers(
Name of column in intermediate dataframe containing the formatted yticklabels.
ax (Matplotlib Axes)
Axes to operate on.
weight_col (str)
If specified, marker size will be drawn proportionally to the weight of the study.
total_col (str)
The column containing the indicator for whether a row is a subtotal row. If it is, a subtotal row, set the markersize to 0 despite the weight being 100.

Returns
-------
Matplotlib Axes object.
"""
marker = kwargs.get("marker", "s")
markersize = kwargs.get("markersize", 40)
# markersize = kwargs.get("markersize", 40)
markercolor = kwargs.get("markercolor", "darkslategray")
markeralpha = kwargs.get("markeralpha", 0.8)
ax.scatter(
y=yticklabel,
x=estimate,
data=dataframe,
marker=marker,
s=markersize,
color=markercolor,
alpha=markeralpha,
)

# 250715: draw marker sizes proportionally to study weights
if weight_col!=None:
if not pd.isnull(dataframe[estimate]).all():
dataframe["markersize"] = determine_marker_size(dataframe[weight_col])
if total_col!=None:
dataframe.loc[dataframe[total_col]==1,"markersize"]=0
markersize = "markersize"
if weight_col==None:
markersize = kwargs.get("markersize", 40)


# 250714: some dataframes are empty. In such cases, we still draw an empty graph. But of course we don't need markers on an empty graph!
if not pd.isnull(dataframe[estimate]).all():
ax.scatter(
y=range(len(dataframe)),
x=estimate,
data=dataframe,
marker=marker,
s=markersize,
color=markercolor,
alpha=markeralpha,
)
return ax

def draw_total_diamond(
dataframe: pd.core.frame.DataFrame,
total_col: str,
estimate: str,
ll: str,
hl: str,
ax: Axes,
**kwargs: Any
) -> Axes:
height = 0.8 # total height of the diamond from top to bottom
# print(height)
for ii,row in dataframe.iterrows():
if row[total_col]==1:
# print(f"Row {ii} is total!")
ci_low = row[ll]
ci_high = row[hl]
val = row[estimate]
diamond = Polygon(
# left, top, right, bottom
[(ci_low, ii), (val, ii+height/2), (ci_high, ii), (val, ii-height/2)],
closed = True, facecolor="black", zorder=10
)
ax.add_patch(diamond)
return ax

def draw_ref_xline(
ax: Axes,
Expand Down Expand Up @@ -140,7 +194,7 @@ def draw_ref_xline(


def right_flush_yticklabels(
dataframe: pd.core.frame.DataFrame, yticklabel: str, flush: bool, ax: Axes, **kwargs: Any
dataframe: pd.core.frame.DataFrame, yticklabel: str, flush: bool, ax: Axes, flag_col: str, **kwargs: Any
) -> float:
"""Flushes the formatted ytickers to the left. Also returns the amount of max padding in the window width.

Expand All @@ -167,6 +221,14 @@ def right_flush_yticklabels(
fontsize = kwargs.get("fontsize", 12)
# plt.draw()
fig = plt.gcf()
# fig.canvas.draw()
# fig = ax.figure

# 250804: solve the mystery where the plot sometimes drop the header row!
# the number of rows in dataframe[yticklabel] is sometimes not equal to the number of y-ticks automatically generated by matplotlib.
# TakuaLiu: I still don't understand what triggers this mismatch, but we can eliminate it by forcing the number of yticks be the same as the number of yticklabels!
yticks = range(len(dataframe))
ax.set_yticks(yticks)
if flush:
ax.set_yticklabels(
dataframe[yticklabel], fontfamily=fontfamily, fontsize=fontsize, ha="left"
Expand All @@ -175,20 +237,35 @@ def right_flush_yticklabels(
ax.set_yticklabels(
dataframe[yticklabel], fontfamily=fontfamily, fontsize=fontsize, ha="right"
)
if len(flag_col)>0 and (flag_col in dataframe.columns):
# print(dataframe.columns)
for label, fg in zip(ax.get_yticklabels(), dataframe[flag_col]):
if pd.isnull(fg):
continue
if fg>0:
# print(f"Should set the color of {label}")
label.set_color("red")
# nlast = len(ax.get_yticklabels())
# print(nlast)
# print(ax.get_yticklabels())
yax = ax.get_yaxis()
try:
pad = max(
T.label.get_window_extent(renderer=fig.canvas.get_renderer()).width
for T in yax.majorTicks
)

except AttributeError:
pad = max(
T.label1.get_window_extent(renderer=fig.canvas.get_renderer()).width
for T in yax.majorTicks
)
# for T in yax.majorTicks:
# print(T.label1)
# print(T.label1.get_window_extent(renderer=fig.canvas.get_renderer()).width)
pad = pad* 72.0 / fig.dpi
if flush:
yax.set_tick_params(pad=pad)

return pad


Expand Down Expand Up @@ -228,6 +305,7 @@ def draw_pval_right(
inv = ax.transData.inverted()
righttext_width = 0
fig = plt.gcf()
# fig = ax.figure
for _, row in dataframe.iterrows():
yticklabel1 = row[yticklabel]
yticklabel2 = row["formatted_pval"]
Expand Down Expand Up @@ -327,6 +405,7 @@ def draw_yticklabel2(
inv = ax.transData.inverted()
righttext_width = 0
fig = plt.gcf()
# fig = ax.figure
for ix, row in dataframe.iterrows():
yticklabel1 = row["yticklabel"]
yticklabel2 = row["yticklabel2"]
Expand Down Expand Up @@ -362,6 +441,23 @@ def draw_yticklabel2(
righttext_width = max(righttext_width, x1)
return ax, righttext_width

# utility for draw_ylabel1()
def shorten_label(txt, orig_txt, attempt=0):
if len(txt.split(":"))>1:
# have both analysis group and subgroup info
# in this case, keep only subgroup info
return txt.split(":")[-1]
if (len(txt.split(":"))==1 and len(orig_txt.split(":"))>1) or len(orig_txt.split(":"))==1:
# originally have both analysis group and subgroup info, tried keeping only the subgroup info, but still too long
# or, there is only analysis group to begin with
new_txt = ""
for tt in orig_txt.split(":"):
new_txt += " ".join(tt.split(" ")[:max(3,15-attempt)]) + " (...) "
txt = new_txt
if attempt > 10:
# unlikely, but if this happens, we don't care about what's being written, just reduce the length!
txt = orig_txt[:(max(20, len(orig_txt)-attempt))] + "(...)"
return txt

def draw_ylabel1(ylabel: str, pad: float, ax: Axes, **kwargs: Any) -> Axes:
"""
Expand All @@ -380,22 +476,35 @@ def draw_ylabel1(ylabel: str, pad: float, ax: Axes, **kwargs: Any) -> Axes:
-------
Matplotlib Axes object.
"""

ylabel_orig = ylabel
fig = plt.gcf()
# fig = ax.figure
fontsize = kwargs.get("fontsize", 12)
ax.set_ylabel("")
decent_length = False
attempt = 0
if ylabel is not None:
# Retrieve settings from kwargs
ylabel1_size = kwargs.get("ylabel1_size", 1 + fontsize)
ylabel1_fontweight = kwargs.get("ylabel1_fontweight", "bold")
ylabel_loc = kwargs.get("ylabel_loc", "top")
ylabel_angle = kwargs.get("ylabel_angle", "horizontal")
ax.set_ylabel(
ylabel,
loc=ylabel_loc,
labelpad=-pad,
rotation=ylabel_angle,
size=ylabel1_size,
fontweight=ylabel1_fontweight,
)
while not decent_length:
# Retrieve settings from kwargs
ylabel1_size = kwargs.get("ylabel1_size", 1 + fontsize)
ylabel1_fontweight = kwargs.get("ylabel1_fontweight", "bold")
ylabel_loc = kwargs.get("ylabel_loc", "top")
ylabel_angle = kwargs.get("ylabel_angle", "horizontal")
ax.set_ylabel(
ylabel,
loc=ylabel_loc,
labelpad=-pad,
rotation=ylabel_angle,
size=ylabel1_size,
fontweight=ylabel1_fontweight,
)
label_w = ax.yaxis.label.get_window_extent(renderer=fig.canvas.get_renderer()).width
label_w = label_w * 72.0/fig.dpi
decent_length = label_w <= pad
ylabel = shorten_label(ylabel, ylabel_orig, attempt)
attempt += 1
pos = ax.get_position()
ax.set_position([pos.x0, pos.y0+10, pos.width, pos.height]) # shrink height by 10%
return ax


Expand Down Expand Up @@ -587,7 +696,13 @@ def format_xticks(
else:
xlowerlimit = 1.1 * dataframe[estimate].min()
xupperlimit = 1.1 * dataframe[estimate].max()
ax.set_xlim(xlowerlimit, xupperlimit)

# 250714: handle the studies with unestimable CI
if not pd.isnull(xlowerlimit) and not pd.isnull(xupperlimit):
ax.set_xlim(xlowerlimit, xupperlimit)
else:
ax.set_xlim(-1, 1)

if xticks is not None:
ax.set_xticks(xticks)
ax.xaxis.set_tick_params(labelsize=xtick_size)
Expand Down
Loading