diff --git a/petab/v1/visualize/plot_residuals.py b/petab/v1/visualize/plot_residuals.py index 46e83fb9..a1f2ec9b 100644 --- a/petab/v1/visualize/plot_residuals.py +++ b/petab/v1/visualize/plot_residuals.py @@ -136,6 +136,7 @@ def plot_goodness_of_fit( size: tuple = (10, 7), color=None, ax: plt.Axes | None = None, + normalized_error: bool = True, ) -> matplotlib.axes.Axes: """ Plot goodness of fit. @@ -154,6 +155,10 @@ def plot_goodness_of_fit( `matplotlib.pyplot.scatter`. ax: Axis object. + normalized_error: + Type of error to display. + If True, mean of squared normalized residuals is shown, + otherwise mean of squared residuals. Returns ------- @@ -168,12 +173,26 @@ def plot_goodness_of_fit( "are needed for goodness_of_fit" ) - residual_df = calculate_residuals( - measurement_dfs=petab_problem.measurement_df, - simulation_dfs=simulations_df, - observable_dfs=petab_problem.observable_df, - parameter_dfs=petab_problem.parameter_df, - )[0] + if normalized_error: + residual_df = calculate_residuals( + measurement_dfs=petab_problem.measurement_df, + simulation_dfs=simulations_df, + observable_dfs=petab_problem.observable_df, + parameter_dfs=petab_problem.parameter_df, + normalize=True, + )[0] + error_name = "mean of squared\nnormalized residuals" + else: + residual_df = calculate_residuals( + measurement_dfs=petab_problem.measurement_df, + simulation_dfs=simulations_df, + observable_dfs=petab_problem.observable_df, + parameter_dfs=petab_problem.parameter_df, + normalize=False, + )[0] + error_name = "mean of squared residuals" + error = np.mean(np.power(residual_df["residual"], 2)) + slope, intercept, r_value, p_value, std_err = stats.linregress( simulations_df["simulation"], petab_problem.measurement_df["measurement"], @@ -199,7 +218,6 @@ def plot_goodness_of_fit( ax.plot(x, x, linestyle="--", color="gray") ax.plot(x, intercept + slope * x, "r", label="fitted line") - mse = np.mean(np.abs(residual_df["residual"])) ax.text( 0.1, 0.70, @@ -207,7 +225,7 @@ def plot_goodness_of_fit( f"slope: {slope:.2f}\n" f"intercept: {intercept:.2f}\n" f"p-value: {p_value:.2e}\n" - f"mean squared error: {mse:.2e}\n", + f"{error_name}: {error:.2e}\n", transform=ax.transAxes, )