diff --git a/forestplot/graph_utils.py b/forestplot/graph_utils.py index 81b1a07..5f39582 100644 --- a/forestplot/graph_utils.py +++ b/forestplot/graph_utils.py @@ -5,7 +5,9 @@ import matplotlib.pyplot as plt import pandas as pd +import numpy as np from matplotlib.pyplot import Axes +from matplotlib.patches import Polygon warnings.filterwarnings("ignore") @@ -47,9 +49,16 @@ def draw_ci( if ll is not None: lw = kwargs.get("lw", 1.4) linecolor = kwargs.get("linecolor", ".6") + # 250808: originally, the y is specified as dataframe[yticklabel]. This works until there are duplicate values in the yticklabel column. In this case, pyplot skips the duplicated values without yielding any warning. This is very bad practice. When plotting, always specify numerical x-y coordinates!!! + # x = dataframe[estimate].to_numpy(copy=False) # float + # lo = dataframe[ll].to_numpy(copy=False) + # hi = dataframe[hl].to_numpy(copy=False) + # y = np.arange(len(x), dtype=float) + # xerr = np.vstack([x - lo, hi - x]) + ax.errorbar( x=dataframe[estimate], - y=dataframe[yticklabel], + y=np.arange(dataframe.shape[0]), xerr=[dataframe[estimate] - dataframe[ll], dataframe[hl] - dataframe[estimate]], ecolor=linecolor, elinewidth=lw, @@ -60,9 +69,12 @@ def draw_ci( ax.set_xscale("log", base=10) return ax +# Utility function, determine marker size. +def determine_marker_size(weight): + return np.power(2+0.125*weight,2) def draw_est_markers( - dataframe: pd.core.frame.DataFrame, estimate: str, yticklabel: str, ax: Axes, **kwargs: Any + dataframe: pd.core.frame.DataFrame, estimate: str, yticklabel: str, ax: Axes, weight_col=None, total_col=None, **kwargs: Any ) -> Axes: """ Draws the markers of the estimates using the Matplotlib plt.scatter API. @@ -79,26 +91,68 @@ def draw_est_markers( Name of column in intermediate dataframe containing the formatted yticklabels. ax (Matplotlib Axes) Axes to operate on. + weight_col (str) + If specified, marker size will be drawn proportionally to the weight of the study. + total_col (str) + The column containing the indicator for whether a row is a subtotal row. If it is, a subtotal row, set the markersize to 0 despite the weight being 100. Returns ------- Matplotlib Axes object. """ marker = kwargs.get("marker", "s") - markersize = kwargs.get("markersize", 40) + # markersize = kwargs.get("markersize", 40) markercolor = kwargs.get("markercolor", "darkslategray") markeralpha = kwargs.get("markeralpha", 0.8) - ax.scatter( - y=yticklabel, - x=estimate, - data=dataframe, - marker=marker, - s=markersize, - color=markercolor, - alpha=markeralpha, - ) + + # 250715: draw marker sizes proportionally to study weights + if weight_col!=None: + if not pd.isnull(dataframe[estimate]).all(): + dataframe["markersize"] = determine_marker_size(dataframe[weight_col]) + if total_col!=None: + dataframe.loc[dataframe[total_col]==1,"markersize"]=0 + markersize = "markersize" + if weight_col==None: + markersize = kwargs.get("markersize", 40) + + + # 250714: some dataframes are empty. In such cases, we still draw an empty graph. But of course we don't need markers on an empty graph! + if not pd.isnull(dataframe[estimate]).all(): + ax.scatter( + y=range(len(dataframe)), + x=estimate, + data=dataframe, + marker=marker, + s=markersize, + color=markercolor, + alpha=markeralpha, + ) return ax +def draw_total_diamond( + dataframe: pd.core.frame.DataFrame, + total_col: str, + estimate: str, + ll: str, + hl: str, + ax: Axes, + **kwargs: Any +) -> Axes: + height = 0.8 # total height of the diamond from top to bottom + # print(height) + for ii,row in dataframe.iterrows(): + if row[total_col]==1: + # print(f"Row {ii} is total!") + ci_low = row[ll] + ci_high = row[hl] + val = row[estimate] + diamond = Polygon( + # left, top, right, bottom + [(ci_low, ii), (val, ii+height/2), (ci_high, ii), (val, ii-height/2)], + closed = True, facecolor="black", zorder=10 + ) + ax.add_patch(diamond) + return ax def draw_ref_xline( ax: Axes, @@ -140,7 +194,7 @@ def draw_ref_xline( def right_flush_yticklabels( - dataframe: pd.core.frame.DataFrame, yticklabel: str, flush: bool, ax: Axes, **kwargs: Any + dataframe: pd.core.frame.DataFrame, yticklabel: str, flush: bool, ax: Axes, flag_col: str, **kwargs: Any ) -> float: """Flushes the formatted ytickers to the left. Also returns the amount of max padding in the window width. @@ -167,6 +221,14 @@ def right_flush_yticklabels( fontsize = kwargs.get("fontsize", 12) # plt.draw() fig = plt.gcf() + # fig.canvas.draw() + # fig = ax.figure + + # 250804: solve the mystery where the plot sometimes drop the header row! + # the number of rows in dataframe[yticklabel] is sometimes not equal to the number of y-ticks automatically generated by matplotlib. + # TakuaLiu: I still don't understand what triggers this mismatch, but we can eliminate it by forcing the number of yticks be the same as the number of yticklabels! + yticks = range(len(dataframe)) + ax.set_yticks(yticks) if flush: ax.set_yticklabels( dataframe[yticklabel], fontfamily=fontfamily, fontsize=fontsize, ha="left" @@ -175,20 +237,35 @@ def right_flush_yticklabels( ax.set_yticklabels( dataframe[yticklabel], fontfamily=fontfamily, fontsize=fontsize, ha="right" ) + if len(flag_col)>0 and (flag_col in dataframe.columns): + # print(dataframe.columns) + for label, fg in zip(ax.get_yticklabels(), dataframe[flag_col]): + if pd.isnull(fg): + continue + if fg>0: + # print(f"Should set the color of {label}") + label.set_color("red") + # nlast = len(ax.get_yticklabels()) + # print(nlast) + # print(ax.get_yticklabels()) yax = ax.get_yaxis() try: pad = max( T.label.get_window_extent(renderer=fig.canvas.get_renderer()).width for T in yax.majorTicks ) + except AttributeError: pad = max( T.label1.get_window_extent(renderer=fig.canvas.get_renderer()).width for T in yax.majorTicks ) + # for T in yax.majorTicks: + # print(T.label1) + # print(T.label1.get_window_extent(renderer=fig.canvas.get_renderer()).width) + pad = pad* 72.0 / fig.dpi if flush: yax.set_tick_params(pad=pad) - return pad @@ -228,6 +305,7 @@ def draw_pval_right( inv = ax.transData.inverted() righttext_width = 0 fig = plt.gcf() + # fig = ax.figure for _, row in dataframe.iterrows(): yticklabel1 = row[yticklabel] yticklabel2 = row["formatted_pval"] @@ -327,6 +405,7 @@ def draw_yticklabel2( inv = ax.transData.inverted() righttext_width = 0 fig = plt.gcf() + # fig = ax.figure for ix, row in dataframe.iterrows(): yticklabel1 = row["yticklabel"] yticklabel2 = row["yticklabel2"] @@ -362,6 +441,23 @@ def draw_yticklabel2( righttext_width = max(righttext_width, x1) return ax, righttext_width +# utility for draw_ylabel1() +def shorten_label(txt, orig_txt, attempt=0): + if len(txt.split(":"))>1: + # have both analysis group and subgroup info + # in this case, keep only subgroup info + return txt.split(":")[-1] + if (len(txt.split(":"))==1 and len(orig_txt.split(":"))>1) or len(orig_txt.split(":"))==1: + # originally have both analysis group and subgroup info, tried keeping only the subgroup info, but still too long + # or, there is only analysis group to begin with + new_txt = "" + for tt in orig_txt.split(":"): + new_txt += " ".join(tt.split(" ")[:max(3,15-attempt)]) + " (...) " + txt = new_txt + if attempt > 10: + # unlikely, but if this happens, we don't care about what's being written, just reduce the length! + txt = orig_txt[:(max(20, len(orig_txt)-attempt))] + "(...)" + return txt def draw_ylabel1(ylabel: str, pad: float, ax: Axes, **kwargs: Any) -> Axes: """ @@ -380,22 +476,35 @@ def draw_ylabel1(ylabel: str, pad: float, ax: Axes, **kwargs: Any) -> Axes: ------- Matplotlib Axes object. """ + + ylabel_orig = ylabel + fig = plt.gcf() + # fig = ax.figure fontsize = kwargs.get("fontsize", 12) - ax.set_ylabel("") + decent_length = False + attempt = 0 if ylabel is not None: - # Retrieve settings from kwargs - ylabel1_size = kwargs.get("ylabel1_size", 1 + fontsize) - ylabel1_fontweight = kwargs.get("ylabel1_fontweight", "bold") - ylabel_loc = kwargs.get("ylabel_loc", "top") - ylabel_angle = kwargs.get("ylabel_angle", "horizontal") - ax.set_ylabel( - ylabel, - loc=ylabel_loc, - labelpad=-pad, - rotation=ylabel_angle, - size=ylabel1_size, - fontweight=ylabel1_fontweight, - ) + while not decent_length: + # Retrieve settings from kwargs + ylabel1_size = kwargs.get("ylabel1_size", 1 + fontsize) + ylabel1_fontweight = kwargs.get("ylabel1_fontweight", "bold") + ylabel_loc = kwargs.get("ylabel_loc", "top") + ylabel_angle = kwargs.get("ylabel_angle", "horizontal") + ax.set_ylabel( + ylabel, + loc=ylabel_loc, + labelpad=-pad, + rotation=ylabel_angle, + size=ylabel1_size, + fontweight=ylabel1_fontweight, + ) + label_w = ax.yaxis.label.get_window_extent(renderer=fig.canvas.get_renderer()).width + label_w = label_w * 72.0/fig.dpi + decent_length = label_w <= pad + ylabel = shorten_label(ylabel, ylabel_orig, attempt) + attempt += 1 + pos = ax.get_position() + ax.set_position([pos.x0, pos.y0+10, pos.width, pos.height]) # shrink height by 10% return ax @@ -587,7 +696,13 @@ def format_xticks( else: xlowerlimit = 1.1 * dataframe[estimate].min() xupperlimit = 1.1 * dataframe[estimate].max() - ax.set_xlim(xlowerlimit, xupperlimit) + + # 250714: handle the studies with unestimable CI + if not pd.isnull(xlowerlimit) and not pd.isnull(xupperlimit): + ax.set_xlim(xlowerlimit, xupperlimit) + else: + ax.set_xlim(-1, 1) + if xticks is not None: ax.set_xticks(xticks) ax.xaxis.set_tick_params(labelsize=xtick_size) diff --git a/forestplot/plot.py b/forestplot/plot.py index 4b2191b..35e61fa 100644 --- a/forestplot/plot.py +++ b/forestplot/plot.py @@ -17,6 +17,7 @@ draw_alt_row_colors, draw_ci, draw_est_markers, + draw_total_diamond, draw_pval_right, draw_ref_xline, draw_tablelines, @@ -80,6 +81,10 @@ def forestplot( preprocess: bool = True, table: bool = False, ax: Optional[Axes] = None, + weight_col = None, + total_col = None, + total_stats_col = None, + flag_col = "", **kwargs: Any, ) -> Axes: """ @@ -152,7 +157,14 @@ def forestplot( If True, in addition to the Matplotlib Axes object, returns the intermediate dataframe created from preprocess_dataframe(). A tuple of (preprocessed_dataframe, Ax) will be returned. - + weight_col (str) + Default is None. If specified, marker size will be proportaional to the weight of the study. + total_col (str) + Default is None. If specified, it should be the name of the column indicating which row is subtotal. The values in the column should be 0 (not a subtotal), or 1 (a subtotal row). A horizontal diamond will be drawn for subtotal rows rather than square&whiskers. + total_stats_col (str) + Default is None. If specified, it should be the name of the column indicating which row contains the stats info of the subtotal. The values in the column should be 0 (not such a row), or 1 (is such a row). In such a row, the stats info should be specified in the varlabel column using complete descriptions like "Test for overall effect: Z = 3.02 (P = 0.003)", "Heterogeneity: Tau² (DLb) = 0.00; Chi² = 2.86, df = 3 (P = 0.41); I² = 0%". Can add as many such rows as needed. + flag_col (str) + the column based on which we color the yticklables to flag suspicious rows. Returns ------- Matplotlib Axes object. @@ -198,9 +210,10 @@ def forestplot( sortby=sortby, flush=flush, decimal_precision=decimal_precision, + total_stats_col=total_stats_col, **kwargs, ) - ax = _make_forestplot( + fig, ax = _make_forestplot( dataframe=_local_df, yticklabel="yticklabel", estimate=estimate, @@ -221,9 +234,12 @@ def forestplot( color_alt_rows=color_alt_rows, table=table, ax=ax, + weight_col=weight_col, + total_col=total_col, + flag_col=flag_col, **kwargs, ) - return (_local_df, ax) if return_df else ax + return (_local_df, fig, ax) if return_df else (fig, ax) def _preprocess_dataframe( @@ -248,6 +264,7 @@ def _preprocess_dataframe( sortascend: bool = True, flush: bool = True, decimal_precision: int = 2, + total_stats_col: Optional[str] = None, **kwargs: Any, ) -> pd.core.frame.DataFrame: """ @@ -320,11 +337,11 @@ def _preprocess_dataframe( ) if groupvar is not None: # Make groups dataframe = normalize_varlabels( - dataframe=dataframe, varlabel=groupvar, capitalize=capitalize + dataframe=dataframe, varlabel=groupvar, capitalize=capitalize, total_stats_col=total_stats_col ) dataframe = insert_groups(dataframe=dataframe, groupvar=groupvar, varlabel=varlabel) dataframe = normalize_varlabels( - dataframe=dataframe, varlabel=varlabel, capitalize=capitalize + dataframe=dataframe, varlabel=varlabel, capitalize=capitalize, total_stats_col=total_stats_col ) dataframe = indent_nongroupvar(dataframe=dataframe, varlabel=varlabel, groupvar=groupvar) if form_ci_report: @@ -357,6 +374,7 @@ def _preprocess_dataframe( varlabel=varlabel, annote=annote, annoteheaders=annoteheaders, + total_stats_col=total_stats_col, **kwargs, ) if rightannote is not None: @@ -374,6 +392,7 @@ def _preprocess_dataframe( annoteheaders=annoteheaders, rightannote=rightannote, right_annoteheaders=right_annoteheaders, + total_stats_col=total_stats_col, **kwargs, ) return reverse_dataframe(dataframe) # since plotting starts from bottom @@ -401,7 +420,10 @@ def _make_forestplot( ax: Axes, despine: bool = True, table: bool = False, - **kwargs: Any, + weight_col = None, + total_col = None, + flag_col = "", + **kwargs: Any ) -> Axes: """ Create and draw a forest plot using the given DataFrame and specified parameters. @@ -451,6 +473,8 @@ def _make_forestplot( Whether to remove the top and right spines of the plot. table : bool, default=False Whether to draw a table-like structure on the plot. + weight_col: str, default=None + If weight column is specified, the marker size will be drawn proportionally to weight. **kwargs : Any Additional keyword arguments for further customization. @@ -460,7 +484,7 @@ def _make_forestplot( The matplotlib Axes object with the forest plot. """ if not ax: - _, ax = plt.subplots(figsize=figsize, facecolor="white") + fig, ax = plt.subplots(figsize=figsize, facecolor="white") ax = draw_ci( dataframe=dataframe, estimate=estimate, @@ -471,9 +495,15 @@ def _make_forestplot( ax=ax, **kwargs, ) + # 250715: draw marker sizes proportionally to study weights draw_est_markers( - dataframe=dataframe, estimate=estimate, yticklabel=yticklabel, ax=ax, **kwargs + dataframe=dataframe, estimate=estimate, yticklabel=yticklabel, ax=ax, weight_col=weight_col, total_col=total_col, **kwargs ) + + + if total_col is not None: + draw_total_diamond(dataframe=dataframe, total_col=total_col, ax=ax, estimate=estimate, ll=ll, hl=hl, **kwargs + ) format_xticks( dataframe=dataframe, estimate=estimate, ll=ll, hl=hl, xticks=xticks, ax=ax, **kwargs ) @@ -485,8 +515,9 @@ def _make_forestplot( **kwargs, ) pad = right_flush_yticklabels( - dataframe=dataframe, yticklabel=yticklabel, flush=flush, ax=ax, **kwargs + dataframe=dataframe, yticklabel=yticklabel, flush=flush, ax=ax, flag_col=flag_col, **kwargs ) + draw_ylabel1(ylabel=ylabel, pad=pad, ax=ax, **kwargs) if rightannote is None: ax, righttext_width = draw_pval_right( dataframe=dataframe, @@ -507,10 +538,10 @@ def _make_forestplot( ax=ax, **kwargs, ) - - draw_ylabel1(ylabel=ylabel, pad=pad, ax=ax, **kwargs) + remove_ticks(ax) format_grouplabels(dataframe=dataframe, groupvar=groupvar, ax=ax, **kwargs) + format_tableheader( annoteheaders=annoteheaders, right_annoteheaders=right_annoteheaders, ax=ax, **kwargs ) @@ -536,5 +567,6 @@ def _make_forestplot( ax=ax, ) negative_padding = 0.5 - ax.set_ylim(-0.5, ax.get_ylim()[1] - negative_padding) - return ax + # ax.set_ylim(-0.5, ax.get_ylim()[1] - negative_padding) # this doesn't reflect the number of actually required rows + ax.set_ylim(-0.5, dataframe.shape[0]) # 250713: added by Takua Liu + return fig, ax diff --git a/forestplot/text_utils.py b/forestplot/text_utils.py index ccbf13a..c7fdba1 100644 --- a/forestplot/text_utils.py +++ b/forestplot/text_utils.py @@ -167,6 +167,7 @@ def normalize_varlabels( dataframe: pd.core.frame.DataFrame, varlabel: str, capitalize: str = "capitalize", + total_stats_col = None, ) -> pd.core.frame.DataFrame: """ Normalize variable labels to capitalize or title form. @@ -181,22 +182,35 @@ def normalize_varlabels( capitalize (str) 'capitalize' or 'title' See https://pandas.pydata.org/docs/reference/api/pandas.Series.str.capitalize.html - + total_stats_col (str) + if such a column is specified, ignore the rows where total_stats_col is 1 Returns ------- pd.core.frame.DataFrame with the varlabel column normalized. """ if capitalize: - if capitalize == "title": - dataframe[varlabel] = dataframe[varlabel].str.title() - elif capitalize == "capitalize": - dataframe[varlabel] = dataframe[varlabel].str.capitalize() - elif capitalize == "lower": - dataframe[varlabel] = dataframe[varlabel].str.lower() - elif capitalize == "upper": - dataframe[varlabel] = dataframe[varlabel].str.upper() - elif capitalize == "swapcase": - dataframe[varlabel] = dataframe[varlabel].str.swapcase() + if total_stats_col != None: + if capitalize == "title": + dataframe[dataframe[total_stats_col]==0][varlabel] = dataframe[dataframe[total_stats_col]==0][varlabel].str.title() + elif capitalize == "capitalize": + dataframe[dataframe[total_stats_col]==0][varlabel] = dataframe[dataframe[total_stats_col]==0][varlabel].str.capitalize() + elif capitalize == "lower": + dataframe[dataframe[total_stats_col]==0][varlabel] = dataframe[dataframe[total_stats_col]==0][varlabel].str.lower() + elif capitalize == "upper": + dataframe[dataframe[total_stats_col]==0][varlabel] = dataframe[dataframe[total_stats_col]==0][varlabel].str.upper() + elif capitalize == "swapcase": + dataframe[dataframe[total_stats_col]==0][varlabel] = dataframe[dataframe[total_stats_col]==0][varlabel].str.swapcase() + else: + if capitalize == "title": + dataframe[varlabel] = dataframe[varlabel].str.title() + elif capitalize == "capitalize": + dataframe[varlabel] = dataframe[varlabel].str.capitalize() + elif capitalize == "lower": + dataframe[varlabel] = dataframe[varlabel].str.lower() + elif capitalize == "upper": + dataframe[varlabel] = dataframe[varlabel].str.upper() + elif capitalize == "swapcase": + dataframe[varlabel] = dataframe[varlabel].str.swapcase() return dataframe @@ -330,6 +344,7 @@ def prep_annote( annoteheaders: Optional[Union[Sequence[str], None]], varlabel: str, groupvar: str, + total_stats_col=None, **kwargs: Any, ) -> pd.core.frame.DataFrame: """Prepare the additional columns to be printed as annotations. @@ -364,16 +379,26 @@ def prep_annote( for ix, annotation in enumerate(annote): # Get max len for padding _pad = _get_max_varlen(dataframe=dataframe, varlabel=annotation, extrapad=0) + if total_stats_col is not None: + _pad = _get_max_varlen(dataframe=dataframe[dataframe[total_stats_col]==0], varlabel=annotation, extrapad=0) + + if annoteheaders is not None: # Check that max len exceeds header length _header = annoteheaders[ix] _pad = max(_pad, len(_header)) lookup_annote_len[ix] = _pad for iy, row in dataframe.iterrows(): # Make individual formatted_annotations + if total_stats_col is not None: + if row[total_stats_col]==1: + dataframe.loc[iy, f"formatted_{annotation}"] = "" + continue _annotation = str(row[annotation]).ljust(_pad) dataframe.loc[iy, f"formatted_{annotation}"] = _annotation # get max length for variables pad = _get_max_varlen(dataframe=dataframe, varlabel=varlabel, extrapad=0) + if total_stats_col is not None: + pad = _get_max_varlen(dataframe=dataframe[dataframe[total_stats_col]==0], varlabel=varlabel, extrapad=0) if groupvar is not None: groups = [gr.lower() for gr in dataframe[groupvar].unique()] @@ -382,6 +407,10 @@ def prep_annote( for ix, row in dataframe.iterrows(): yticklabel = row[varlabel] + if total_stats_col is not None: + if row[total_stats_col] == 1: + dataframe.loc[ix, "yticklabel"] = yticklabel + continue if yticklabel.lower().strip() in groups: dataframe.loc[ix, "yticklabel"] = yticklabel else: @@ -473,6 +502,7 @@ def make_tableheaders( annoteheaders: Optional[Union[Sequence[str], None]], rightannote: Optional[Union[Sequence[str], None]], right_annoteheaders: Optional[Union[Sequence[str], None]], + total_stats_col=None, **kwargs: Any, ) -> pd.core.frame.DataFrame: """Make the table headers from 'annoteheaders' and 'right_annoteheaders' as a row in the dataframe. @@ -517,6 +547,8 @@ def make_tableheaders( dataframe = insert_empty_row(dataframe) pad = _get_max_varlen(dataframe=dataframe, varlabel=varlabel, extrapad=0) + if total_stats_col is not None: + pad = _get_max_varlen(dataframe=dataframe[dataframe[total_stats_col]==0], varlabel=varlabel, extrapad=0) left_headers = variable_header.ljust(pad) dataframe.loc[0, "yticklabel"] = left_headers if annoteheaders is not None: