Source code for kda.plotting

# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
#
# Author: Nikolaus C. Awtrey
#
"""
Diagram Plotting
=========================================================================
The :mod:`~kda.plotting` module contains code to plot partial diagrams
(undirected spanning trees), directional diagrams, and flux diagrams, as
well as cycles and :func:`~kda.ode.ode_solver` results.

The two main functions used for plotting KDA-generated diagrams are
:func:`~kda.plotting.draw_diagrams` and :func:`~kda.plotting.draw_cycles`.
``draw_diagrams`` is used for plotting kinetic diagrams, partial diagrams,
directional diagrams, and flux diagrams, while ``draw_cycles`` is used for
plotting cycles in the kinetic diagram.

For example, for a 4-state model we start by generating the
:class:`~kda.core.KineticModel` and plotting the kinetic diagram:

.. code-block:: python

  import os
  import numpy as np
  import kda
  from kda import plotting

  # define matrix with reaction rates set to 1
  K = np.array([
      [0, 1, 0, 1],
      [1, 0, 1, 1],
      [0, 1, 0, 1],
      [1, 1, 1, 0],
  ])
  # create a KineticModel from the rate matrix
  model = kda.KineticModel(K=K, G=None)
  # specify the positions of all nodes in NetworkX fashion
  node_positions = {0: [1, 1], 1: [-1, 1], 2: [-1, -1], 3: [1, -1]}
  # plot and save the input diagram in the current directory
  plotting.draw_diagrams(model.G, pos=node_positions, path=os.getcwd(), label="input")

The output kinetic diagram figure, ``input.png``:

|img_4wl_small|

We can also plot all cycles:

.. code-block:: python

  # build the cycles for the model
  model.build_cycles()
  # plot cycles with coral-colored nodes
  plotting.draw_cycles(
      G=model.G,
      cycles=model.cycles,
      pos=node_positions,
      path=os.getcwd(),
      # set color-by-target to label the target nodes
      cbt=True,
      label="cycles_panel",
  )

The output cycles figure, ``cycles_panel.png``:

|img_4wl_cycles_small|

Lastly, we can generate the directional and flux diagrams:

.. code-block:: python

  # generate the flux and directional diagrams
  model.build_flux_diagrams()
  model.build_directional_diagrams()
  # plot and save the directional diagrams as a panel
  plotting.draw_diagrams(
      model.directional_diagrams,
      pos=node_positions,
      rows=model.G.number_of_nodes(),
      path=os.getcwd(),
      # set color-by-target to label the target nodes
      cbt=True,
      label="directional_panel",
  )
  # flatten the flux diagrams since they are stored in nested lists
  flux_diagrams = [g for l in model.flux_diagrams if not l is None for g in l]
  # plot and save the flux diagrams as a panel
  plotting.draw_diagrams(
      flux_diagrams,
      pos=node_positions,
      path=os.getcwd(),
      # set color-by-target to label the target nodes
      cbt=True,
      label="flux_panel",
  )

The output directional and flux diagrams figures, ``directional_panel.png``
and ``flux_panel.png``:

**directional_panel.png**

|img_4wl_directional|

**flux_panel.png**

|img_4wl_flux_small|

**NOTE:** For more examples visit the
`KDA examples <https://github.com/Becksteinlab/kda-examples>`_ repository.

.. autofunction:: draw_diagrams
.. autofunction:: draw_cycles
.. autofunction:: draw_ode_results

"""
import os
import numpy as np
import networkx as nx
import matplotlib as mpl
import matplotlib.pyplot as plt

from kda.diagrams import _construct_cycle_edges, _append_reverse_edges


def _get_node_labels(node_list):
    """
    Builds the dictionary of node labels for NetworkX nodes.

    Parameters
    ----------
    node_list : list
        List of node indices (e.g. ``[0, 2, 3, 1]``) indicating
        which node labels should be made.

    Returns
    -------
    labels: dict
        Dictionary where keys are the node index (index-zero) and the
        keys are the node index string (index-one).

    """
    labels = {n:r"${:.0f}$".format(n + 1) for n in node_list}
    return labels


