colosseum.analysis.visualization

  1from typing import Dict, Tuple, Union, TYPE_CHECKING, List
  2
  3import networkx as nx
  4import numpy as np
  5import seaborn as sns
  6import toolz
  7from matplotlib import pyplot as plt
  8from matplotlib.colors import LinearSegmentedColormap
  9
 10from colosseum.mdp.utils import MDPCommunicationClass
 11
 12if TYPE_CHECKING:
 13    from colosseum.mdp import ContinuousMDP, EpisodicMDP, NODE_TYPE
 14
 15custom_dark = sns.color_palette("dark")
 16p = custom_dark.pop(1)
 17
 18
 19def plot_MDP_graph(
 20    mdp: Union["ContinuousMDP", "EpisodicMDP"],
 21    node_palette: List[Tuple[float, float, float]] = sns.color_palette("bright"),
 22    action_palette: List[Tuple[float, float, float]] = custom_dark,
 23    save_file: str = None,
 24    ax: plt.Axes = None,
 25    figsize: Tuple[int, int] = None,
 26    node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None,
 27    action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None,
 28    int_labels_offset_x: int = 10,
 29    int_labels_offset_y: int = 10,
 30    continuous_form: bool = True,
 31    prog: str = "neato",
 32    ncol: int = 4,
 33    title: str = None,
 34    legend_fontsize: int = None,
 35    font_color_state_labels: str = "k",
 36    font_color_state_actions_labels: str = "k",
 37    cm_state_labels: LinearSegmentedColormap = None,
 38    cm_state_actions_labels: LinearSegmentedColormap = None,
 39    no_written_state_labels=True,
 40    no_written_state_action_labels=True,
 41    node_size=150,
 42):
 43    """
 44    In the MDP graph representation, states are associated to round nodes and actions to square nodes. Each state is
 45    connected to action nodes for all the available actions. Each action node is connected to the state node such that
 46    the probability of transitioning to them when playing the corresponding action from the corresponding state is
 47    greater than one.
 48
 49    Parameters
 50    ----------
 51    mdp : Union["ContinuousMDP", "EpisodicMDP"]
 52        The MDP to be visualised.
 53    node_palette : List[Tuple[float, float, float]]
 54        The color palette for the nodes corresponding to states. By default, the seaborn 'bright' palette is used.
 55    action_palette : List[Tuple[float, float, float]]
 56        The color palette for the nodes corresponding to actions. By default, the seaborn 'dark' palette is used.
 57    save_file : str
 58        The path file where the visualization will be store. Be default, the visualization will not be stored.
 59    ax : plt.Axes
 60        The ax object where the plot will be put. By default, a new axis is created.
 61    figsize : Tuple[int, int]
 62        The size of the figures in the grid of plots. By default, None is provided.
 63    node_labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]]
 64        The dictionary mapping state nodes to labels (either a float or a str). If provided as bool, then the default labels
 65        are used. By default, no label is plotted.
 66    action_labels : Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]]
 67        The dictionary mapping action nodes to labels (either a float or a str). If provided as bool, then the default labels
 68        are used. By default, no label is plotted.
 69    int_labels_offset_x : int
 70        x-axis offset for the node labels. By default, the offset is set to 10.
 71    int_labels_offset_y : int
 72        y-axis offset for the node labels. By default, the offset is set to 10.
 73    continuous_form : bool
 74        If True and the MDP is episodic, the continuous form of the MDP is used. By default, the continuous form is
 75        used.
 76    prog : str
 77        The default program to create the graph layout to be passed to the networkx library. By default, 'neato' is used.
 78    ncol : int
 79        The number of columns in the legend. By default, four columns are used.
 80    title : str
 81        The title to be given to the plot. By default, no title is used.
 82    legend_fontsize : int
 83        The font size of the legend. Vy default, it is not specified.
 84    font_color_state_labels : str
 85        The color code for the labels of the state nodes. By default, it is set to 'k'.
 86    font_color_state_actions_labels : str
 87        The color code for the labels of the action nodes. By default, it is set to 'k'.
 88    cm_state_labels : LinearSegmentedColormap
 89        The matplotlib color map for the labels of the state nodes. By default, it is not specified.
 90    cm_state_actions_labels : LinearSegmentedColormap
 91        The matplotlib color map for the labels of the action nodes. By default, it is not specified.
 92    no_written_state_labels : bool
 93        If True, the labels for the node states are not shown. By default, it is set to True.
 94    no_written_state_action_labels : bool
 95        If True, the labels for the node actions are not shown. By default, it is set to True.
 96    node_size : int
 97        The size of the node in the plot. By default, it is set to 100.
 98    """
 99    show = ax is None
