colosseum.utils.visualization

  1from typing import Dict, Tuple, Union, TYPE_CHECKING
  2
  3import networkx as nx
  4import numpy as np
  5import seaborn as sns
  6import toolz
  7from matplotlib import pyplot as plt
  8
  9from colosseum.mdp.utils import MDPCommunicationClass
 10
 11if TYPE_CHECKING:
 12    from colosseum.mdp import ContinuousMDP, EpisodicMDP, NODE_TYPE
 13
 14
 15def _create_ax(layout, figsize: Tuple[int, int] = None):
 16    if figsize is None:
 17        positions = np.array(list(layout.values()))
 18        max_distance = max(
 19            np.sqrt(np.sum((positions[i] - positions[j]) ** 2))
 20            for i in range(len(layout))
 21            for j in range(i + 1, len(layout))
 22        )
 23        figsize = max(6, min(20, int(max_distance / 70)))
 24        figsize = (figsize, figsize)
 25    plt.figure(None, figsize)
 26    cf = plt.gcf()
 27    cf.set_facecolor("w")
 28    ax = cf.add_axes((0, 0, 1, 1)) if cf._axstack() is None else cf.gca()
 29    ax.spines["top"].set_visible(False)
 30    ax.spines["right"].set_visible(False)
 31    ax.spines["bottom"].set_visible(False)
 32    ax.spines["left"].set_visible(False)
 33    return ax
 34
 35
 36def _create_MDP_graph(mdp: Union["ContinuousMDP", "EpisodicMDP"]):
 37    T, R = mdp.transition_matrix_and_rewards
 38
 39    probs = dict()
 40    G = nx.DiGraph()
 41    for s in range(mdp.n_states):
 42        n = mdp.index_to_node[s]
 43        if s not in G.nodes:
 44            G.add_node(n)
 45        for a in range(mdp.n_actions):
 46            an = tuple((n, a))
 47            G.add_edge(n, an)
 48            for nn in np.where(T[s, a] > 0)[0]:
 49                G.add_edge(an, mdp.index_to_node[nn])
 50                probs[an, mdp.index_to_node[nn]] = T[s, a, nn]
 51
 52    return G, probs
 53
 54
 55def _create_epi_MDP_graph(mdp: "EpisodicMDP"):
 56    G_epi = mdp.get_episodic_graph(False)
 57    T, R = mdp.episodic_transition_matrix_and_rewards
 58
 59    probs = dict()
 60    G = nx.DiGraph()
 61    for n in G_epi.nodes:
 62        if n not in G.nodes:
 63            G.add_node(n)
 64
 65        for a in range(mdp.n_actions):
 66            an = tuple((n, a))
 67            G.add_edge(n, an)
 68            for nn in G_epi.successors(n):
 69                G.add_edge(an, nn)
 70                probs[an, nn] = T[
 71                    n[0], mdp.node_to_index(n[1]), a, mdp.node_to_index(nn[1])
 72                ]
 73
 74    return G, probs
 75
 76
 77custom_dark = sns.color_palette("dark")
 78p = custom_dark.pop(1)
 79custom_dark.insert(6, p)
 80
 81
 82def plot_MDP_graph(
 83    mdp: Union["ContinuousMDP", "EpisodicMDP"],
 84    node_palette=sns.color_palette("bright"),
 85    action_palette=custom_dark,
 86    save_file: str = None,
 87    ax=None,
 88    figsize: Tuple[int, int] = None,
 89    node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None,
 90    action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None,
 91    int_labels_offset_x: int = 10,
 92    int_labels_offset_y: int = 10,
 93    continuous_form: bool = True,
 94    prog="neato",
 95    ncol: int = 4,
 96    title: str = None,
 97    fontsize: int = None,
 98    font_color_state_labels="k",
 99    font_color_state_actions_labels="k",
100    cm_state_labels=None,
101    cm_state_actions_labels=None,
102    no_written_state_labels=True,
103    no_written_state_action_labels=True,
104    node_size=150
105):
106    show = ax is None
107
108    sns.reset_defaults()
109    G, probs = (
110        _create_epi_MDP_graph(mdp)
111        if mdp.is_episodic() and not continuous_form
112        else _create_MDP_graph(mdp)
113    )
114    T, R = mdp.transition_matrix_and_rewards
115
116    layout = nx.nx_agraph.graphviz_layout(G, prog=prog)
117
118    if ax is None:
119        ax = _create_ax(layout, figsize)
120
121    if mdp.is_episodic() and not continuous_form:
122        if node_labels is not None and cm_state_labels is not None:
123            node_color = [
124                cm_state_labels(node_labels[node] / max(node_labels.values()))
125                for node in mdp.get_episodic_graph(False).nodes
126            ]
127        else:
128            node_color = [
129                node_palette[5]  # brown
130                if node[1] in mdp.starting_nodes and node[0] == 0
131                else node_palette[2]  # green
132                if R[mdp.node_to_index(node[1])].max() == R.max()
133                else node_palette[-2]  # yellow
134                if node[1] in mdp.recurrent_nodes_set
135                else node_palette[-3]  # grey
136                for node in mdp.get_episodic_graph(False).nodes
137            ]
138    else:
139        if node_labels is not None and cm_state_labels is not None:
140            node_color = [
141                cm_state_labels(node_labels[node] / max(node_labels.values()))
142                for node in mdp.G.nodes
143            ]
144        else:
145            node_color = [
146                node_palette[5]  # brown
147                if node in mdp.starting_nodes
148                else node_palette[2]  # green
149                if R[mdp.node_to_index[node]].max() == R.max()
150                else node_palette[-2]  # yellow
151                if node in mdp.recurrent_nodes_set
152                else node_palette[-3]  # grey
153                for node in mdp.G.nodes
154            ]
155
156    # Lazy way to create nice legend handles
157    x, y = list(layout.values())[0]
158    if cm_state_labels is None:
159        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
160        ax.scatter(x, y, color=node_palette[-2], label="State")
161        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
162            ax.scatter(x, y, color=node_palette[-3], label="Transient state")
163        ax.scatter(x, y, color=node_palette[5], label="Starting state")
164    ax.plot(x, y, color=node_palette[-3], label="Transition probability")
165    if cm_state_actions_labels is None:
166        for a in range(mdp.n_actions):
167            ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s")
168
169    G_nodes = (
170        mdp.get_episodic_graph(False).nodes
171        if mdp.is_episodic() and not continuous_form
172        else mdp.G.nodes
173    )
174    nx.draw_networkx_nodes(
175        G, layout, G_nodes, ax=ax, node_color=node_color, edgecolors="black",node_size=node_size
176    )
177    for a in range(mdp.n_actions):
178        na_nodes = (
179            [an for an in G.nodes if type(an[1]) == int and an[-1] == a]
180            if mdp.is_episodic() and not continuous_form
181            else [an for an in G.nodes if type(an) == tuple and an[-1] == a]
182        )
183        nx.draw_networkx_nodes(
184            G,
185            layout,
186            na_nodes,
187            node_shape="s",
188            ax=ax,
189            node_size=node_size,
190            node_color=[action_palette[a]]
191            if cm_state_actions_labels is None
192            else [
193                cm_state_actions_labels(action_labels[an] / max(action_labels.values()))
194                for an in na_nodes
195            ],
196            edgecolors="black",
197        )
198        nx.draw_networkx_edges(
199            G,
200            layout,
201            edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a]
202            if mdp.is_episodic() and not continuous_form
203            else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a],
204            ax=ax,
205            edge_color=action_palette[a],
206        )
207    nx.draw_networkx_edges(
208        G,
209        layout,
210        edgelist=[e for e in G.edges if type(e[0][0]) == tuple]
211        if mdp.is_episodic() and not continuous_form
212        else [e for e in G.edges if type(e[0]) == tuple],
213        ax=ax,
214        edge_color=action_palette[-3],
215        width=[probs[e] for e in G.edges if type(e[0][0]) == tuple]
216        if mdp.is_episodic() and not continuous_form
217        else [probs[e] for e in G.edges if type(e[0]) == tuple],
218    )
219    ax.legend(ncol=ncol, fontsize=fontsize)
220    if node_labels is not None and not no_written_state_labels:
221        if type(node_labels) == bool and node_labels:
222            node_labels = {
223                n: (
224                    f"h={n[0]},{n[1]}"
225                    if mdp.is_episodic() and not continuous_form
226                    else str(n)
227                )
228                for n in G_nodes
229            }
230        assert all(n in G.nodes for n in node_labels)
231        nx.draw_networkx_labels(
232            G,
233            toolz.valmap(
234                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
235                layout,
236            ),
237            node_labels,
238            font_color=font_color_state_labels,
239            ax=ax,
240            verticalalignment="center_baseline",
241        )
242    if action_labels is not None and not no_written_state_action_labels:
243        if type(action_labels) == bool and action_labels:
244            action_labels = {
245                n: str(n[1])
246                for n in (
247                    (an for an in G.nodes if type(an[1]) == int)
248                    if mdp.is_episodic() and not continuous_form
249                    else (an for an in G.nodes if type(an) == tuple)
250                )
251            }
252        assert all(n in G.nodes and type(n[1]) == int for n in action_labels)
253        nx.draw_networkx_labels(
254            G,
255            toolz.valmap(
256                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
257                layout,
258            ),
259            action_labels,
260            font_color=font_color_state_actions_labels,
261            ax=ax,
262            verticalalignment="center_baseline",
263        )
264
265    ax.axis("off")
266    if title is not None:
267        ax.set_title(title)
268    if save_file is not None:
269        plt.savefig(save_file, bbox_inches="tight")
270    if show:
271        plt.show()
272
273
274def plot_MCGraph(
275    mdp: Union["ContinuousMDP", "EpisodicMDP"],
276    node_palette=sns.color_palette("deep"),
277    labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {},
278    font_color_labels="k",
279    save_file: str = None,
280    ax=None,
281    figsize: Tuple[int, int] = None,
282    prog: str = None,
283    fontsize: int = None,
284    node_size=100,
285    cm_state_labels=None,
286    no_written_state_labels=True,
287):
288    show = ax is None
289
290    _, R = mdp.transition_matrix_and_rewards
291
292    if cm_state_labels is not None:
293        node_color = [
294            cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes
295        ]
296    else:
297        node_color = [
298            node_palette[0]  # brown
299            if node in mdp.starting_nodes
300            else node_palette[2]  # green
301            if R[mdp.node_to_index[node]].max() == R.max()
302            else node_palette[1]  # yellow
303            if node in mdp.recurrent_nodes_set
304            else node_palette[-1]  # grey
305            for node in mdp.G.nodes
306        ]
307
308    if ax is None:
309        ax = _create_ax(mdp.graph_layout, figsize)
310
311    if cm_state_labels is None:
312        x, y = list(mdp.graph_layout.values())[0]
313        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
314        ax.scatter(x, y, color=node_palette[1], label="State")
315        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
316            ax.scatter(x, y, color=node_palette[-1], label="Transient state")
317        ax.scatter(x, y, color=node_palette[0], label="Starting state")
318
319    nx.draw(
320        mdp.G,
321        mdp.graph_layout
322        if prog is None
323        else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog),
324        node_color=node_color,
325        node_size=node_size,
326        edgecolors="black",
327        edge_color=node_palette[-3],
328        labels={}
329        if cm_state_labels is not None and no_written_state_labels
330        else labels,
331        font_color=font_color_labels,
332        ax=ax,
333    )
334
335    if cm_state_labels is None:
336        ax.legend(fontsize=fontsize)
337    if save_file is not None:
338        plt.savefig(save_file)
339    if show:
340        plt.show()
def _create_ax(layout, figsize: Tuple[int, int] = None):
16def _create_ax(layout, figsize: Tuple[int, int] = None):
17    if figsize is None:
18        positions = np.array(list(layout.values()))
19        max_distance = max(
20            np.sqrt(np.sum((positions[i] - positions[j]) ** 2))
21            for i in range(len(layout))
22            for j in range(i + 1, len(layout))
23        )
24        figsize = max(6, min(20, int(max_distance / 70)))
25        figsize = (figsize, figsize)
26    plt.figure(None, figsize)
27    cf = plt.gcf()
28    cf.set_facecolor("w")
29    ax = cf.add_axes((0, 0, 1, 1)) if cf._axstack() is None else cf.gca()
30    ax.spines["top"].set_visible(False)
31    ax.spines["right"].set_visible(False)
32    ax.spines["bottom"].set_visible(False)
33    ax.spines["left"].set_visible(False)
34    return ax
def _create_MDP_graph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP]):
37def _create_MDP_graph(mdp: Union["ContinuousMDP", "EpisodicMDP"]):
38    T, R = mdp.transition_matrix_and_rewards
39
40    probs = dict()
41    G = nx.DiGraph()
42    for s in range(mdp.n_states):
43        n = mdp.index_to_node[s]
44        if s not in G.nodes:
45            G.add_node(n)
46        for a in range(mdp.n_actions):
47            an = tuple((n, a))
48            G.add_edge(n, an)
49            for nn in np.where(T[s, a] > 0)[0]:
50                G.add_edge(an, mdp.index_to_node[nn])
51                probs[an, mdp.index_to_node[nn]] = T[s, a, nn]
52
53    return G, probs
def _create_epi_MDP_graph(mdp: colosseum.mdp.base_finite.EpisodicMDP):
56def _create_epi_MDP_graph(mdp: "EpisodicMDP"):
57    G_epi = mdp.get_episodic_graph(False)
58    T, R = mdp.episodic_transition_matrix_and_rewards
59
60    probs = dict()
61    G = nx.DiGraph()
62    for n in G_epi.nodes:
63        if n not in G.nodes:
64            G.add_node(n)
65
66        for a in range(mdp.n_actions):
67            an = tuple((n, a))
68            G.add_edge(n, an)
69            for nn in G_epi.successors(n):
70                G.add_edge(an, nn)
71                probs[an, nn] = T[
72                    n[0], mdp.node_to_index(n[1]), a, mdp.node_to_index(nn[1])
73                ]
74
75    return G, probs
def plot_MDP_graph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette=[(0.00784313725490196, 0.24313725490196078, 1.0), (1.0, 0.48627450980392156, 0.0), (0.10196078431372549, 0.788235294117647, 0.2196078431372549), (0.9098039215686274, 0.0, 0.043137254901960784), (0.5450980392156862, 0.16862745098039217, 0.8862745098039215), (0.6235294117647059, 0.2823529411764706, 0.0), (0.9450980392156862, 0.2980392156862745, 0.7568627450980392), (0.6392156862745098, 0.6392156862745098, 0.6392156862745098), (1.0, 0.7686274509803922, 0.0), (0.0, 0.8431372549019608, 1.0)], action_palette=[(0.0, 0.10980392156862745, 0.4980392156862745), (0.07058823529411765, 0.44313725490196076, 0.10980392156862745), (0.5490196078431373, 0.03137254901960784, 0.0), (0.34901960784313724, 0.11764705882352941, 0.44313725490196076), (0.34901960784313724, 0.1843137254901961, 0.050980392156862744), (0.6352941176470588, 0.20784313725490197, 0.5098039215686274), (0.6941176470588235, 0.25098039215686274, 0.050980392156862744), (0.23529411764705882, 0.23529411764705882, 0.23529411764705882), (0.7215686274509804, 0.5215686274509804, 0.0392156862745098), (0.0, 0.38823529411764707, 0.4549019607843137)], save_file: str = None, ax=None, figsize: Tuple[int, int] = None, node_labels: Union[bool, dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = None, action_labels: Union[bool, dict[tuple[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], int], Union[float, str]]] = None, int_labels_offset_x: int = 10, int_labels_offset_y: int = 10, continuous_form: bool = True, prog='neato', ncol: int = 4, title: str = None, fontsize: int = None, font_color_state_labels='k', font_color_state_actions_labels='k', cm_state_labels=None, cm_state_actions_labels=None, no_written_state_labels=True, no_written_state_action_labels=True, node_size=150):
 83def plot_MDP_graph(
 84    mdp: Union["ContinuousMDP", "EpisodicMDP"],
 85    node_palette=sns.color_palette("bright"),
 86    action_palette=custom_dark,
 87    save_file: str = None,
 88    ax=None,
 89    figsize: Tuple[int, int] = None,
 90    node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None,
 91    action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None,
 92    int_labels_offset_x: int = 10,
 93    int_labels_offset_y: int = 10,
 94    continuous_form: bool = True,
 95    prog="neato",
 96    ncol: int = 4,
 97    title: str = None,
 98    fontsize: int = None,
 99    font_color_state_labels="k",
100    font_color_state_actions_labels="k",
101    cm_state_labels=None,
102    cm_state_actions_labels=None,
103    no_written_state_labels=True,
104    no_written_state_action_labels=True,
105    node_size=150
106):
107    show = ax is None
108
109    sns.reset_defaults()
110    G, probs = (
111        _create_epi_MDP_graph(mdp)
112        if mdp.is_episodic() and not continuous_form
113        else _create_MDP_graph(mdp)
114    )
115    T, R = mdp.transition_matrix_and_rewards
116
117    layout = nx.nx_agraph.graphviz_layout(G, prog=prog)
118
119    if ax is None:
120        ax = _create_ax(layout, figsize)
121
122    if mdp.is_episodic() and not continuous_form:
123        if node_labels is not None and cm_state_labels is not None:
124            node_color = [
125                cm_state_labels(node_labels[node] / max(node_labels.values()))
126                for node in mdp.get_episodic_graph(False).nodes
127            ]
128        else:
129            node_color = [
130                node_palette[5]  # brown
131                if node[1] in mdp.starting_nodes and node[0] == 0
132                else node_palette[2]  # green
133                if R[mdp.node_to_index(node[1])].max() == R.max()
134                else node_palette[-2]  # yellow
135                if node[1] in mdp.recurrent_nodes_set
136                else node_palette[-3]  # grey
137                for node in mdp.get_episodic_graph(False).nodes
138            ]
139    else:
140        if node_labels is not None and cm_state_labels is not None:
141            node_color = [
142                cm_state_labels(node_labels[node] / max(node_labels.values()))
143                for node in mdp.G.nodes
144            ]
145        else:
146            node_color = [
147                node_palette[5]  # brown
148                if node in mdp.starting_nodes
149                else node_palette[2]  # green
150                if R[mdp.node_to_index[node]].max() == R.max()
151                else node_palette[-2]  # yellow
152                if node in mdp.recurrent_nodes_set
153                else node_palette[-3]  # grey
154                for node in mdp.G.nodes
155            ]
156
157    # Lazy way to create nice legend handles
158    x, y = list(layout.values())[0]
159    if cm_state_labels is None:
160        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
161        ax.scatter(x, y, color=node_palette[-2], label="State")
162        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
163            ax.scatter(x, y, color=node_palette[-3], label="Transient state")
164        ax.scatter(x, y, color=node_palette[5], label="Starting state")
165    ax.plot(x, y, color=node_palette[-3], label="Transition probability")
166    if cm_state_actions_labels is None:
167        for a in range(mdp.n_actions):
168            ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s")
169
170    G_nodes = (
171        mdp.get_episodic_graph(False).nodes
172        if mdp.is_episodic() and not continuous_form
173        else mdp.G.nodes
174    )
175    nx.draw_networkx_nodes(
176        G, layout, G_nodes, ax=ax, node_color=node_color, edgecolors="black",node_size=node_size
177    )
178    for a in range(mdp.n_actions):
179        na_nodes = (
180            [an for an in G.nodes if type(an[1]) == int and an[-1] == a]
181            if mdp.is_episodic() and not continuous_form
182            else [an for an in G.nodes if type(an) == tuple and an[-1] == a]
183        )
184        nx.draw_networkx_nodes(
185            G,
186            layout,
187            na_nodes,
188            node_shape="s",
189            ax=ax,
190            node_size=node_size,
191            node_color=[action_palette[a]]
192            if cm_state_actions_labels is None
193            else [
194                cm_state_actions_labels(action_labels[an] / max(action_labels.values()))
195                for an in na_nodes
196            ],
197            edgecolors="black",
198        )
199        nx.draw_networkx_edges(
200            G,
201            layout,
202            edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a]
203            if mdp.is_episodic() and not continuous_form
204            else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a],
205            ax=ax,
206            edge_color=action_palette[a],
207        )
208    nx.draw_networkx_edges(
209        G,
210        layout,
211        edgelist=[e for e in G.edges if type(e[0][0]) == tuple]
212        if mdp.is_episodic() and not continuous_form
213        else [e for e in G.edges if type(e[0]) == tuple],
214        ax=ax,
215        edge_color=action_palette[-3],
216        width=[probs[e] for e in G.edges if type(e[0][0]) == tuple]
217        if mdp.is_episodic() and not continuous_form
218        else [probs[e] for e in G.edges if type(e[0]) == tuple],
219    )
220    ax.legend(ncol=ncol, fontsize=fontsize)
221    if node_labels is not None and not no_written_state_labels:
222        if type(node_labels) == bool and node_labels:
223            node_labels = {
224                n: (
225                    f"h={n[0]},{n[1]}"
226                    if mdp.is_episodic() and not continuous_form
227                    else str(n)
228                )
229                for n in G_nodes
230            }
231        assert all(n in G.nodes for n in node_labels)
232        nx.draw_networkx_labels(
233            G,
234            toolz.valmap(
235                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
236                layout,
237            ),
238            node_labels,
239            font_color=font_color_state_labels,
240            ax=ax,
241            verticalalignment="center_baseline",
242        )
243    if action_labels is not None and not no_written_state_action_labels:
244        if type(action_labels) == bool and action_labels:
245            action_labels = {
246                n: str(n[1])
247                for n in (
248                    (an for an in G.nodes if type(an[1]) == int)
249                    if mdp.is_episodic() and not continuous_form
250                    else (an for an in G.nodes if type(an) == tuple)
251                )
252            }
253        assert all(n in G.nodes and type(n[1]) == int for n in action_labels)
254        nx.draw_networkx_labels(
255            G,
256            toolz.valmap(
257                lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y],
258                layout,
259            ),
260            action_labels,
261            font_color=font_color_state_actions_labels,
262            ax=ax,
263            verticalalignment="center_baseline",
264        )
265
266    ax.axis("off")
267    if title is not None:
268        ax.set_title(title)
269    if save_file is not None:
270        plt.savefig(save_file, bbox_inches="tight")
271    if show:
272        plt.show()
def plot_MCGraph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), (0.8666666666666667, 0.5176470588235295, 0.3215686274509804), (0.3333333333333333, 0.6588235294117647, 0.40784313725490196), (0.7686274509803922, 0.3058823529411765, 0.3215686274509804), (0.5058823529411764, 0.4470588235294118, 0.7019607843137254), (0.5764705882352941, 0.47058823529411764, 0.3764705882352941), (0.8549019607843137, 0.5450980392156862, 0.7647058823529411), (0.5490196078431373, 0.5490196078431373, 0.5490196078431373), (0.8, 0.7254901960784313, 0.4549019607843137), (0.39215686274509803, 0.7098039215686275, 0.803921568627451)], labels: Union[bool, dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = {}, font_color_labels='k', save_file: str = None, ax=None, figsize: Tuple[int, int] = None, prog: str = None, fontsize: int = None, node_size=100, cm_state_labels=None, no_written_state_labels=True):
275def plot_MCGraph(
276    mdp: Union["ContinuousMDP", "EpisodicMDP"],
277    node_palette=sns.color_palette("deep"),
278    labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {},
279    font_color_labels="k",
280    save_file: str = None,
281    ax=None,
282    figsize: Tuple[int, int] = None,
283    prog: str = None,
284    fontsize: int = None,
285    node_size=100,
286    cm_state_labels=None,
287    no_written_state_labels=True,
288):
289    show = ax is None
290
291    _, R = mdp.transition_matrix_and_rewards
292
293    if cm_state_labels is not None:
294        node_color = [
295            cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes
296        ]
297    else:
298        node_color = [
299            node_palette[0]  # brown
300            if node in mdp.starting_nodes
301            else node_palette[2]  # green
302            if R[mdp.node_to_index[node]].max() == R.max()
303            else node_palette[1]  # yellow
304            if node in mdp.recurrent_nodes_set
305            else node_palette[-1]  # grey
306            for node in mdp.G.nodes
307        ]
308
309    if ax is None:
310        ax = _create_ax(mdp.graph_layout, figsize)
311
312    if cm_state_labels is None:
313        x, y = list(mdp.graph_layout.values())[0]
314        ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state")
315        ax.scatter(x, y, color=node_palette[1], label="State")
316        if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING:
317            ax.scatter(x, y, color=node_palette[-1], label="Transient state")
318        ax.scatter(x, y, color=node_palette[0], label="Starting state")
319
320    nx.draw(
321        mdp.G,
322        mdp.graph_layout
323        if prog is None
324        else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog),
325        node_color=node_color,
326        node_size=node_size,
327        edgecolors="black",
328        edge_color=node_palette[-3],
329        labels={}
330        if cm_state_labels is not None and no_written_state_labels
331        else labels,
332        font_color=font_color_labels,
333        ax=ax,
334    )
335
336    if cm_state_labels is None:
337        ax.legend(fontsize=fontsize)
338    if save_file is not None:
339        plt.savefig(save_file)
340    if show:
341        plt.show()