def _get_node_colors(cbt, obj):
    """
    Returns a list of color values (either grey or coral) depending
    on whether color by target is turned on.

    Parameters
    ----------
    cbt : bool
        'Color by target' option that paints target nodes with a
        coral red when ``True``. Typically used for plotting directional
        and flux diagrams.
    obj: object
        ``nx.Graph``, ``nx.MultiDiGraph``, or list of nodes to return color
        values for. If a graph object is input, only nodes
        with attribute ``is_target=True`` will be colored coral red.

    Returns
    -------
    node_colors: list
        List of strings of color values (e.g. ``["0.8", "0.8",...]``).

    """
    base_color = "0.8"
    target_color = "#FF8080"
    if isinstance(obj, nx.Graph) or isinstance(obj, nx.MultiDiGraph):
        node_colors = [base_color for i in obj.nodes]
        if cbt:
            node_colors = np.asarray(node_colors, dtype=object)
            target_mask = list(nx.get_node_attributes(obj, "is_target").values())
            node_colors[target_mask] = target_color
    else:
        if cbt:
            color = target_color
        else:
            color = base_color
        node_colors = [color for i in obj]

    return node_colors


def _get_axis_limits(pos, scale_factor=1.4):
    """
    Retrieves the x/y limits based on the node positions. Values are
    scaled by a constant factor to compensate for the size of the nodes.

    Parameters
    ----------
    pos : dict
        Dictionary where keys are the indexed states (e.g. 0, 1,
        2, ..., ``N``) and the values are the x, y coordinates for
        each node.
    scale_factor: float, optional
        Factor used to scale the x/y axis limits. Default is ``1.4``.

    Returns
    -------
    Tuple of the form ``(xlims, ylims)``, where ``xlims``
    and ``ylims`` are lists containing the scaled minimum
    and maximum x and y values, respectively.

    """
    x = np.zeros(len(pos))
    y = np.zeros(len(pos))
    for i, positions in pos.items():
        x[i] = positions[0]
        y[i] = positions[1]
    xlims = [scale_factor * x.min(), scale_factor * x.max()]
    ylims = [scale_factor * y.min(), scale_factor * y.max()]
    return xlims, ylims


def _get_panel_dimensions(n_diagrams, rows, cols=None):
    """
    Calculates the number of appropriate rows and columns based on the
    number of diagrams. Generally returns the most square-like shape
    that is feasible for a given number of diagrams. If rows are specified,
    the columns will be adjusted to fit.

    Parameters
    ----------
    n_diagrams: int
        Number of diagrams to plot in panel.
    rows : int
        Number of rows, typically based on the square
        root of the number of diagrams to generate.
    cols : int, optional
        Number of columns. Default is ``None``, which results in the number
        of rows being determined based on the number of diagrams input.

    Returns
    -------
    Tuple of the form ``(rows, cols, excess_plots)``, where
    ``rows`` and ``cols`` are the number of rows and columns in the
    panel, respectively, and ``excess_plots`` is the number of extra
    graphs available in the panel.

    """
    if rows is None:
        rows = int(np.sqrt(n_diagrams))
    if cols is None:
        cols = int(np.ceil(n_diagrams / rows))
    excess_plots = rows * cols - n_diagrams
    return (rows, cols, excess_plots)