100
101    sns.reset_defaults()
102    G, probs = (
103        _create_epi_MDP_graph(mdp)
104        if mdp.is_episodic() and not continuous_form
105        else _create_MDP_graph(mdp)
106    )
107    T, R = mdp.transition_matrix_and_rewards
108
109    layout = nx.nx_agraph.graphviz_layout(G, prog=prog)
110
111    if ax is None:
112        ax = _create_ax(layout, figsize)
113
114    if mdp.is_episodic() and not continuous_form:
115        if node_labels is not None and cm_state_labels is not None:
116            node_color = [
117                cm_state_labels(node_labels[node] / max(node_labels.values()))
118                for node in mdp.get_episodic_graph(False).nodes
119            ]
120        else:
121            node_color = [
122                node_palette[5]  # brown
123                if node[1] in mdp.starting_nodes and node[0] == 0
124                else node_palette[2]  # green
125                if R[mdp.node_to_index(node[1])].max() == R.max()
126                else node_palette[-2]  # yellow
127                if node[1] in mdp.recurrent_nodes_set
128                else node_palette[-3]  # grey
129                for node in mdp.get_episodic_graph(False).nodes
130            ]
131    else:
132        if node_labels is not None and cm_state_labels is not None:
133            node_color = [
134                cm_state_labels(node_labels[node] / max(node_labels.values()))
135                for node in mdp.G.nodes
136            ]
137        else:
138            node_color = [
139                node_palette[5]  # brown
140                if node in mdp.starting_nodes
141                else node_palette[2]  # green
142                if R[mdp.node_to_index[node]].max() == R.max()
143                else node_palette[-2]  # yellow
144                if node in mdp.recurrent_nodes_set
145                else node_palette[-3]  # grey
146                for node in mdp.G.nodes
147            ]
148
149    # Lazy way to create nice legend handles
150    x, y = list(layout.values())[0]
151    if cm_state_labels is None:
152        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
153        ax.scatter(x, y, color=node_palette[-2], label="State")
154        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
155            ax.scatter(x, y, color=node_palette[-3], label="Transient state")
156        ax.scatter(x, y, color=node_palette[5], label="Starting state")
157    ax.plot(x, y, color=node_palette[-3], label="Transition probability")
158    if cm_state_actions_labels is None:
159        for a in range(mdp.n_actions):
160            ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s")
161
162    G_nodes = (
163        mdp.get_episodic_graph(False).nodes
164        if mdp.is_episodic() and not continuous_form
165        else mdp.G.nodes
166    )
167    nx.draw_networkx_nodes(
168        G,
169        layout,
170        G_nodes,
171        ax=ax,
172        node_color=node_color,
173        edgecolors="black",
174        node_size=node_size,
175    )
176    for a in range(mdp.n_actions):
177        na_nodes = (
178            [an for an in G.nodes if type(an[1]) == int and an[-1] == a]
179            if mdp.is_episodic() and not continuous_form
180            else [an for an in G.nodes if type(an) == tuple and an[-1] == a]
181        )
182        nx.draw_networkx_nodes(
183            G,
184            layout,
185            na_nodes,
186            node_shape="s",
187            ax=ax,
188            node_size=node_size,
189            node_color=[action_palette[a]]
190            if cm_state_actions_labels is None
191            else [
192                cm_state_actions_labels(action_labels[an] / max(action_labels.values()))
193                for an in na_nodes
194            ],
195            edgecolors="black",
196        )
197        nx.draw_networkx_edges(
198            G,
199            layout,
200            edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a]
201            if mdp.is_episodic() and not continuous_form
202            else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a],
203            ax=ax,
204            edge_color=action_palette[a],
205        )
206    nx.draw_networkx_edges(
207        G,
208        layout,
209        edgelist=[e for e in G.edges if type(e[0][0]) == tuple]
210        if mdp.is_episodic() and not continuous_form
211        else [e for e in G.edges if type(e[0]) == tuple],
212        ax=ax,
213        edge_color=action_palette[-3],
214        width=[probs[e] for e in G.edges if type(e[0][0]) == tuple]
215        if mdp.is_episodic() and not continuous_form
216        else [probs[e] for e in G.edges if type(e[0]) == tuple],
217    )
218    ax.legend(ncol=ncol, fontsize=legend_fontsize)
219    if node_labels is not None and not no_written_state_labels:
220        if type(node_labels) == bool and node_labels:
221            node_labels = {
222                n: (
223                    f"h={n[0]},{n[1]}"
224                    if mdp.is_episodic() and not continuous_form
225                    else str(n)
226                )
227                for n in G_nodes
228            }
229        assert all(n in G.nodes for n in node_labels)
230        nx.draw_networkx_labels(
231            G,
232            toolz.valmap(
233                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
234                layout,
235            ),
236            node_labels,
237            font_color=font_color_state_labels,
238            ax=ax,
239            verticalalignment="center_baseline",
240        )
241    if action_labels is not None and not no_written_state_action_labels:
242        if type(action_labels) == bool and action_labels:
243            action_labels = {
244                n: str(n[1])
245                for n in (
246                    (an for an in G.nodes if type(an[1]) == int)
247                    if mdp.is_episodic() and not continuous_form
248                    else (an for an in G.nodes if type(an) == tuple)
249                )
250            }
251        assert all(n in G.nodes and type(n[1]) == int for n in action_labels)
252        nx.draw_networkx_labels(
253            G,
254            toolz.valmap(
255                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
256                layout,
257            ),
258            action_labels,
259            font_color=font_color_state_actions_labels,
260            ax=ax,
261            verticalalignment="center_baseline",
262        )
263
264    ax.axis("off")
265    if title is not None:
266        ax.set_title(title)
267    if save_file is not None:
268        plt.savefig(save_file, bbox_inches="tight")
269    if show:
270        plt.show()
271
272
273def plot_MCGraph(
274    mdp: Union["ContinuousMDP", "EpisodicMDP"],
275    node_palette: List[Tuple[float, float, float]] = sns.color_palette("deep"),
276    labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {},
277    font_color_labels="k",
278    save_file: str = None,
279    ax: plt.Axes = None,
280    figsize: Tuple[int, int] = None,
281    prog: str = None,
282    legend_fontsize: int = None,
283    node_size=100,
284    cm_state_labels=None,
285    no_written_state_labels=True,
286):
287    """
288    In the Markov chain-based representation, states are associated to nodes and actions are not visualized. Each state
289    is connected to states such that there exists at least one action with corresponding transition probability greater
290    than zero.
291
292    Parameters
293    ----------
294    mdp : Union["ContinuousMDP", "EpisodicMDP"]
295        The MDP to be visualized.
296    node_palette : List[Tuple[float, float, float]]
297        The color palette for the nodes. By default, the seaborn 'bright' palette is used.
298    labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]]
299        The dictionary mapping nodes to labels (either a float or a str). If provided as bool, then the default labels
300        are used. By default, no label is plotted.
301    font_color_labels
302    save_file : str
303        The path file where the visualization will be store. Be default, the visualization will not be stored.
304    ax : plt.Axes
305        The ax object where the plot will be put. By default, a new axis is created.
306    figsize : Tuple[int, int]
307        The size of the figures in the grid of plots. By default, None is provided.
308    prog : str
309        The default program to create the graph layout to be passed to the networkx library. By default, it is not
310        specified.
311    legend_fontsize : int
312        The font size of the legend. Vy default, it is not specified.
313    node_size : int
314        The size of the node in the plot. By default, it is set to 100.
315    cm_state_labels : LinearSegmentedColormap
316        The matplotlib color map for the labels of the state nodes. By default, it is not specified.
317    no_written_state_labels : bool
318        If True, the labels for the nodes are not shown. By default, it is set to True.
319
320    """
321
322    show = ax is None
323
324    _, R = mdp.transition_matrix_and_rewards
325
326    if cm_state_labels is not None:
327        node_color = [
328            cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes
329        ]
330    else:
331        node_color = [
332            node_palette[0]  # brown
333            if node in mdp.starting_nodes
334            else node_palette[2]  # green
335            if R[mdp.node_to_index[node]].max() == R.max()
336            else node_palette[1]  # yellow
337            if node in mdp.recurrent_nodes_set
338            else node_palette[-1]  # grey
339            for node in mdp.G.nodes
340        ]
341
342    if ax is None:
343        ax = _create_ax(mdp.graph_layout, figsize)
344
345    if cm_state_labels is None:
346        x, y = list(mdp.graph_layout.values())[0]
347        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
348        ax.scatter(x, y, color=node_palette[1], label="State")
349        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
350            ax.scatter(x, y, color=node_palette[-1], label="Transient state")
351        ax.scatter(x, y, color=node_palette[0], label="Starting state")
352
353    nx.draw(
354        mdp.G,
355        mdp.graph_layout
356        if prog is None
357        else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog),
358        node_color=node_color,
359        node_size=node_size,
360        edgecolors="black",
361        edge_color=node_palette[-3],
362        labels={}
363        if cm_state_labels is not None and no_written_state_labels
364        else labels,
365        font_color=font_color_labels,
366        ax=ax,
367    )
368
369    if cm_state_labels is None:
370        ax.legend(fontsize=legend_fontsize)
371    if save_file is not None:
372        plt.savefig(save_file)
373    if show:
374        plt.show()
375
376
377def _create_ax(layout, figsize: Tuple[int, int] = None):
378    if figsize is None:
379        positions = np.array(list(layout.values()))
380        max_distance = max(
381            np.sqrt(np.sum((positions[i] - positions[j]) ** 2))
382            for i in range(len(layout))
383            for j in range(i + 1, len(layout))
384        )
385        figsize = max(6, min(20, int(max_distance / 70)))
386        figsize = (figsize, figsize)
387    plt.figure(None, figsize)
388    cf = plt.gcf()
389    cf.set_facecolor("w")
390    ax = cf.add_axes((0, 0, 1, 1)) if cf._axstack() is None else cf.gca()
391    ax.spines["top"].set_visible(False)
392    ax.spines["right"].set_visible(False)
393    ax.spines["bottom"].set_visible(False)
394    ax.spines["left"].set_visible(False)
395    return ax
396
397
398def _create_MDP_graph(mdp: Union["ContinuousMDP", "EpisodicMDP"]):
399    T, R = mdp.transition_matrix_and_rewards
400
401    probs = dict()
402    G = nx.DiGraph()
403    for s in range(mdp.n_states):
404        n = mdp.index_to_node[s]
405        if s not in G.nodes:
406            G.add_node(n)
407        for a in range(mdp.n_actions):
408            an = tuple((n, a))
409            G.add_edge(n, an)
410            for nn in np.where(T[s, a] > 0)[0]:
411                G.add_edge(an, mdp.index_to_node[nn])
412                probs[an, mdp.index_to_node[nn]] = T[s, a, nn]
413
414    return G, probs
415
416
417def _create_epi_MDP_graph(mdp: "EpisodicMDP"):
418    G_epi = mdp.get_episodic_graph(False)
419    T, R = mdp.episodic_transition_matrix_and_rewards
420
421    probs = dict()
422    G = nx.DiGraph()
423    for n in G_epi.nodes:
424        if n not in G.nodes:
425            G.add_node(n)
426
427        for a in range(mdp.n_actions):
428            an = tuple((n, a))
429            G.add_edge(n, an)
430            for nn in G_epi.successors(n):
431                G.add_edge(an, nn)
432                probs[an, nn] = T[
433                    n[0], mdp.node_to_index(n[1]), a, mdp.node_to_index(nn[1])
434                ]
435
436    return G, probs
def plot_MDP_graph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette: List[Tuple[float, float, float]] = [(0.00784313725490196, 0.24313725490196078, 1.0), (1.0, 0.48627450980392156, 0.0), (0.10196078431372549, 0.788235294117647, 0.2196078431372549), (0.9098039215686274, 0.0, 0.043137254901960784), (0.5450980392156862, 0.16862745098039217, 0.8862745098039215), (0.6235294117647059, 0.2823529411764706, 0.0), (0.9450980392156862, 0.2980392156862745, 0.7568627450980392), (0.6392156862745098, 0.6392156862745098, 0.6392156862745098), (1.0, 0.7686274509803922, 0.0), (0.0, 0.8431372549019608, 1.0)], action_palette: List[Tuple[float, float, float]] = [(0.0, 0.10980392156862745, 0.4980392156862745), (0.07058823529411765, 0.44313725490196076, 0.10980392156862745), (0.5490196078431373, 0.03137254901960784, 0.0), (0.34901960784313724, 0.11764705882352941, 0.44313725490196076), (0.34901960784313724, 0.1843137254901961, 0.050980392156862744), (0.6352941176470588, 0.20784313725490197, 0.5098039215686274), (0.23529411764705882, 0.23529411764705882, 0.23529411764705882), (0.7215686274509804, 0.5215686274509804, 0.0392156862745098), (0.0, 0.38823529411764707, 0.4549019607843137)], save_file: str = None, ax: matplotlib.axes._axes.Axes = None, figsize: Tuple[int, int] = None, node_labels: Union[bool, Dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = None, action_labels: Union[bool, Dict[Tuple[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], int], Union[float, str]]] = None, int_labels_offset_x: int = 10, int_labels_offset_y: int = 10, continuous_form: bool = True, prog: str = 'neato', ncol: int = 4, title: str = None, legend_fontsize: int = None, font_color_state_labels: str = 'k', font_color_state_actions_labels: str = 'k', cm_state_labels: matplotlib.colors.LinearSegmentedColormap = None, cm_state_actions_labels: matplotlib.colors.LinearSegmentedColormap = None, no_written_state_labels=True, no_written_state_action_labels=True, node_size=150):
 20def plot_MDP_graph(
 21    mdp: Union["ContinuousMDP", "EpisodicMDP"],
 22    node_palette: List[Tuple[float, float, float]] = sns.color_palette("bright"),
 23    action_palette: List[Tuple[float, float, float]] = custom_dark,
 24    save_file: str = None,
 25    ax: plt.Axes = None,
 26    figsize: Tuple[int, int] = None,
 27    node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None,
 28    action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None,
 29    int_labels_offset_x: int = 10,
 30    int_labels_offset_y: int = 10,
 31    continuous_form: bool = True,
 32    prog: str = "neato",
 33    ncol: int = 4,
 34    title: str = None,
 35    legend_fontsize: int = None,
 36    font_color_state_labels: str = "k",
 37    font_color_state_actions_labels: str = "k",
 38    cm_state_labels: LinearSegmentedColormap = None,
 39    cm_state_actions_labels: LinearSegmentedColormap = None,
 40    no_written_state_labels=True,
 41    no_written_state_action_labels=True,
 42    node_size=150,
 43):
 44    """
 45    In the MDP graph representation, states are associated to round nodes and actions to square nodes. Each state is
 46    connected to action nodes for all the available actions. Each action node is connected to the state node such that
 47    the probability of transitioning to them when playing the corresponding action from the corresponding state is
 48    greater than one.
 49
 50    Parameters
 51    ----------
 52    mdp : Union["ContinuousMDP", "EpisodicMDP"]
 53        The MDP to be visualised.
 54    node_palette : List[Tuple[float, float, float]]
 55        The color palette for the nodes corresponding to states. By default, the seaborn 'bright' palette is used.
 56    action_palette : List[Tuple[float, float, float]]
 57        The color palette for the nodes corresponding to actions. By default, the seaborn 'dark' palette is used.
 58    save_file : str
 59        The path file where the visualization will be store. Be default, the visualization will not be stored.
 60    ax : plt.Axes
 61        The ax object where the plot will be put. By default, a new axis is created.
 62    figsize : Tuple[int, int]
 63        The size of the figures in the grid of plots. By default, None is provided.
 64    node_labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]]
 65        The dictionary mapping state nodes to labels (either a float or a str). If provided as bool, then the default labels
 66        are used. By default, no label is plotted.
 67    action_labels : Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]]
 68        The dictionary mapping action nodes to labels (either a float or a str). If provided as bool, then the default labels
 69        are used. By default, no label is plotted.
 70    int_labels_offset_x : int
 71        x-axis offset for the node labels. By default, the offset is set to 10.
 72    int_labels_offset_y : int
 73        y-axis offset for the node labels. By default, the offset is set to 10.
 74    continuous_form : bool
 75        If True and the MDP is episodic, the continuous form of the MDP is used. By default, the continuous form is
 76        used.
 77    prog : str
 78        The default program to create the graph layout to be passed to the networkx library. By default, 'neato' is used.
 79    ncol : int
 80        The number of columns in the legend. By default, four columns are used.
 81    title : str
 82        The title to be given to the plot. By default, no title is used.
 83    legend_fontsize : int
 84        The font size of the legend. Vy default, it is not specified.
 85    font_color_state_labels : str
 86        The color code for the labels of the state nodes. By default, it is set to 'k'.
 87    font_color_state_actions_labels : str
 88        The color code for the labels of the action nodes. By default, it is set to 'k'.
 89    cm_state_labels : LinearSegmentedColormap
 90        The matplotlib color map for the labels of the state nodes. By default, it is not specified.
 91    cm_state_actions_labels : LinearSegmentedColormap
 92        The matplotlib color map for the labels of the action nodes. By default, it is not specified.
 93    no_written_state_labels : bool
 94        If True, the labels for the node states are not shown. By default, it is set to True.
 95    no_written_state_action_labels : bool
 96        If True, the labels for the node actions are not shown. By default, it is set to True.
 97    node_size : int
 98        The size of the node in the plot. By default, it is set to 100.
 99    """
