colosseum.utils.visualization
1from typing import Dict, Tuple, Union, TYPE_CHECKING 2 3import networkx as nx 4import numpy as np 5import seaborn as sns 6import toolz 7from matplotlib import pyplot as plt 8 9from colosseum.mdp.utils import MDPCommunicationClass 10 11if TYPE_CHECKING: 12 from colosseum.mdp import ContinuousMDP, EpisodicMDP, NODE_TYPE 13 14 15def _create_ax(layout, figsize: Tuple[int, int] = None): 16 if figsize is None: 17 positions = np.array(list(layout.values())) 18 max_distance = max( 19 np.sqrt(np.sum((positions[i] - positions[j]) ** 2)) 20 for i in range(len(layout)) 21 for j in range(i + 1, len(layout)) 22 ) 23 figsize = max(6, min(20, int(max_distance / 70))) 24 figsize = (figsize, figsize) 25 plt.figure(None, figsize) 26 cf = plt.gcf() 27 cf.set_facecolor("w") 28 ax = cf.add_axes((0, 0, 1, 1)) if cf._axstack() is None else cf.gca() 29 ax.spines["top"].set_visible(False) 30 ax.spines["right"].set_visible(False) 31 ax.spines["bottom"].set_visible(False) 32 ax.spines["left"].set_visible(False) 33 return ax 34 35 36def _create_MDP_graph(mdp: Union["ContinuousMDP", "EpisodicMDP"]): 37 T, R = mdp.transition_matrix_and_rewards 38 39 probs = dict() 40 G = nx.DiGraph() 41 for s in range(mdp.n_states): 42 n = mdp.index_to_node[s] 43 if s not in G.nodes: 44 G.add_node(n) 45 for a in range(mdp.n_actions): 46 an = tuple((n, a)) 47 G.add_edge(n, an) 48 for nn in np.where(T[s, a] > 0)[0]: 49 G.add_edge(an, mdp.index_to_node[nn]) 50 probs[an, mdp.index_to_node[nn]] = T[s, a, nn] 51 52 return G, probs 53 54 55def _create_epi_MDP_graph(mdp: "EpisodicMDP"): 56 G_epi = mdp.get_episodic_graph(False) 57 T, R = mdp.episodic_transition_matrix_and_rewards 58 59 probs = dict() 60 G = nx.DiGraph() 61 for n in G_epi.nodes: 62 if n not in G.nodes: 63 G.add_node(n) 64 65 for a in range(mdp.n_actions): 66 an = tuple((n, a)) 67 G.add_edge(n, an) 68 for nn in G_epi.successors(n): 69 G.add_edge(an, nn) 70 probs[an, nn] = T[ 71 n[0], mdp.node_to_index(n[1]), a, mdp.node_to_index(nn[1]) 72 ] 73 74 return G, probs 75 76 77custom_dark = sns.color_palette("dark") 78p = custom_dark.pop(1) 79custom_dark.insert(6, p) 80 81 82def plot_MDP_graph( 83 mdp: Union["ContinuousMDP", "EpisodicMDP"], 84 node_palette=sns.color_palette("bright"), 85 action_palette=custom_dark, 86 save_file: str = None, 87 ax=None, 88 figsize: Tuple[int, int] = None, 89 node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None, 90 action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None, 91 int_labels_offset_x: int = 10, 92 int_labels_offset_y: int = 10, 93 continuous_form: bool = True, 94 prog="neato", 95 ncol: int = 4, 96 title: str = None, 97 fontsize: int = None, 98 font_color_state_labels="k", 99 font_color_state_actions_labels="k", 100 cm_state_labels=None, 101 cm_state_actions_labels=None, 102 no_written_state_labels=True, 103 no_written_state_action_labels=True, 104 node_size=150 105): 106 show = ax is None 107 108 sns.reset_defaults() 109 G, probs = ( 110 _create_epi_MDP_graph(mdp) 111 if mdp.is_episodic() and not continuous_form 112 else _create_MDP_graph(mdp) 113 ) 114 T, R = mdp.transition_matrix_and_rewards 115 116 layout = nx.nx_agraph.graphviz_layout(G, prog=prog) 117 118 if ax is None: 119 ax = _create_ax(layout, figsize) 120 121 if mdp.is_episodic() and not continuous_form: 122 if node_labels is not None and cm_state_labels is not None: 123 node_color = [ 124 cm_state_labels(node_labels[node] / max(node_labels.values())) 125 for node in mdp.get_episodic_graph(False).nodes 126 ] 127 else: 128 node_color = [ 129 node_palette[5] # brown 130 if node[1] in mdp.starting_nodes and node[0] == 0 131 else node_palette[2] # green 132 if R[mdp.node_to_index(node[1])].max() == R.max() 133 else node_palette[-2] # yellow 134 if node[1] in mdp.recurrent_nodes_set 135 else node_palette[-3] # grey 136 for node in mdp.get_episodic_graph(False).nodes 137 ] 138 else: 139 if node_labels is not None and cm_state_labels is not None: 140 node_color = [ 141 cm_state_labels(node_labels[node] / max(node_labels.values())) 142 for node in mdp.G.nodes 143 ] 144 else: 145 node_color = [ 146 node_palette[5] # brown 147 if node in mdp.starting_nodes 148 else node_palette[2] # green 149 if R[mdp.node_to_index[node]].max() == R.max() 150 else node_palette[-2] # yellow 151 if node in mdp.recurrent_nodes_set 152 else node_palette[-3] # grey 153 for node in mdp.G.nodes 154 ] 155 156 # Lazy way to create nice legend handles 157 x, y = list(layout.values())[0] 158 if cm_state_labels is None: 159 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 160 ax.scatter(x, y, color=node_palette[-2], label="State") 161 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 162 ax.scatter(x, y, color=node_palette[-3], label="Transient state") 163 ax.scatter(x, y, color=node_palette[5], label="Starting state") 164 ax.plot(x, y, color=node_palette[-3], label="Transition probability") 165 if cm_state_actions_labels is None: 166 for a in range(mdp.n_actions): 167 ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s") 168 169 G_nodes = ( 170 mdp.get_episodic_graph(False).nodes 171 if mdp.is_episodic() and not continuous_form 172 else mdp.G.nodes 173 ) 174 nx.draw_networkx_nodes( 175 G, layout, G_nodes, ax=ax, node_color=node_color, edgecolors="black",node_size=node_size 176 ) 177 for a in range(mdp.n_actions): 178 na_nodes = ( 179 [an for an in G.nodes if type(an[1]) == int and an[-1] == a] 180 if mdp.is_episodic() and not continuous_form 181 else [an for an in G.nodes if type(an) == tuple and an[-1] == a] 182 ) 183 nx.draw_networkx_nodes( 184 G, 185 layout, 186 na_nodes, 187 node_shape="s", 188 ax=ax, 189 node_size=node_size, 190 node_color=[action_palette[a]] 191 if cm_state_actions_labels is None 192 else [ 193 cm_state_actions_labels(action_labels[an] / max(action_labels.values())) 194 for an in na_nodes 195 ], 196 edgecolors="black", 197 ) 198 nx.draw_networkx_edges( 199 G, 200 layout, 201 edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a] 202 if mdp.is_episodic() and not continuous_form 203 else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a], 204 ax=ax, 205 edge_color=action_palette[a], 206 ) 207 nx.draw_networkx_edges( 208 G, 209 layout, 210 edgelist=[e for e in G.edges if type(e[0][0]) == tuple] 211 if mdp.is_episodic() and not continuous_form 212 else [e for e in G.edges if type(e[0]) == tuple], 213 ax=ax, 214 edge_color=action_palette[-3], 215 width=[probs[e] for e in G.edges if type(e[0][0]) == tuple] 216 if mdp.is_episodic() and not continuous_form 217 else [probs[e] for e in G.edges if type(e[0]) == tuple], 218 ) 219 ax.legend(ncol=ncol, fontsize=fontsize) 220 if node_labels is not None and not no_written_state_labels: 221 if type(node_labels) == bool and node_labels: 222 node_labels = { 223 n: ( 224 f"h={n[0]},{n[1]}" 225 if mdp.is_episodic() and not continuous_form 226 else str(n) 227 ) 228 for n in G_nodes 229 } 230 assert all(n in G.nodes for n in node_labels) 231 nx.draw_networkx_labels( 232 G, 233 toolz.valmap( 234 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 235 layout, 236 ), 237 node_labels, 238 font_color=font_color_state_labels, 239 ax=ax, 240 verticalalignment="center_baseline", 241 ) 242 if action_labels is not None and not no_written_state_action_labels: 243 if type(action_labels) == bool and action_labels: 244 action_labels = { 245 n: str(n[1]) 246 for n in ( 247 (an for an in G.nodes if type(an[1]) == int) 248 if mdp.is_episodic() and not continuous_form 249 else (an for an in G.nodes if type(an) == tuple) 250 ) 251 } 252 assert all(n in G.nodes and type(n[1]) == int for n in action_labels) 253 nx.draw_networkx_labels( 254 G, 255 toolz.valmap( 256 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 257 layout, 258 ), 259 action_labels, 260 font_color=font_color_state_actions_labels, 261 ax=ax, 262 verticalalignment="center_baseline", 263 ) 264 265 ax.axis("off") 266 if title is not None: 267 ax.set_title(title) 268 if save_file is not None: 269 plt.savefig(save_file, bbox_inches="tight") 270 if show: 271 plt.show() 272 273 274def plot_MCGraph( 275 mdp: Union["ContinuousMDP", "EpisodicMDP"], 276 node_palette=sns.color_palette("deep"), 277 labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {}, 278 font_color_labels="k", 279 save_file: str = None, 280 ax=None, 281 figsize: Tuple[int, int] = None, 282 prog: str = None, 283 fontsize: int = None, 284 node_size=100, 285 cm_state_labels=None, 286 no_written_state_labels=True, 287): 288 show = ax is None 289 290 _, R = mdp.transition_matrix_and_rewards 291 292 if cm_state_labels is not None: 293 node_color = [ 294 cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes 295 ] 296 else: 297 node_color = [ 298 node_palette[0] # brown 299 if node in mdp.starting_nodes 300 else node_palette[2] # green 301 if R[mdp.node_to_index[node]].max() == R.max() 302 else node_palette[1] # yellow 303 if node in mdp.recurrent_nodes_set 304 else node_palette[-1] # grey 305 for node in mdp.G.nodes 306 ] 307 308 if ax is None: 309 ax = _create_ax(mdp.graph_layout, figsize) 310 311 if cm_state_labels is None: 312 x, y = list(mdp.graph_layout.values())[0] 313 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 314 ax.scatter(x, y, color=node_palette[1], label="State") 315 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 316 ax.scatter(x, y, color=node_palette[-1], label="Transient state") 317 ax.scatter(x, y, color=node_palette[0], label="Starting state") 318 319 nx.draw( 320 mdp.G, 321 mdp.graph_layout 322 if prog is None 323 else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog), 324 node_color=node_color, 325 node_size=node_size, 326 edgecolors="black", 327 edge_color=node_palette[-3], 328 labels={} 329 if cm_state_labels is not None and no_written_state_labels 330 else labels, 331 font_color=font_color_labels, 332 ax=ax, 333 ) 334 335 if cm_state_labels is None: 336 ax.legend(fontsize=fontsize) 337 if save_file is not None: 338 plt.savefig(save_file) 339 if show: 340 plt.show()
def
_create_ax(layout, figsize: Tuple[int, int] = None):
16def _create_ax(layout, figsize: Tuple[int, int] = None): 17 if figsize is None: 18 positions = np.array(list(layout.values())) 19 max_distance = max( 20 np.sqrt(np.sum((positions[i] - positions[j]) ** 2)) 21 for i in range(len(layout)) 22 for j in range(i + 1, len(layout)) 23 ) 24 figsize = max(6, min(20, int(max_distance / 70))) 25 figsize = (figsize, figsize) 26 plt.figure(None, figsize) 27 cf = plt.gcf() 28 cf.set_facecolor("w") 29 ax = cf.add_axes((0, 0, 1, 1)) if cf._axstack() is None else cf.gca() 30 ax.spines["top"].set_visible(False) 31 ax.spines["right"].set_visible(False) 32 ax.spines["bottom"].set_visible(False) 33 ax.spines["left"].set_visible(False) 34 return ax
def
_create_MDP_graph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP]):
37def _create_MDP_graph(mdp: Union["ContinuousMDP", "EpisodicMDP"]): 38 T, R = mdp.transition_matrix_and_rewards 39 40 probs = dict() 41 G = nx.DiGraph() 42 for s in range(mdp.n_states): 43 n = mdp.index_to_node[s] 44 if s not in G.nodes: 45 G.add_node(n) 46 for a in range(mdp.n_actions): 47 an = tuple((n, a)) 48 G.add_edge(n, an) 49 for nn in np.where(T[s, a] > 0)[0]: 50 G.add_edge(an, mdp.index_to_node[nn]) 51 probs[an, mdp.index_to_node[nn]] = T[s, a, nn] 52 53 return G, probs
56def _create_epi_MDP_graph(mdp: "EpisodicMDP"): 57 G_epi = mdp.get_episodic_graph(False) 58 T, R = mdp.episodic_transition_matrix_and_rewards 59 60 probs = dict() 61 G = nx.DiGraph() 62 for n in G_epi.nodes: 63 if n not in G.nodes: 64 G.add_node(n) 65 66 for a in range(mdp.n_actions): 67 an = tuple((n, a)) 68 G.add_edge(n, an) 69 for nn in G_epi.successors(n): 70 G.add_edge(an, nn) 71 probs[an, nn] = T[ 72 n[0], mdp.node_to_index(n[1]), a, mdp.node_to_index(nn[1]) 73 ] 74 75 return G, probs
def
plot_MDP_graph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette=[(0.00784313725490196, 0.24313725490196078, 1.0), (1.0, 0.48627450980392156, 0.0), (0.10196078431372549, 0.788235294117647, 0.2196078431372549), (0.9098039215686274, 0.0, 0.043137254901960784), (0.5450980392156862, 0.16862745098039217, 0.8862745098039215), (0.6235294117647059, 0.2823529411764706, 0.0), (0.9450980392156862, 0.2980392156862745, 0.7568627450980392), (0.6392156862745098, 0.6392156862745098, 0.6392156862745098), (1.0, 0.7686274509803922, 0.0), (0.0, 0.8431372549019608, 1.0)], action_palette=[(0.0, 0.10980392156862745, 0.4980392156862745), (0.07058823529411765, 0.44313725490196076, 0.10980392156862745), (0.5490196078431373, 0.03137254901960784, 0.0), (0.34901960784313724, 0.11764705882352941, 0.44313725490196076), (0.34901960784313724, 0.1843137254901961, 0.050980392156862744), (0.6352941176470588, 0.20784313725490197, 0.5098039215686274), (0.6941176470588235, 0.25098039215686274, 0.050980392156862744), (0.23529411764705882, 0.23529411764705882, 0.23529411764705882), (0.7215686274509804, 0.5215686274509804, 0.0392156862745098), (0.0, 0.38823529411764707, 0.4549019607843137)], save_file: str = None, ax=None, figsize: Tuple[int, int] = None, node_labels: Union[bool, dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = None, action_labels: Union[bool, dict[tuple[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], int], Union[float, str]]] = None, int_labels_offset_x: int = 10, int_labels_offset_y: int = 10, continuous_form: bool = True, prog='neato', ncol: int = 4, title: str = None, fontsize: int = None, font_color_state_labels='k', font_color_state_actions_labels='k', cm_state_labels=None, cm_state_actions_labels=None, no_written_state_labels=True, no_written_state_action_labels=True, node_size=150):
83def plot_MDP_graph( 84 mdp: Union["ContinuousMDP", "EpisodicMDP"], 85 node_palette=sns.color_palette("bright"), 86 action_palette=custom_dark, 87 save_file: str = None, 88 ax=None, 89 figsize: Tuple[int, int] = None, 90 node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None, 91 action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None, 92 int_labels_offset_x: int = 10, 93 int_labels_offset_y: int = 10, 94 continuous_form: bool = True, 95 prog="neato", 96 ncol: int = 4, 97 title: str = None, 98 fontsize: int = None, 99 font_color_state_labels="k", 100 font_color_state_actions_labels="k", 101 cm_state_labels=None, 102 cm_state_actions_labels=None, 103 no_written_state_labels=True, 104 no_written_state_action_labels=True, 105 node_size=150 106): 107 show = ax is None 108 109 sns.reset_defaults() 110 G, probs = ( 111 _create_epi_MDP_graph(mdp) 112 if mdp.is_episodic() and not continuous_form 113 else _create_MDP_graph(mdp) 114 ) 115 T, R = mdp.transition_matrix_and_rewards 116 117 layout = nx.nx_agraph.graphviz_layout(G, prog=prog) 118 119 if ax is None: 120 ax = _create_ax(layout, figsize) 121 122 if mdp.is_episodic() and not continuous_form: 123 if node_labels is not None and cm_state_labels is not None: 124 node_color = [ 125 cm_state_labels(node_labels[node] / max(node_labels.values())) 126 for node in mdp.get_episodic_graph(False).nodes 127 ] 128 else: 129 node_color = [ 130 node_palette[5] # brown 131 if node[1] in mdp.starting_nodes and node[0] == 0 132 else node_palette[2] # green 133 if R[mdp.node_to_index(node[1])].max() == R.max() 134 else node_palette[-2] # yellow 135 if node[1] in mdp.recurrent_nodes_set 136 else node_palette[-3] # grey 137 for node in mdp.get_episodic_graph(False).nodes 138 ] 139 else: 140 if node_labels is not None and cm_state_labels is not None: 141 node_color = [ 142 cm_state_labels(node_labels[node] / max(node_labels.values())) 143 for node in mdp.G.nodes 144 ] 145 else: 146 node_color = [ 147 node_palette[5] # brown 148 if node in mdp.starting_nodes 149 else node_palette[2] # green 150 if R[mdp.node_to_index[node]].max() == R.max() 151 else node_palette[-2] # yellow 152 if node in mdp.recurrent_nodes_set 153 else node_palette[-3] # grey 154 for node in mdp.G.nodes 155 ] 156 157 # Lazy way to create nice legend handles 158 x, y = list(layout.values())[0] 159 if cm_state_labels is None: 160 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 161 ax.scatter(x, y, color=node_palette[-2], label="State") 162 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 163 ax.scatter(x, y, color=node_palette[-3], label="Transient state") 164 ax.scatter(x, y, color=node_palette[5], label="Starting state") 165 ax.plot(x, y, color=node_palette[-3], label="Transition probability") 166 if cm_state_actions_labels is None: 167 for a in range(mdp.n_actions): 168 ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s") 169 170 G_nodes = ( 171 mdp.get_episodic_graph(False).nodes 172 if mdp.is_episodic() and not continuous_form 173 else mdp.G.nodes 174 ) 175 nx.draw_networkx_nodes( 176 G, layout, G_nodes, ax=ax, node_color=node_color, edgecolors="black",node_size=node_size 177 ) 178 for a in range(mdp.n_actions): 179 na_nodes = ( 180 [an for an in G.nodes if type(an[1]) == int and an[-1] == a] 181 if mdp.is_episodic() and not continuous_form 182 else [an for an in G.nodes if type(an) == tuple and an[-1] == a] 183 ) 184 nx.draw_networkx_nodes( 185 G, 186 layout, 187 na_nodes, 188 node_shape="s", 189 ax=ax, 190 node_size=node_size, 191 node_color=[action_palette[a]] 192 if cm_state_actions_labels is None 193 else [ 194 cm_state_actions_labels(action_labels[an] / max(action_labels.values())) 195 for an in na_nodes 196 ], 197 edgecolors="black", 198 ) 199 nx.draw_networkx_edges( 200 G, 201 layout, 202 edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a] 203 if mdp.is_episodic() and not continuous_form 204 else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a], 205 ax=ax, 206 edge_color=action_palette[a], 207 ) 208 nx.draw_networkx_edges( 209 G, 210 layout, 211 edgelist=[e for e in G.edges if type(e[0][0]) == tuple] 212 if mdp.is_episodic() and not continuous_form 213 else [e for e in G.edges if type(e[0]) == tuple], 214 ax=ax, 215 edge_color=action_palette[-3], 216 width=[probs[e] for e in G.edges if type(e[0][0]) == tuple] 217 if mdp.is_episodic() and not continuous_form 218 else [probs[e] for e in G.edges if type(e[0]) == tuple], 219 ) 220 ax.legend(ncol=ncol, fontsize=fontsize) 221 if node_labels is not None and not no_written_state_labels: 222 if type(node_labels) == bool and node_labels: 223 node_labels = { 224 n: ( 225 f"h={n[0]},{n[1]}" 226 if mdp.is_episodic() and not continuous_form 227 else str(n) 228 ) 229 for n in G_nodes 230 } 231 assert all(n in G.nodes for n in node_labels) 232 nx.draw_networkx_labels( 233 G, 234 toolz.valmap( 235 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 236 layout, 237 ), 238 node_labels, 239 font_color=font_color_state_labels, 240 ax=ax, 241 verticalalignment="center_baseline", 242 ) 243 if action_labels is not None and not no_written_state_action_labels: 244 if type(action_labels) == bool and action_labels: 245 action_labels = { 246 n: str(n[1]) 247 for n in ( 248 (an for an in G.nodes if type(an[1]) == int) 249 if mdp.is_episodic() and not continuous_form 250 else (an for an in G.nodes if type(an) == tuple) 251 ) 252 } 253 assert all(n in G.nodes and type(n[1]) == int for n in action_labels) 254 nx.draw_networkx_labels( 255 G, 256 toolz.valmap( 257 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 258 layout, 259 ), 260 action_labels, 261 font_color=font_color_state_actions_labels, 262 ax=ax, 263 verticalalignment="center_baseline", 264 ) 265 266 ax.axis("off") 267 if title is not None: 268 ax.set_title(title) 269 if save_file is not None: 270 plt.savefig(save_file, bbox_inches="tight") 271 if show: 272 plt.show()
def
plot_MCGraph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), (0.8666666666666667, 0.5176470588235295, 0.3215686274509804), (0.3333333333333333, 0.6588235294117647, 0.40784313725490196), (0.7686274509803922, 0.3058823529411765, 0.3215686274509804), (0.5058823529411764, 0.4470588235294118, 0.7019607843137254), (0.5764705882352941, 0.47058823529411764, 0.3764705882352941), (0.8549019607843137, 0.5450980392156862, 0.7647058823529411), (0.5490196078431373, 0.5490196078431373, 0.5490196078431373), (0.8, 0.7254901960784313, 0.4549019607843137), (0.39215686274509803, 0.7098039215686275, 0.803921568627451)], labels: Union[bool, dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = {}, font_color_labels='k', save_file: str = None, ax=None, figsize: Tuple[int, int] = None, prog: str = None, fontsize: int = None, node_size=100, cm_state_labels=None, no_written_state_labels=True):
275def plot_MCGraph( 276 mdp: Union["ContinuousMDP", "EpisodicMDP"], 277 node_palette=sns.color_palette("deep"), 278 labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {}, 279 font_color_labels="k", 280 save_file: str = None, 281 ax=None, 282 figsize: Tuple[int, int] = None, 283 prog: str = None, 284 fontsize: int = None, 285 node_size=100, 286 cm_state_labels=None, 287 no_written_state_labels=True, 288): 289 show = ax is None 290 291 _, R = mdp.transition_matrix_and_rewards 292 293 if cm_state_labels is not None: 294 node_color = [ 295 cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes 296 ] 297 else: 298 node_color = [ 299 node_palette[0] # brown 300 if node in mdp.starting_nodes 301 else node_palette[2] # green 302 if R[mdp.node_to_index[node]].max() == R.max() 303 else node_palette[1] # yellow 304 if node in mdp.recurrent_nodes_set 305 else node_palette[-1] # grey 306 for node in mdp.G.nodes 307 ] 308 309 if ax is None: 310 ax = _create_ax(mdp.graph_layout, figsize) 311 312 if cm_state_labels is None: 313 x, y = list(mdp.graph_layout.values())[0] 314 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 315 ax.scatter(x, y, color=node_palette[1], label="State") 316 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 317 ax.scatter(x, y, color=node_palette[-1], label="Transient state") 318 ax.scatter(x, y, color=node_palette[0], label="Starting state") 319 320 nx.draw( 321 mdp.G, 322 mdp.graph_layout 323 if prog is None 324 else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog), 325 node_color=node_color, 326 node_size=node_size, 327 edgecolors="black", 328 edge_color=node_palette[-3], 329 labels={} 330 if cm_state_labels is not None and no_written_state_labels 331 else labels, 332 font_color=font_color_labels, 333 ax=ax, 334 ) 335 336 if cm_state_labels is None: 337 ax.legend(fontsize=fontsize) 338 if save_file is not None: 339 plt.savefig(save_file) 340 if show: 341 plt.show()