def _plot_single_diagram(
    diagram,
    pos=None,
    node_labels=None,
    node_list=None,
    node_colors=None,
    edge_list=None,
    font_size=12,
    figsize=(3, 3),
    node_size=300,
    arrow_width=1.5,
    arrow_size=12,
    arrow_style="->",
    connection_style="arc3",
    ax=None,
    cbt=False,
):
    """
    Plots a single diagram.

    Parameters
    ----------
    diagram : ``NetworkX.MultiDiGraph`` or ``NetworkX.Graph``
        Diagram to be plotted.
    pos : dict, optional
        Dictionary where keys are the indexed states (e.g. 0, 1, 2,
        ..., ``N``) and the values are the x, y coordinates for each
        node. If not specified, ``NetworkX.spring_layout()`` is used.
    node_labels: dict, optional
        Dictionary where keys are the node index (index-zero) and the
        keys are the node index string (index-one). If not specified, labels
        will be created for all nodes in the input diagram.
    node_list : list, optional
        List of node indices (e.g. ``[0, 2, 3, 1]``) indicating
        which nodes to plot. If not specified, all nodes in the input
        diagram will be plotted.
    node_colors: list, optional
        List of strings of color values (e.g. ``["0.8", "0.8",...]``)
        used to color the nodes. If not specified, node colors will
        be determined using the ``cbt`` parameter.
    edge_list: list, optional
        List of edge tuples (e.g. ``[(1, 0), (1, 2), ...]``) to plot. If not
        specified, all edges will be plotted.
    font_size : int, optional
        Sets the font size for the figure. Default is ``12``.
    figsize: tuple, optional
        Tuple of the form ``(x, y)``, where ``x`` and ``y`` are the
        x and y-axis figure dimensions in inches. Default is ``(3, 3)``.
    node_size: int, optional
        Size of nodes used for ``NetworkX`` diagram. Default is ``300``.
    arrow_width: float, optional
        Arrow width used for ``NetworkX`` diagram. Default is ``1.5``.
    arrow_size: int, optional
        Arrow size used for ``NetworkX`` diagram. Default is ``12``.
    arrow_style: str, optional
        Style of arrows used for ``NetworkX`` diagram. Default is ``"->"``.
    connection_style: str, optional
        Style of arrow connections for ``NetworkX`` diagram.
        Default is ``"arc3"``.
    ax: ``matplotlib`` axis object, optional
        Axis to place diagrams on. If not specified, a new figure
        and axis will be created. Default is ``None``.
    cbt : bool, optional
        'Color by target' option that paints target nodes with a coral red.
        Typically used for plotting directional and flux diagrams.
        Default is ``False``.

    Returns
    -------
    fig: ``matplotlib.pyplot.figure`` object
        The plotted diagram.

    """
    if nx.is_directed(diagram):
        arrows = True
    else:
        # if graph is undirected, use default parameters
        # for `nx.draw_networkx_edges()`
        # see https://github.com/Becksteinlab/kda/issues/80
        arrows = False
        arrow_size = 10
        arrow_style = None
        connection_style = "arc3"
    if ax is None:
        fig = plt.figure(figsize=figsize, tight_layout=True)
        ax = fig.add_subplot(111)
    else:
        fig = None

    if node_list is None:
        node_list = diagram.nodes()

    if node_labels is None:
        node_labels = _get_node_labels(node_list)

    if pos is None:
        pos = nx.spring_layout(diagram)

    if node_colors is None:
        node_colors = _get_node_colors(cbt=cbt, obj=diagram)

    nx.draw_networkx_nodes(
        diagram,
        pos,
        node_size=node_size,
        nodelist=node_list,
        node_color=node_colors,
        ax=ax,
    )
    nx.draw_networkx_edges(
        diagram,
        pos,
        edgelist=edge_list,
        node_size=node_size,
        width=arrow_width,
        arrowsize=arrow_size,
        arrowstyle=arrow_style,
        arrows=arrows,
        connectionstyle=connection_style,
        ax=ax,
    )
    nx.draw_networkx_labels(diagram, pos, node_labels, font_size=font_size, ax=ax)
    ax.set_axis_off()
    return fig


def _plot_panel(
    diagrams,
    rows=None,
    cols=None,
    pos=None,
    panel_scale=2,
    font_size=12,
    cbt=False,
    curved_arrows=False,
):
    """
    Plots a panel of diagrams of shape `(rows, cols)`.

    Parameters
    ----------
    diagrams : list of cycles or ``NetworkX`` graph objects
        List of diagrams or single diagram to be plotted.
    rows : int, optional
        Number of rows. Default is ``None``, which results in the number
        of rows being determined based on the number of diagrams input.
    cols : int, optional
        Number of columns. Default is ``None``, which results in the number
        of columns being determined based on the number of diagrams input.
    pos : dict, optional
        Dictionary where keys are the indexed states (e.g. 0, 1, 2,
        ..., ``N``) and the values are the x, y coordinates for each
        node. If not specified, ``NetworkX.spring_layout()`` is used.
    panel_scale : float, optional
        Parameter used to scale figure if ``panel=True``. Linearly
        scales figure height and width. Default is ``2``.
    font_size : int, optional
        Sets the font size for the figure. Default is ``12``.
    cbt : bool, optional
        'Color by target' option that paints target nodes with a
        coral red. Typically used for plotting directional and flux
        diagrams. Default is ``False``.
    curved_arrows: bool, optional
        Switches on arrows with a slight curvature to separate double arrows
        for directional diagrams. Default is ``False``.

    Returns
    -------
    fig: ``matplotlib.pyplot.figure`` object
        A panel of figures, where each figure is a plotted diagram or cycle.

    Notes
    -----
    If number of diagrams is not a perfect square, extra
    plots will be generated as empty coordinate axes.

    """
    nrows, ncols, excess_plots = _get_panel_dimensions(
        n_diagrams=len(diagrams), rows=rows, cols=cols
    )
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, tight_layout=True)
    fig.set_figheight(nrows * panel_scale)
    fig.set_figwidth(1.2 * ncols * panel_scale)

    node_size = 150 * panel_scale

    if pos is None:
        pos = nx.spring_layout(diagrams[0])

    if curved_arrows:
        connection_style = "arc3, rad = 0.11"
    else:
        connection_style = "arc3"

    for i, diag in enumerate(diagrams):
        ix = np.unravel_index(i, ax.shape)
        plt.sca(ax[ix])
        _plot_single_diagram(
            diagram=diag,
            pos=pos,
            font_size=font_size,
            node_size=node_size,
            ax=ax[ix],
            cbt=cbt,
            connection_style=connection_style,
        )
    for i in range(excess_plots):
        ax.flat[-i - 1].set_visible(False)
    return fig