100    show = ax is None
101
102    sns.reset_defaults()
103    G, probs = (
104        _create_epi_MDP_graph(mdp)
105        if mdp.is_episodic() and not continuous_form
106        else _create_MDP_graph(mdp)
107    )
108    T, R = mdp.transition_matrix_and_rewards
109
110    layout = nx.nx_agraph.graphviz_layout(G, prog=prog)
111
112    if ax is None:
113        ax = _create_ax(layout, figsize)
114
115    if mdp.is_episodic() and not continuous_form:
116        if node_labels is not None and cm_state_labels is not None:
117            node_color = [
118                cm_state_labels(node_labels[node] / max(node_labels.values()))
119                for node in mdp.get_episodic_graph(False).nodes
120            ]
121        else:
122            node_color = [
123                node_palette[5]  # brown
124                if node[1] in mdp.starting_nodes and node[0] == 0
125                else node_palette[2]  # green
126                if R[mdp.node_to_index(node[1])].max() == R.max()
127                else node_palette[-2]  # yellow
128                if node[1] in mdp.recurrent_nodes_set
129                else node_palette[-3]  # grey
130                for node in mdp.get_episodic_graph(False).nodes
131            ]
132    else:
133        if node_labels is not None and cm_state_labels is not None:
134            node_color = [
135                cm_state_labels(node_labels[node] / max(node_labels.values()))
136                for node in mdp.G.nodes
137            ]
138        else:
139            node_color = [
140                node_palette[5]  # brown
141                if node in mdp.starting_nodes
142                else node_palette[2]  # green
143                if R[mdp.node_to_index[node]].max() == R.max()
144                else node_palette[-2]  # yellow
145                if node in mdp.recurrent_nodes_set
146                else node_palette[-3]  # grey
147                for node in mdp.G.nodes
148            ]
149
150    # Lazy way to create nice legend handles
151    x, y = list(layout.values())[0]
152    if cm_state_labels is None:
153        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
154        ax.scatter(x, y, color=node_palette[-2], label="State")
155        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
156            ax.scatter(x, y, color=node_palette[-3], label="Transient state")
157        ax.scatter(x, y, color=node_palette[5], label="Starting state")
158    ax.plot(x, y, color=node_palette[-3], label="Transition probability")
159    if cm_state_actions_labels is None:
160        for a in range(mdp.n_actions):
161            ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s")
162
163    G_nodes = (
164        mdp.get_episodic_graph(False).nodes
165        if mdp.is_episodic() and not continuous_form
166        else mdp.G.nodes
167    )
168    nx.draw_networkx_nodes(
169        G,
170        layout,
171        G_nodes,
172        ax=ax,
173        node_color=node_color,
174        edgecolors="black",
175        node_size=node_size,
176    )
177    for a in range(mdp.n_actions):
178        na_nodes = (
179            [an for an in G.nodes if type(an[1]) == int and an[-1] == a]
180            if mdp.is_episodic() and not continuous_form
181            else [an for an in G.nodes if type(an) == tuple and an[-1] == a]
182        )
183        nx.draw_networkx_nodes(
184            G,
185            layout,
186            na_nodes,
187            node_shape="s",
188            ax=ax,
189            node_size=node_size,
190            node_color=[action_palette[a]]
191            if cm_state_actions_labels is None
192            else [
193                cm_state_actions_labels(action_labels[an] / max(action_labels.values()))
194                for an in na_nodes
195            ],
196            edgecolors="black",
197        )
198        nx.draw_networkx_edges(
199            G,
200            layout,
201            edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a]
202            if mdp.is_episodic() and not continuous_form
203            else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a],
204            ax=ax,
205            edge_color=action_palette[a],
206        )
207    nx.draw_networkx_edges(
208        G,
209        layout,
210        edgelist=[e for e in G.edges if type(e[0][0]) == tuple]
211        if mdp.is_episodic() and not continuous_form
212        else [e for e in G.edges if type(e[0]) == tuple],
213        ax=ax,
214        edge_color=action_palette[-3],
215        width=[probs[e] for e in G.edges if type(e[0][0]) == tuple]
216        if mdp.is_episodic() and not continuous_form
217        else [probs[e] for e in G.edges if type(e[0]) == tuple],
218    )
219    ax.legend(ncol=ncol, fontsize=legend_fontsize)
220    if node_labels is not None and not no_written_state_labels:
221        if type(node_labels) == bool and node_labels:
222            node_labels = {
223                n: (
224                    f"h={n[0]},{n[1]}"
225                    if mdp.is_episodic() and not continuous_form
226                    else str(n)
227                )
228                for n in G_nodes
229            }
230        assert all(n in G.nodes for n in node_labels)
231        nx.draw_networkx_labels(
232            G,
233            toolz.valmap(
234                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
235                layout,
236            ),
237            node_labels,
238            font_color=font_color_state_labels,
239            ax=ax,
240            verticalalignment="center_baseline",
241        )
242    if action_labels is not None and not no_written_state_action_labels:
243        if type(action_labels) == bool and action_labels:
244            action_labels = {
245                n: str(n[1])
246                for n in (
247                    (an for an in G.nodes if type(an[1]) == int)
248                    if mdp.is_episodic() and not continuous_form
249                    else (an for an in G.nodes if type(an) == tuple)
250                )
251            }
252        assert all(n in G.nodes and type(n[1]) == int for n in action_labels)
253        nx.draw_networkx_labels(
254            G,
255            toolz.valmap(
256                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
257                layout,
258            ),
259            action_labels,
260            font_color=font_color_state_actions_labels,
261            ax=ax,
262            verticalalignment="center_baseline",
263        )
264
265    ax.axis("off")
266    if title is not None:
267        ax.set_title(title)
268    if save_file is not None:
269        plt.savefig(save_file, bbox_inches="tight")
270    if show:
271        plt.show()

