Source code for physical_validation.util.plot

#                                                                         #
#    physical_validation,                                                 #
#    a python package to test the physical validity of MD results         #
#                                                                         #
#    Written by Pascal T. Merz <>                       #
#               Michael R. Shirts <>           #
#                                                                         #
#    Copyright (c) 2017-2021 University of Colorado Boulder               #
#              (c) 2012      The University of Virginia                   #
#                                                                         #

import warnings
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

[docs]def plot( res: List[Dict[str, Union[np.ndarray, str, Dict]]], legend: Optional[str] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, inv_x: bool = False, inv_y: bool = False, sci_x: bool = False, sci_y: bool = False, axtext: Optional[str] = None, annotation_location: Optional[str] = None, percent: bool = False, filename: Optional[str] = None, screen: bool = True, ) -> None: try: import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.ticker import AutoMinorLocator, FuncFormatter except ImportError: warnings.warn("Install matplotlib to enable plotting.") return def to_percent(y_ticks, _): # Adapted from return "{:g}".format(100 * y_ticks) font = {"family": "serif", "weight": "normal", "size": 16} mpl.rc("font", **font) plt.ioff() fig, ax = plt.subplots() xmin = float("inf") xmax = float("-inf") for r in res: if "args" in r: args = r["args"] else: args = dict() if "name" in r: # backwards compatibility args["label"] = r["name"] if "hist" in r and r["hist"]: y = r["y"] # Use Freedman–Diaconis rule to determine good bin width inter_quartile_range = np.percentile(y, 75) - np.percentile(y, 25) bin_width = 2 * inter_quartile_range / (len(y) ** (1 / 3)) num_bins = int((np.max(y) - np.min(y)) / bin_width) _, x, _ = ax.hist(y, num_bins, **args) else: x = r["x"] y = r["y"] if xlim is not None: x = x[(r["x"] >= xlim[0]) & (r["x"] <= xlim[1])] y = y[(r["x"] >= xlim[0]) & (r["x"] <= xlim[1])] if "y_err" in r: dy = r["y_err"] if xlim is not None: dy = dy[(r["x"] >= xlim[0]) & (r["x"] <= xlim[1])] ax.errorbar(x, y, yerr=dy, **args) else: ax.plot(x, y, **args) xmin = min(np.min(x), xmin) xmax = max(np.max(x), xmax) if legend is not None: ax.legend(loc=legend) box = ax.get_position() if title is not None: ax.set_title(title, y=1.05) box = box.from_bounds(box.x0, box.y0, box.width, box.height * 0.95) if xlabel is not None: ax.set_xlabel(xlabel, labelpad=5) box = box.from_bounds( box.x0, box.y0 + 0.05 * box.height, box.width, box.height * 0.95 ) if ylabel is not None: ax.set_ylabel(ylabel, labelpad=10) box = box.from_bounds( box.x0 + 0.05 * box.width, box.y0, box.width * 0.95, box.height ) ax.set_position([box.x0, box.y0, box.width, box.height]) ax.axis("auto") if xlim is not None: ax.set_xlim(xlim) elif np.isfinite(xmin) and np.isfinite(xmax): ax.set_xlim([xmin, xmax]) if ylim is not None: ax.set_ylim(ylim) # Hack to minimize chance of tick overlap xticks = [ tick_location for tick_location in ax.get_xticks() if ax.get_xlim()[0] < tick_location < ax.get_xlim()[1] ] if len(xticks) > 5: ax.set_xticks(xticks[::2]) if xlim is not None: ax.set_xlim(xlim) elif np.isfinite(xmin) and np.isfinite(xmax): ax.set_xlim([xmin, xmax]) ax.xaxis.set_minor_locator(AutoMinorLocator(2)) if inv_x: ax.invert_xaxis() if inv_y: ax.invert_yaxis() if axtext is not None: if isinstance(axtext, str): axtext = [axtext] if annotation_location is None: annotation_location = [None for _ in axtext] if isinstance(annotation_location, tuple): annotation_location = [annotation_location] for t, loc in zip(axtext, annotation_location): bbox = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.9) if loc is None: ax.text( 0.95, 0.05, t, transform=ax.transAxes, ha="right", va="bottom", bbox=bbox, ) else: ax.text(loc[0], loc[1], t, bbox=bbox) if percent: formatter = FuncFormatter(to_percent) ax.yaxis.set_major_formatter(formatter) if sci_x: ax.ticklabel_format(style="sci", axis="x", scilimits=(-3, 4)) if sci_y: ax.ticklabel_format(style="sci", axis="y", scilimits=(-3, 4)) ax.xaxis.major.formatter._useMathText = True if filename is not None: fig.savefig(filename, dpi=300) if screen: plt.ion() if not screen: plt.close(fig)