[docs] def draw_diagrams( diagrams, pos=None, panel=True, panel_scale=2, font_size=12, cbt=False, rows=None, cols=None, path=None, label=None, curved_arrows=False, ): """ Plots any number of input diagrams. Typically used for plotting kinetic diagrams, or arrays of partial, directional, or flux diagrams. Parameters ---------- diagrams : list of ``NetworkX`` graph objects List of diagrams or single diagram to be plotted. pos : dict, optional Dictionary where keys are the indexed states (e.g. 0, 1, 2, ..., ``N``) and the values are the x, y coordinates for each node. If not specified, ``NetworkX.spring_layout()`` is used. panel : bool, optional Tells the function to output diagrams as an ``NxM`` matrix of subplots, where ``N`` and ``M`` are the number of rows and columns, respectively. ``True`` will output a panel figure, ``False`` will output each figure individually. Default is ``False``. panel_scale : float, optional Parameter used to scale figure if ``panel=True``. Linearly scales figure height and width. Default is ``2``. font_size : int, optional Sets the font size for the figure. Default is ``12``. cbt : bool, optional 'Color by target' option that paints target nodes with a coral red. Typically used for plotting directional and flux diagrams. Default is ``False``. rows : int, optional Number of rows. Default is ``None``, which results in the number of rows being determined based on the number of diagrams input. cols : int, optional Number of columns. Default is ``None``, which results in the number of columns being determined based on the number of diagrams input. path : str, optional String of save path for figure. If a path is specified the figure(s) will be saved at the specified location. Default is ``None``. label : str, optional Figure label used to create unique filename if ``path`` is input. Includes ``.png`` file extension. Default is ``None``. curved_arrows: bool, optional Switches on arrows with a slight curvature to separate double arrows for directional diagrams. Default is ``False``. Notes ----- When using ``panel=True``, if number of diagrams is not a perfect square, extra plots will be generated as empty coordinate axes. Examples -------- The :func:`~kda.plotting.draw_diagrams` function allows for easy plotting of KDA-generated diagrams: .. code-block:: python import os import numpy as np import kda from kda import plotting # define matrix with reaction rates set to 1 K = np.array([ [0, 1, 1], [1, 0, 1], [1, 1, 0], ]) # create a KineticModel from the rate matrix model = kda.KineticModel(K=K, G=None) # generate the directional diagrams model.build_directional_diagrams() # specify the positions of all nodes in NetworkX fashion node_positions = {0: [0, 1], 1: [-0.5, 0], 2: [0.5, 0]} # plot and save the input diagram in the current directory plotting.draw_diagrams(model.G, pos=node_positions, path=os.getcwd(), label="input") # plot and save the directional diagrams as a panel plotting.draw_diagrams( model.directional_diagrams, pos=node_positions, path=cwd, # set color-by-target to label the target nodes cbt=True, label="directional_panel", ) This will save two files, ``input.png`` and ``directional_panel.png``, in your current working directory. """ if curved_arrows: connection_style = "arc3, rad = 0.11" else: connection_style = "arc3" if isinstance(diagrams, nx.Graph) or isinstance(diagrams, nx.MultiDiGraph): # single diagram case fig = _plot_single_diagram( diagram=diagrams, pos=pos, font_size=font_size, figsize=(4, 4), node_size=500, arrow_width=2, cbt=cbt, connection_style=connection_style, ) plt.close() if path: save_path = os.path.join(path, f"{label}.png") fig.savefig(save_path, dpi=300) else: return fig else: # array of diagrams case if pos is None: pos = nx.spring_layout(diagrams[0]) if panel: fig = _plot_panel( diagrams=diagrams, pos=pos, rows=rows, cols=cols, font_size=font_size, panel_scale=panel_scale, cbt=cbt, curved_arrows=curved_arrows, ) plt.close() if path: save_path = os.path.join(path, f"{label}.png") fig.savefig(save_path, dpi=300) else: return fig else: # individual plots case node_list = list(diagrams[0].nodes) node_labels = _get_node_labels(node_list=node_list) fig_list = [] for i, diag in enumerate(diagrams): fig = _plot_single_diagram( diagram=diag, pos=pos, node_list=node_list, node_labels=node_labels, font_size=font_size, cbt=cbt, connection_style=connection_style, ) fig_list.append(fig) plt.close() if path: save_path = os.path.join(path, f"{label}_{i+1}.png") fig.savefig(save_path, dpi=300) if not path: return fig_list
[docs] def draw_cycles( G, cycles, pos=None, panel=True, panel_scale=2, rows=None, cols=None, font_size=12, cbt=False, curved_arrows=False, path=None, label=None, ): """ Plots a diagram with a cycle labeled. Parameters ---------- G : ``NetworkX.MultiDiGraph`` Input diagram used for plotting the cycles. cycles : list of lists of int List of cycles or individual cycle to be plotted, index zero. Order of node indices does not matter. pos : dict, optional Dictionary where keys are the indexed states (e.g. 0, 1, 2, ..., ``N``) and the values are the x, y coordinates for each node. If not specified, ``NetworkX.spring_layout()`` is used. panel : bool, optional Tells the function to output diagrams as an ``NxM`` matrix of subplots, where ``N`` and ``M`` are the number of rows and columns, respectively. ``True`` will output a panel figure, ``False`` will output each figure individually. Default is ``False``. panel_scale : float, optional Parameter used to scale figure if ``panel=True``. Linearly scales figure height and width. Default is ``2``. font_size : int, optional Sets the font size for the figure. Default is ``12``. cbt : bool, optional 'Color by target' option that paints target nodes with a coral red. Default is ``False``. curved_arrows: bool, optional Switches on arrows with a slight curvature to separate double arrows for directional diagrams. Default is ``False``. path : str, optional String of save path for figure. If a path is specified the figure(s) will be saved at the specified location. Default is ``None``. label : str, optional Figure label used to create unique filename if ``path`` is input. Includes ``.png`` file extension. Default is ``None``. Notes ----- When using ``panel=True``, if number of diagrams is not a perfect square, extra plots will be generated as empty coordinate axes. Examples -------- The :func:`~kda.plotting.draw_cycles` function allows for easy plotting of cycles in kinetic diagrams: .. code-block:: python import os import numpy as np import kda from kda import plotting # define matrix with reaction rates set to 1 K = np.array([ [0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 1, 1, 0], ]) # create a KineticModel from the rate matrix model = kda.KineticModel(K=K, G=None) # specify the positions of all nodes in NetworkX fashion node_positions = {0: [1, 1], 1: [-1, 1], 2: [-1, -1], 3: [1, -1]} # build the cycles for the model model.build_cycles() # plot cycles with coral-colored nodes plotting.draw_cycles( G=model.G, cycles=model.cycles, pos=node_positions, path=os.getcwd(), # set color-by-target to label the target nodes cbt=True, label="cycles_panel", ) This will save a file ``cycles_panel.png`` in your current working directory displaying all 3 cycles for the 4-state model. """ if curved_arrows: connection_style = "arc3, rad = 0.11" else: connection_style = "arc3" if isinstance(cycles[0], int): # single cycle case node_list = cycles node_labels = _get_node_labels(node_list=node_list) node_colors = _get_node_colors(cbt=cbt, obj=node_list) cycle_edges = _construct_cycle_edges(node_list) edge_list = _append_reverse_edges(cycle_edges) fig = _plot_single_diagram( diagram=G, pos=pos, edge_list=edge_list, node_list=node_list, node_labels=node_labels, node_colors=node_colors, node_size=500, font_size=font_size, figsize=(4, 4), arrow_width=2, cbt=False, connection_style=connection_style, ) plt.close() if path: save_path = os.path.join(path, f"{label}.png") fig.savefig(save_path, dpi=300) else: return fig else: # multiple cycles case if pos is None: pos = nx.spring_layout(G) if panel: # draw panel case nrows, ncols, excess_plots = _get_panel_dimensions( n_diagrams=len(cycles), rows=rows, cols=cols ) fig, ax = plt.subplots(nrows=nrows, ncols=ncols, tight_layout=True) fig.set_figheight(nrows * panel_scale) fig.set_figwidth(1.2 * ncols * panel_scale) xlims, ylims = _get_axis_limits(pos, scale_factor=1.4) for i, cycle in enumerate(cycles): node_labels = _get_node_labels(node_list=cycle) node_colors = _get_node_colors(cbt=cbt, obj=cycle) cycle_edges = _construct_cycle_edges(cycle) edge_list = _append_reverse_edges(cycle_edges) ix = np.unravel_index(i, ax.shape) plt.sca(ax[ix]) ax[ix].set_xlim(xlims) ax[ix].set_ylim(ylims) _plot_single_diagram( diagram=G, pos=pos, edge_list=edge_list, node_list=cycle, node_labels=node_labels, node_colors=node_colors, node_size=150 * panel_scale, font_size=font_size, arrow_width=1.5, cbt=False, connection_style=connection_style, ax=ax[ix], ) for j in range(excess_plots): ax.flat[-j - 1].set_visible(False) plt.close() if path: save_path = os.path.join(path, f"{label}.png") fig.savefig(save_path, dpi=300) else: return fig else: # draw individual plots case fig_list = [] for i, cycle in enumerate(cycles): node_labels = _get_node_labels(node_list=cycle) node_colors = _get_node_colors(cbt=cbt, obj=cycle) cycle_edges = _construct_cycle_edges(cycle) edge_list = _append_reverse_edges(cycle_edges) fig = _plot_single_diagram( diagram=G, pos=pos, edge_list=edge_list, node_list=cycle, node_labels=node_labels, node_colors=node_colors, node_size=500, font_size=font_size, arrow_width=2, cbt=False, connection_style=connection_style, ) fig_list.append(fig) plt.close() if path: save_path = os.path.join(path, f"{label}_{i+1}.png") fig.savefig(save_path, dpi=300) if not path: return fig_list
[docs] def draw_ode_results( results, figsize=(5, 4), legendloc="best", bbox_coords=None, path=None, label=None ): """ Plots probability time series for all states generated by :func:`~kda.ode.ode_solver`. Parameters ---------- results : ``Bunch`` object Contains time information (``results.t``) and function information at time ``t`` (``results.y``), as well as various other fields. figsize: tuple, optional Tuple of the form ``(x, y)``, where ``x`` and ``y`` are the x and y-axis figure dimensions in inches. Default is ``(5, 4)``. legendloc : str, optional String passed to determine where to place the legend for the figure. Default is ``"best"``. bbox_coords : tuple, optional Tuple of the form ``(x, y)``, where ``x`` and ``y`` are the x and y-axis coordinates for the legend. Default is ``None``. path : str, optional String of save path for figure. If a path is specified the figure will be saved at the specified location. Default is ``None``. label : str, optional Figure label used to create unique filename if ``path`` is input. Includes ``.png`` file extension. Default is ``None``. """ N = int(len(results.y)) time = results.t p_time_series = results.y[:N] p_tot = p_time_series.sum(axis=0) fig = plt.figure(figsize=figsize, tight_layout=True) ax = fig.add_subplot(111) for i in range(N): state_label = r"$p_{%d, %s}$" % (i + 1, "final") state_val = " = {:.3f}".format(p_time_series[i][-1]) ax.plot(time, p_time_series[i], "-", lw=2, label=state_label + state_val) ptot_label = r"$p_{tot, final}$" + " = {:.2f}".format(p_tot[-1]) ax.plot(time, p_tot, "--", lw=2, color="black", label=ptot_label) ax.set_title("State Probabilities for {} State Model".format(N)) ax.set_ylabel(r"Probability") ax.set_xlabel(r"Time (s)") if bbox_coords is None: ax.legend(loc=legendloc) else: ax.legend(loc=legendloc, bbox_to_anchor=bbox_coords) plt.close() if path: save_path = os.path.join(path, f"{label}.png") fig.savefig(save_path, dpi=300) else: return fig