In the MDP graph representation, states are associated to round nodes and actions to square nodes. Each state is connected to action nodes for all the available actions. Each action node is connected to the state node such that the probability of transitioning to them when playing the corresponding action from the corresponding state is greater than one.

Parameters
  • mdp (Union["ContinuousMDP", "EpisodicMDP"]): The MDP to be visualised.
  • node_palette (List[Tuple[float, float, float]]): The color palette for the nodes corresponding to states. By default, the seaborn 'bright' palette is used.
  • action_palette (List[Tuple[float, float, float]]): The color palette for the nodes corresponding to actions. By default, the seaborn 'dark' palette is used.
  • save_file (str): The path file where the visualization will be store. Be default, the visualization will not be stored.
  • ax (plt.Axes): The ax object where the plot will be put. By default, a new axis is created.
  • figsize (Tuple[int, int]): The size of the figures in the grid of plots. By default, None is provided.
  • node_labels (Union[bool, Dict["NODE_TYPE", Union[float, str]]]): The dictionary mapping state nodes to labels (either a float or a str). If provided as bool, then the default labels are used. By default, no label is plotted.
  • action_labels (Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]]): The dictionary mapping action nodes to labels (either a float or a str). If provided as bool, then the default labels are used. By default, no label is plotted.
  • int_labels_offset_x (int): x-axis offset for the node labels. By default, the offset is set to 10.
  • int_labels_offset_y (int): y-axis offset for the node labels. By default, the offset is set to 10.
  • continuous_form (bool): If True and the MDP is episodic, the continuous form of the MDP is used. By default, the continuous form is used.
  • prog (str): The default program to create the graph layout to be passed to the networkx library. By default, 'neato' is used.
  • ncol (int): The number of columns in the legend. By default, four columns are used.
  • title (str): The title to be given to the plot. By default, no title is used.
  • legend_fontsize (int): The font size of the legend. Vy default, it is not specified.
  • font_color_state_labels (str): The color code for the labels of the state nodes. By default, it is set to 'k'.
  • font_color_state_actions_labels (str): The color code for the labels of the action nodes. By default, it is set to 'k'.
  • cm_state_labels (LinearSegmentedColormap): The matplotlib color map for the labels of the state nodes. By default, it is not specified.
  • cm_state_actions_labels (LinearSegmentedColormap): The matplotlib color map for the labels of the action nodes. By default, it is not specified.
  • no_written_state_labels (bool): If True, the labels for the node states are not shown. By default, it is set to True.
  • no_written_state_action_labels (bool): If True, the labels for the node actions are not shown. By default, it is set to True.
  • node_size (int): The size of the node in the plot. By default, it is set to 100.
def plot_MCGraph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette: List[Tuple[float, float, float]] = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), (0.8666666666666667, 0.5176470588235295, 0.3215686274509804), (0.3333333333333333, 0.6588235294117647, 0.40784313725490196), (0.7686274509803922, 0.3058823529411765, 0.3215686274509804), (0.5058823529411764, 0.4470588235294118, 0.7019607843137254), (0.5764705882352941, 0.47058823529411764, 0.3764705882352941), (0.8549019607843137, 0.5450980392156862, 0.7647058823529411), (0.5490196078431373, 0.5490196078431373, 0.5490196078431373), (0.8, 0.7254901960784313, 0.4549019607843137), (0.39215686274509803, 0.7098039215686275, 0.803921568627451)], labels: Union[bool, Dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = {}, font_color_labels='k', save_file: str = None, ax: matplotlib.axes._axes.Axes = None, figsize: Tuple[int, int] = None, prog: str = None, legend_fontsize: int = None, node_size=100, cm_state_labels=None, no_written_state_labels=True):
274def plot_MCGraph(
275    mdp: Union["ContinuousMDP", "EpisodicMDP"],
276    node_palette: List[Tuple[float, float, float]] = sns.color_palette("deep"),
277    labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {},
278    font_color_labels="k",
279    save_file: str = None,
280    ax: plt.Axes = None,
281    figsize: Tuple[int, int] = None,
282    prog: str = None,
283    legend_fontsize: int = None,
284    node_size=100,
285    cm_state_labels=None,
286    no_written_state_labels=True,
287):
288    """
289    In the Markov chain-based representation, states are associated to nodes and actions are not visualized. Each state
290    is connected to states such that there exists at least one action with corresponding transition probability greater
291    than zero.
292
293    Parameters
294    ----------
295    mdp : Union["ContinuousMDP", "EpisodicMDP"]
296        The MDP to be visualized.
297    node_palette : List[Tuple[float, float, float]]
298        The color palette for the nodes. By default, the seaborn 'bright' palette is used.
299    labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]]
300        The dictionary mapping nodes to labels (either a float or a str). If provided as bool, then the default labels
301        are used. By default, no label is plotted.
302    font_color_labels
303    save_file : str
304        The path file where the visualization will be store. Be default, the visualization will not be stored.
305    ax : plt.Axes
306        The ax object where the plot will be put. By default, a new axis is created.
307    figsize : Tuple[int, int]
308        The size of the figures in the grid of plots. By default, None is provided.
309    prog : str
310        The default program to create the graph layout to be passed to the networkx library. By default, it is not
311        specified.
312    legend_fontsize : int
313        The font size of the legend. Vy default, it is not specified.
314    node_size : int
315        The size of the node in the plot. By default, it is set to 100.
316    cm_state_labels : LinearSegmentedColormap
317        The matplotlib color map for the labels of the state nodes. By default, it is not specified.
318    no_written_state_labels : bool
319        If True, the labels for the nodes are not shown. By default, it is set to True.
320
321    """
322
323    show = ax is None
324
325    _, R = mdp.transition_matrix_and_rewards
326
327    if cm_state_labels is not None:
328        node_color = [
329            cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes
330        ]
331    else:
332        node_color = [
333            node_palette[0]  # brown
334            if node in mdp.starting_nodes
335            else node_palette[2]  # green
336            if R[mdp.node_to_index[node]].max() == R.max()
337            else node_palette[1]  # yellow
338            if node in mdp.recurrent_nodes_set
339            else node_palette[-1]  # grey
340            for node in mdp.G.nodes
341        ]
342
343    if ax is None:
344        ax = _create_ax(mdp.graph_layout, figsize)
345
346    if cm_state_labels is None:
347        x, y = list(mdp.graph_layout.values())[0]
348        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
349        ax.scatter(x, y, color=node_palette[1], label="State")
350        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
351            ax.scatter(x, y, color=node_palette[-1], label="Transient state")
352        ax.scatter(x, y, color=node_palette[0], label="Starting state")
353
354    nx.draw(
355        mdp.G,
356        mdp.graph_layout
357        if prog is None
358        else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog),
359        node_color=node_color,
360        node_size=node_size,
361        edgecolors="black",
362        edge_color=node_palette[-3],
363        labels={}
364        if cm_state_labels is not None and no_written_state_labels
365        else labels,
366        font_color=font_color_labels,
367        ax=ax,
368    )
369
370    if cm_state_labels is None:
371        ax.legend(fontsize=legend_fontsize)
372    if save_file is not None:
373        plt.savefig(save_file)
374    if show:
375        plt.show()

In the Markov chain-based representation, states are associated to nodes and actions are not visualized. Each state is connected to states such that there exists at least one action with corresponding transition probability greater than zero.

Parameters
  • mdp (Union["ContinuousMDP", "EpisodicMDP"]): The MDP to be visualized.
  • node_palette (List[Tuple[float, float, float]]): The color palette for the nodes. By default, the seaborn 'bright' palette is used.
  • labels (Union[bool, Dict["NODE_TYPE", Union[float, str]]]): The dictionary mapping nodes to labels (either a float or a str). If provided as bool, then the default labels are used. By default, no label is plotted.
  • font_color_labels
  • save_file (str): The path file where the visualization will be store. Be default, the visualization will not be stored.
  • ax (plt.Axes): The ax object where the plot will be put. By default, a new axis is created.
  • figsize (Tuple[int, int]): The size of the figures in the grid of plots. By default, None is provided.
  • prog (str): The default program to create the graph layout to be passed to the networkx library. By default, it is not specified.
  • legend_fontsize (int): The font size of the legend. Vy default, it is not specified.
  • node_size (int): The size of the node in the plot. By default, it is set to 100.
  • cm_state_labels (LinearSegmentedColormap): The matplotlib color map for the labels of the state nodes. By default, it is not specified.
  • no_written_state_labels (bool): If True, the labels for the nodes are not shown. By default, it is set to True.