colosseum.analysis.visualization
1from typing import Dict, Tuple, Union, TYPE_CHECKING, List 2 3import networkx as nx 4import numpy as np 5import seaborn as sns 6import toolz 7from matplotlib import pyplot as plt 8from matplotlib.colors import LinearSegmentedColormap 9 10from colosseum.mdp.utils import MDPCommunicationClass 11 12if TYPE_CHECKING: 13 from colosseum.mdp import ContinuousMDP, EpisodicMDP, NODE_TYPE 14 15custom_dark = sns.color_palette("dark") 16p = custom_dark.pop(1) 17 18 19def plot_MDP_graph( 20 mdp: Union["ContinuousMDP", "EpisodicMDP"], 21 node_palette: List[Tuple[float, float, float]] = sns.color_palette("bright"), 22 action_palette: List[Tuple[float, float, float]] = custom_dark, 23 save_file: str = None, 24 ax: plt.Axes = None, 25 figsize: Tuple[int, int] = None, 26 node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None, 27 action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None, 28 int_labels_offset_x: int = 10, 29 int_labels_offset_y: int = 10, 30 continuous_form: bool = True, 31 prog: str = "neato", 32 ncol: int = 4, 33 title: str = None, 34 legend_fontsize: int = None, 35 font_color_state_labels: str = "k", 36 font_color_state_actions_labels: str = "k", 37 cm_state_labels: LinearSegmentedColormap = None, 38 cm_state_actions_labels: LinearSegmentedColormap = None, 39 no_written_state_labels=True, 40 no_written_state_action_labels=True, 41 node_size=150, 42): 43 """ 44 In the MDP graph representation, states are associated to round nodes and actions to square nodes. Each state is 45 connected to action nodes for all the available actions. Each action node is connected to the state node such that 46 the probability of transitioning to them when playing the corresponding action from the corresponding state is 47 greater than one. 48 49 Parameters 50 ---------- 51 mdp : Union["ContinuousMDP", "EpisodicMDP"] 52 The MDP to be visualised. 53 node_palette : List[Tuple[float, float, float]] 54 The color palette for the nodes corresponding to states. By default, the seaborn 'bright' palette is used. 55 action_palette : List[Tuple[float, float, float]] 56 The color palette for the nodes corresponding to actions. By default, the seaborn 'dark' palette is used. 57 save_file : str 58 The path file where the visualization will be store. Be default, the visualization will not be stored. 59 ax : plt.Axes 60 The ax object where the plot will be put. By default, a new axis is created. 61 figsize : Tuple[int, int] 62 The size of the figures in the grid of plots. By default, None is provided. 63 node_labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]] 64 The dictionary mapping state nodes to labels (either a float or a str). If provided as bool, then the default labels 65 are used. By default, no label is plotted. 66 action_labels : Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] 67 The dictionary mapping action nodes to labels (either a float or a str). If provided as bool, then the default labels 68 are used. By default, no label is plotted. 69 int_labels_offset_x : int 70 x-axis offset for the node labels. By default, the offset is set to 10. 71 int_labels_offset_y : int 72 y-axis offset for the node labels. By default, the offset is set to 10. 73 continuous_form : bool 74 If True and the MDP is episodic, the continuous form of the MDP is used. By default, the continuous form is 75 used. 76 prog : str 77 The default program to create the graph layout to be passed to the networkx library. By default, 'neato' is used. 78 ncol : int 79 The number of columns in the legend. By default, four columns are used. 80 title : str 81 The title to be given to the plot. By default, no title is used. 82 legend_fontsize : int 83 The font size of the legend. Vy default, it is not specified. 84 font_color_state_labels : str 85 The color code for the labels of the state nodes. By default, it is set to 'k'. 86 font_color_state_actions_labels : str 87 The color code for the labels of the action nodes. By default, it is set to 'k'. 88 cm_state_labels : LinearSegmentedColormap 89 The matplotlib color map for the labels of the state nodes. By default, it is not specified. 90 cm_state_actions_labels : LinearSegmentedColormap 91 The matplotlib color map for the labels of the action nodes. By default, it is not specified. 92 no_written_state_labels : bool 93 If True, the labels for the node states are not shown. By default, it is set to True. 94 no_written_state_action_labels : bool 95 If True, the labels for the node actions are not shown. By default, it is set to True. 96 node_size : int 97 The size of the node in the plot. By default, it is set to 100. 98 """ 99 show = ax is None 100 101 sns.reset_defaults() 102 G, probs = ( 103 _create_epi_MDP_graph(mdp) 104 if mdp.is_episodic() and not continuous_form 105 else _create_MDP_graph(mdp) 106 ) 107 T, R = mdp.transition_matrix_and_rewards 108 109 layout = nx.nx_agraph.graphviz_layout(G, prog=prog) 110 111 if ax is None: 112 ax = _create_ax(layout, figsize) 113 114 if mdp.is_episodic() and not continuous_form: 115 if node_labels is not None and cm_state_labels is not None: 116 node_color = [ 117 cm_state_labels(node_labels[node] / max(node_labels.values())) 118 for node in mdp.get_episodic_graph(False).nodes 119 ] 120 else: 121 node_color = [ 122 node_palette[5] # brown 123 if node[1] in mdp.starting_nodes and node[0] == 0 124 else node_palette[2] # green 125 if R[mdp.node_to_index(node[1])].max() == R.max() 126 else node_palette[-2] # yellow 127 if node[1] in mdp.recurrent_nodes_set 128 else node_palette[-3] # grey 129 for node in mdp.get_episodic_graph(False).nodes 130 ] 131 else: 132 if node_labels is not None and cm_state_labels is not None: 133 node_color = [ 134 cm_state_labels(node_labels[node] / max(node_labels.values())) 135 for node in mdp.G.nodes 136 ] 137 else: 138 node_color = [ 139 node_palette[5] # brown 140 if node in mdp.starting_nodes 141 else node_palette[2] # green 142 if R[mdp.node_to_index[node]].max() == R.max() 143 else node_palette[-2] # yellow 144 if node in mdp.recurrent_nodes_set 145 else node_palette[-3] # grey 146 for node in mdp.G.nodes 147 ] 148 149 # Lazy way to create nice legend handles 150 x, y = list(layout.values())[0] 151 if cm_state_labels is None: 152 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 153 ax.scatter(x, y, color=node_palette[-2], label="State") 154 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 155 ax.scatter(x, y, color=node_palette[-3], label="Transient state") 156 ax.scatter(x, y, color=node_palette[5], label="Starting state") 157 ax.plot(x, y, color=node_palette[-3], label="Transition probability") 158 if cm_state_actions_labels is None: 159 for a in range(mdp.n_actions): 160 ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s") 161 162 G_nodes = ( 163 mdp.get_episodic_graph(False).nodes 164 if mdp.is_episodic() and not continuous_form 165 else mdp.G.nodes 166 ) 167 nx.draw_networkx_nodes( 168 G, 169 layout, 170 G_nodes, 171 ax=ax, 172 node_color=node_color, 173 edgecolors="black", 174 node_size=node_size, 175 ) 176 for a in range(mdp.n_actions): 177 na_nodes = ( 178 [an for an in G.nodes if type(an[1]) == int and an[-1] == a] 179 if mdp.is_episodic() and not continuous_form 180 else [an for an in G.nodes if type(an) == tuple and an[-1] == a] 181 ) 182 nx.draw_networkx_nodes( 183 G, 184 layout, 185 na_nodes, 186 node_shape="s", 187 ax=ax, 188 node_size=node_size, 189 node_color=[action_palette[a]] 190 if cm_state_actions_labels is None 191 else [ 192 cm_state_actions_labels(action_labels[an] / max(action_labels.values())) 193 for an in na_nodes 194 ], 195 edgecolors="black", 196 ) 197 nx.draw_networkx_edges( 198 G, 199 layout, 200 edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a] 201 if mdp.is_episodic() and not continuous_form 202 else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a], 203 ax=ax, 204 edge_color=action_palette[a], 205 ) 206 nx.draw_networkx_edges( 207 G, 208 layout, 209 edgelist=[e for e in G.edges if type(e[0][0]) == tuple] 210 if mdp.is_episodic() and not continuous_form 211 else [e for e in G.edges if type(e[0]) == tuple], 212 ax=ax, 213 edge_color=action_palette[-3], 214 width=[probs[e] for e in G.edges if type(e[0][0]) == tuple] 215 if mdp.is_episodic() and not continuous_form 216 else [probs[e] for e in G.edges if type(e[0]) == tuple], 217 ) 218 ax.legend(ncol=ncol, fontsize=legend_fontsize) 219 if node_labels is not None and not no_written_state_labels: 220 if type(node_labels) == bool and node_labels: 221 node_labels = { 222 n: ( 223 f"h={n[0]},{n[1]}" 224 if mdp.is_episodic() and not continuous_form 225 else str(n) 226 ) 227 for n in G_nodes 228 } 229 assert all(n in G.nodes for n in node_labels) 230 nx.draw_networkx_labels( 231 G, 232 toolz.valmap( 233 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 234 layout, 235 ), 236 node_labels, 237 font_color=font_color_state_labels, 238 ax=ax, 239 verticalalignment="center_baseline", 240 ) 241 if action_labels is not None and not no_written_state_action_labels: 242 if type(action_labels) == bool and action_labels: 243 action_labels = { 244 n: str(n[1]) 245 for n in ( 246 (an for an in G.nodes if type(an[1]) == int) 247 if mdp.is_episodic() and not continuous_form 248 else (an for an in G.nodes if type(an) == tuple) 249 ) 250 } 251 assert all(n in G.nodes and type(n[1]) == int for n in action_labels) 252 nx.draw_networkx_labels( 253 G, 254 toolz.valmap( 255 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 256 layout, 257 ), 258 action_labels, 259 font_color=font_color_state_actions_labels, 260 ax=ax, 261 verticalalignment="center_baseline", 262 ) 263 264 ax.axis("off") 265 if title is not None: 266 ax.set_title(title) 267 if save_file is not None: 268 plt.savefig(save_file, bbox_inches="tight") 269 if show: 270 plt.show() 271 272 273def plot_MCGraph( 274 mdp: Union["ContinuousMDP", "EpisodicMDP"], 275 node_palette: List[Tuple[float, float, float]] = sns.color_palette("deep"), 276 labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {}, 277 font_color_labels="k", 278 save_file: str = None, 279 ax: plt.Axes = None, 280 figsize: Tuple[int, int] = None, 281 prog: str = None, 282 legend_fontsize: int = None, 283 node_size=100, 284 cm_state_labels=None, 285 no_written_state_labels=True, 286): 287 """ 288 In the Markov chain-based representation, states are associated to nodes and actions are not visualized. Each state 289 is connected to states such that there exists at least one action with corresponding transition probability greater 290 than zero. 291 292 Parameters 293 ---------- 294 mdp : Union["ContinuousMDP", "EpisodicMDP"] 295 The MDP to be visualized. 296 node_palette : List[Tuple[float, float, float]] 297 The color palette for the nodes. By default, the seaborn 'bright' palette is used. 298 labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]] 299 The dictionary mapping nodes to labels (either a float or a str). If provided as bool, then the default labels 300 are used. By default, no label is plotted. 301 font_color_labels 302 save_file : str 303 The path file where the visualization will be store. Be default, the visualization will not be stored. 304 ax : plt.Axes 305 The ax object where the plot will be put. By default, a new axis is created. 306 figsize : Tuple[int, int] 307 The size of the figures in the grid of plots. By default, None is provided. 308 prog : str 309 The default program to create the graph layout to be passed to the networkx library. By default, it is not 310 specified. 311 legend_fontsize : int 312 The font size of the legend. Vy default, it is not specified. 313 node_size : int 314 The size of the node in the plot. By default, it is set to 100. 315 cm_state_labels : LinearSegmentedColormap 316 The matplotlib color map for the labels of the state nodes. By default, it is not specified. 317 no_written_state_labels : bool 318 If True, the labels for the nodes are not shown. By default, it is set to True. 319 320 """ 321 322 show = ax is None 323 324 _, R = mdp.transition_matrix_and_rewards 325 326 if cm_state_labels is not None: 327 node_color = [ 328 cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes 329 ] 330 else: 331 node_color = [ 332 node_palette[0] # brown 333 if node in mdp.starting_nodes 334 else node_palette[2] # green 335 if R[mdp.node_to_index[node]].max() == R.max() 336 else node_palette[1] # yellow 337 if node in mdp.recurrent_nodes_set 338 else node_palette[-1] # grey 339 for node in mdp.G.nodes 340 ] 341 342 if ax is None: 343 ax = _create_ax(mdp.graph_layout, figsize) 344 345 if cm_state_labels is None: 346 x, y = list(mdp.graph_layout.values())[0] 347 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 348 ax.scatter(x, y, color=node_palette[1], label="State") 349 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 350 ax.scatter(x, y, color=node_palette[-1], label="Transient state") 351 ax.scatter(x, y, color=node_palette[0], label="Starting state") 352 353 nx.draw( 354 mdp.G, 355 mdp.graph_layout 356 if prog is None 357 else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog), 358 node_color=node_color, 359 node_size=node_size, 360 edgecolors="black", 361 edge_color=node_palette[-3], 362 labels={} 363 if cm_state_labels is not None and no_written_state_labels 364 else labels, 365 font_color=font_color_labels, 366 ax=ax, 367 ) 368 369 if cm_state_labels is None: 370 ax.legend(fontsize=legend_fontsize) 371 if save_file is not None: 372 plt.savefig(save_file) 373 if show: 374 plt.show() 375 376 377def _create_ax(layout, figsize: Tuple[int, int] = None): 378 if figsize is None: 379 positions = np.array(list(layout.values())) 380 max_distance = max( 381 np.sqrt(np.sum((positions[i] - positions[j]) ** 2)) 382 for i in range(len(layout)) 383 for j in range(i + 1, len(layout)) 384 ) 385 figsize = max(6, min(20, int(max_distance / 70))) 386 figsize = (figsize, figsize) 387 plt.figure(None, figsize) 388 cf = plt.gcf() 389 cf.set_facecolor("w") 390 ax = cf.add_axes((0, 0, 1, 1)) if cf._axstack() is None else cf.gca() 391 ax.spines["top"].set_visible(False) 392 ax.spines["right"].set_visible(False) 393 ax.spines["bottom"].set_visible(False) 394 ax.spines["left"].set_visible(False) 395 return ax 396 397 398def _create_MDP_graph(mdp: Union["ContinuousMDP", "EpisodicMDP"]): 399 T, R = mdp.transition_matrix_and_rewards 400 401 probs = dict() 402 G = nx.DiGraph() 403 for s in range(mdp.n_states): 404 n = mdp.index_to_node[s] 405 if s not in G.nodes: 406 G.add_node(n) 407 for a in range(mdp.n_actions): 408 an = tuple((n, a)) 409 G.add_edge(n, an) 410 for nn in np.where(T[s, a] > 0)[0]: 411 G.add_edge(an, mdp.index_to_node[nn]) 412 probs[an, mdp.index_to_node[nn]] = T[s, a, nn] 413 414 return G, probs 415 416 417def _create_epi_MDP_graph(mdp: "EpisodicMDP"): 418 G_epi = mdp.get_episodic_graph(False) 419 T, R = mdp.episodic_transition_matrix_and_rewards 420 421 probs = dict() 422 G = nx.DiGraph() 423 for n in G_epi.nodes: 424 if n not in G.nodes: 425 G.add_node(n) 426 427 for a in range(mdp.n_actions): 428 an = tuple((n, a)) 429 G.add_edge(n, an) 430 for nn in G_epi.successors(n): 431 G.add_edge(an, nn) 432 probs[an, nn] = T[ 433 n[0], mdp.node_to_index(n[1]), a, mdp.node_to_index(nn[1]) 434 ] 435 436 return G, probs
def
plot_MDP_graph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette: List[Tuple[float, float, float]] = [(0.00784313725490196, 0.24313725490196078, 1.0), (1.0, 0.48627450980392156, 0.0), (0.10196078431372549, 0.788235294117647, 0.2196078431372549), (0.9098039215686274, 0.0, 0.043137254901960784), (0.5450980392156862, 0.16862745098039217, 0.8862745098039215), (0.6235294117647059, 0.2823529411764706, 0.0), (0.9450980392156862, 0.2980392156862745, 0.7568627450980392), (0.6392156862745098, 0.6392156862745098, 0.6392156862745098), (1.0, 0.7686274509803922, 0.0), (0.0, 0.8431372549019608, 1.0)], action_palette: List[Tuple[float, float, float]] = [(0.0, 0.10980392156862745, 0.4980392156862745), (0.07058823529411765, 0.44313725490196076, 0.10980392156862745), (0.5490196078431373, 0.03137254901960784, 0.0), (0.34901960784313724, 0.11764705882352941, 0.44313725490196076), (0.34901960784313724, 0.1843137254901961, 0.050980392156862744), (0.6352941176470588, 0.20784313725490197, 0.5098039215686274), (0.23529411764705882, 0.23529411764705882, 0.23529411764705882), (0.7215686274509804, 0.5215686274509804, 0.0392156862745098), (0.0, 0.38823529411764707, 0.4549019607843137)], save_file: str = None, ax: matplotlib.axes._axes.Axes = None, figsize: Tuple[int, int] = None, node_labels: Union[bool, Dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = None, action_labels: Union[bool, Dict[Tuple[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], int], Union[float, str]]] = None, int_labels_offset_x: int = 10, int_labels_offset_y: int = 10, continuous_form: bool = True, prog: str = 'neato', ncol: int = 4, title: str = None, legend_fontsize: int = None, font_color_state_labels: str = 'k', font_color_state_actions_labels: str = 'k', cm_state_labels: matplotlib.colors.LinearSegmentedColormap = None, cm_state_actions_labels: matplotlib.colors.LinearSegmentedColormap = None, no_written_state_labels=True, no_written_state_action_labels=True, node_size=150):
20def plot_MDP_graph( 21 mdp: Union["ContinuousMDP", "EpisodicMDP"], 22 node_palette: List[Tuple[float, float, float]] = sns.color_palette("bright"), 23 action_palette: List[Tuple[float, float, float]] = custom_dark, 24 save_file: str = None, 25 ax: plt.Axes = None, 26 figsize: Tuple[int, int] = None, 27 node_labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = None, 28 action_labels: Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] = None, 29 int_labels_offset_x: int = 10, 30 int_labels_offset_y: int = 10, 31 continuous_form: bool = True, 32 prog: str = "neato", 33 ncol: int = 4, 34 title: str = None, 35 legend_fontsize: int = None, 36 font_color_state_labels: str = "k", 37 font_color_state_actions_labels: str = "k", 38 cm_state_labels: LinearSegmentedColormap = None, 39 cm_state_actions_labels: LinearSegmentedColormap = None, 40 no_written_state_labels=True, 41 no_written_state_action_labels=True, 42 node_size=150, 43): 44 """ 45 In the MDP graph representation, states are associated to round nodes and actions to square nodes. Each state is 46 connected to action nodes for all the available actions. Each action node is connected to the state node such that 47 the probability of transitioning to them when playing the corresponding action from the corresponding state is 48 greater than one. 49 50 Parameters 51 ---------- 52 mdp : Union["ContinuousMDP", "EpisodicMDP"] 53 The MDP to be visualised. 54 node_palette : List[Tuple[float, float, float]] 55 The color palette for the nodes corresponding to states. By default, the seaborn 'bright' palette is used. 56 action_palette : List[Tuple[float, float, float]] 57 The color palette for the nodes corresponding to actions. By default, the seaborn 'dark' palette is used. 58 save_file : str 59 The path file where the visualization will be store. Be default, the visualization will not be stored. 60 ax : plt.Axes 61 The ax object where the plot will be put. By default, a new axis is created. 62 figsize : Tuple[int, int] 63 The size of the figures in the grid of plots. By default, None is provided. 64 node_labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]] 65 The dictionary mapping state nodes to labels (either a float or a str). If provided as bool, then the default labels 66 are used. By default, no label is plotted. 67 action_labels : Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]] 68 The dictionary mapping action nodes to labels (either a float or a str). If provided as bool, then the default labels 69 are used. By default, no label is plotted. 70 int_labels_offset_x : int 71 x-axis offset for the node labels. By default, the offset is set to 10. 72 int_labels_offset_y : int 73 y-axis offset for the node labels. By default, the offset is set to 10. 74 continuous_form : bool 75 If True and the MDP is episodic, the continuous form of the MDP is used. By default, the continuous form is 76 used. 77 prog : str 78 The default program to create the graph layout to be passed to the networkx library. By default, 'neato' is used. 79 ncol : int 80 The number of columns in the legend. By default, four columns are used. 81 title : str 82 The title to be given to the plot. By default, no title is used. 83 legend_fontsize : int 84 The font size of the legend. Vy default, it is not specified. 85 font_color_state_labels : str 86 The color code for the labels of the state nodes. By default, it is set to 'k'. 87 font_color_state_actions_labels : str 88 The color code for the labels of the action nodes. By default, it is set to 'k'. 89 cm_state_labels : LinearSegmentedColormap 90 The matplotlib color map for the labels of the state nodes. By default, it is not specified. 91 cm_state_actions_labels : LinearSegmentedColormap 92 The matplotlib color map for the labels of the action nodes. By default, it is not specified. 93 no_written_state_labels : bool 94 If True, the labels for the node states are not shown. By default, it is set to True. 95 no_written_state_action_labels : bool 96 If True, the labels for the node actions are not shown. By default, it is set to True. 97 node_size : int 98 The size of the node in the plot. By default, it is set to 100. 99 """ 100 show = ax is None 101 102 sns.reset_defaults() 103 G, probs = ( 104 _create_epi_MDP_graph(mdp) 105 if mdp.is_episodic() and not continuous_form 106 else _create_MDP_graph(mdp) 107 ) 108 T, R = mdp.transition_matrix_and_rewards 109 110 layout = nx.nx_agraph.graphviz_layout(G, prog=prog) 111 112 if ax is None: 113 ax = _create_ax(layout, figsize) 114 115 if mdp.is_episodic() and not continuous_form: 116 if node_labels is not None and cm_state_labels is not None: 117 node_color = [ 118 cm_state_labels(node_labels[node] / max(node_labels.values())) 119 for node in mdp.get_episodic_graph(False).nodes 120 ] 121 else: 122 node_color = [ 123 node_palette[5] # brown 124 if node[1] in mdp.starting_nodes and node[0] == 0 125 else node_palette[2] # green 126 if R[mdp.node_to_index(node[1])].max() == R.max() 127 else node_palette[-2] # yellow 128 if node[1] in mdp.recurrent_nodes_set 129 else node_palette[-3] # grey 130 for node in mdp.get_episodic_graph(False).nodes 131 ] 132 else: 133 if node_labels is not None and cm_state_labels is not None: 134 node_color = [ 135 cm_state_labels(node_labels[node] / max(node_labels.values())) 136 for node in mdp.G.nodes 137 ] 138 else: 139 node_color = [ 140 node_palette[5] # brown 141 if node in mdp.starting_nodes 142 else node_palette[2] # green 143 if R[mdp.node_to_index[node]].max() == R.max() 144 else node_palette[-2] # yellow 145 if node in mdp.recurrent_nodes_set 146 else node_palette[-3] # grey 147 for node in mdp.G.nodes 148 ] 149 150 # Lazy way to create nice legend handles 151 x, y = list(layout.values())[0] 152 if cm_state_labels is None: 153 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 154 ax.scatter(x, y, color=node_palette[-2], label="State") 155 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 156 ax.scatter(x, y, color=node_palette[-3], label="Transient state") 157 ax.scatter(x, y, color=node_palette[5], label="Starting state") 158 ax.plot(x, y, color=node_palette[-3], label="Transition probability") 159 if cm_state_actions_labels is None: 160 for a in range(mdp.n_actions): 161 ax.plot(x, y, color=action_palette[a], label=f"Action: {a}", marker="s") 162 163 G_nodes = ( 164 mdp.get_episodic_graph(False).nodes 165 if mdp.is_episodic() and not continuous_form 166 else mdp.G.nodes 167 ) 168 nx.draw_networkx_nodes( 169 G, 170 layout, 171 G_nodes, 172 ax=ax, 173 node_color=node_color, 174 edgecolors="black", 175 node_size=node_size, 176 ) 177 for a in range(mdp.n_actions): 178 na_nodes = ( 179 [an for an in G.nodes if type(an[1]) == int and an[-1] == a] 180 if mdp.is_episodic() and not continuous_form 181 else [an for an in G.nodes if type(an) == tuple and an[-1] == a] 182 ) 183 nx.draw_networkx_nodes( 184 G, 185 layout, 186 na_nodes, 187 node_shape="s", 188 ax=ax, 189 node_size=node_size, 190 node_color=[action_palette[a]] 191 if cm_state_actions_labels is None 192 else [ 193 cm_state_actions_labels(action_labels[an] / max(action_labels.values())) 194 for an in na_nodes 195 ], 196 edgecolors="black", 197 ) 198 nx.draw_networkx_edges( 199 G, 200 layout, 201 edgelist=[e for e in G.edges if type(e[0][0]) != tuple and e[1][1] == a] 202 if mdp.is_episodic() and not continuous_form 203 else [e for e in G.edges if type(e[0]) != tuple and e[1][1] == a], 204 ax=ax, 205 edge_color=action_palette[a], 206 ) 207 nx.draw_networkx_edges( 208 G, 209 layout, 210 edgelist=[e for e in G.edges if type(e[0][0]) == tuple] 211 if mdp.is_episodic() and not continuous_form 212 else [e for e in G.edges if type(e[0]) == tuple], 213 ax=ax, 214 edge_color=action_palette[-3], 215 width=[probs[e] for e in G.edges if type(e[0][0]) == tuple] 216 if mdp.is_episodic() and not continuous_form 217 else [probs[e] for e in G.edges if type(e[0]) == tuple], 218 ) 219 ax.legend(ncol=ncol, fontsize=legend_fontsize) 220 if node_labels is not None and not no_written_state_labels: 221 if type(node_labels) == bool and node_labels: 222 node_labels = { 223 n: ( 224 f"h={n[0]},{n[1]}" 225 if mdp.is_episodic() and not continuous_form 226 else str(n) 227 ) 228 for n in G_nodes 229 } 230 assert all(n in G.nodes for n in node_labels) 231 nx.draw_networkx_labels( 232 G, 233 toolz.valmap( 234 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 235 layout, 236 ), 237 node_labels, 238 font_color=font_color_state_labels, 239 ax=ax, 240 verticalalignment="center_baseline", 241 ) 242 if action_labels is not None and not no_written_state_action_labels: 243 if type(action_labels) == bool and action_labels: 244 action_labels = { 245 n: str(n[1]) 246 for n in ( 247 (an for an in G.nodes if type(an[1]) == int) 248 if mdp.is_episodic() and not continuous_form 249 else (an for an in G.nodes if type(an) == tuple) 250 ) 251 } 252 assert all(n in G.nodes and type(n[1]) == int for n in action_labels) 253 nx.draw_networkx_labels( 254 G, 255 toolz.valmap( 256 lambda x: [x[0] + int_labels_offset_x, x[1] + int_labels_offset_y], 257 layout, 258 ), 259 action_labels, 260 font_color=font_color_state_actions_labels, 261 ax=ax, 262 verticalalignment="center_baseline", 263 ) 264 265 ax.axis("off") 266 if title is not None: 267 ax.set_title(title) 268 if save_file is not None: 269 plt.savefig(save_file, bbox_inches="tight") 270 if show: 271 plt.show()
In the MDP graph representation, states are associated to round nodes and actions to square nodes. Each state is connected to action nodes for all the available actions. Each action node is connected to the state node such that the probability of transitioning to them when playing the corresponding action from the corresponding state is greater than one.
Parameters
- mdp (Union["ContinuousMDP", "EpisodicMDP"]): The MDP to be visualised.
- node_palette (List[Tuple[float, float, float]]): The color palette for the nodes corresponding to states. By default, the seaborn 'bright' palette is used.
- action_palette (List[Tuple[float, float, float]]): The color palette for the nodes corresponding to actions. By default, the seaborn 'dark' palette is used.
- save_file (str): The path file where the visualization will be store. Be default, the visualization will not be stored.
- ax (plt.Axes): The ax object where the plot will be put. By default, a new axis is created.
- figsize (Tuple[int, int]): The size of the figures in the grid of plots. By default, None is provided.
- node_labels (Union[bool, Dict["NODE_TYPE", Union[float, str]]]): The dictionary mapping state nodes to labels (either a float or a str). If provided as bool, then the default labels are used. By default, no label is plotted.
- action_labels (Union[bool, Dict[Tuple["NODE_TYPE", int], Union[float, str]]]): The dictionary mapping action nodes to labels (either a float or a str). If provided as bool, then the default labels are used. By default, no label is plotted.
- int_labels_offset_x (int): x-axis offset for the node labels. By default, the offset is set to 10.
- int_labels_offset_y (int): y-axis offset for the node labels. By default, the offset is set to 10.
- continuous_form (bool): If True and the MDP is episodic, the continuous form of the MDP is used. By default, the continuous form is used.
- prog (str): The default program to create the graph layout to be passed to the networkx library. By default, 'neato' is used.
- ncol (int): The number of columns in the legend. By default, four columns are used.
- title (str): The title to be given to the plot. By default, no title is used.
- legend_fontsize (int): The font size of the legend. Vy default, it is not specified.
- font_color_state_labels (str): The color code for the labels of the state nodes. By default, it is set to 'k'.
- font_color_state_actions_labels (str): The color code for the labels of the action nodes. By default, it is set to 'k'.
- cm_state_labels (LinearSegmentedColormap): The matplotlib color map for the labels of the state nodes. By default, it is not specified.
- cm_state_actions_labels (LinearSegmentedColormap): The matplotlib color map for the labels of the action nodes. By default, it is not specified.
- no_written_state_labels (bool): If True, the labels for the node states are not shown. By default, it is set to True.
- no_written_state_action_labels (bool): If True, the labels for the node actions are not shown. By default, it is set to True.
- node_size (int): The size of the node in the plot. By default, it is set to 100.
def
plot_MCGraph( mdp: Union[colosseum.mdp.base_infinite.ContinuousMDP, colosseum.mdp.base_finite.EpisodicMDP], node_palette: List[Tuple[float, float, float]] = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), (0.8666666666666667, 0.5176470588235295, 0.3215686274509804), (0.3333333333333333, 0.6588235294117647, 0.40784313725490196), (0.7686274509803922, 0.3058823529411765, 0.3215686274509804), (0.5058823529411764, 0.4470588235294118, 0.7019607843137254), (0.5764705882352941, 0.47058823529411764, 0.3764705882352941), (0.8549019607843137, 0.5450980392156862, 0.7647058823529411), (0.5490196078431373, 0.5490196078431373, 0.5490196078431373), (0.8, 0.7254901960784313, 0.4549019607843137), (0.39215686274509803, 0.7098039215686275, 0.803921568627451)], labels: Union[bool, Dict[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], Union[float, str]]] = {}, font_color_labels='k', save_file: str = None, ax: matplotlib.axes._axes.Axes = None, figsize: Tuple[int, int] = None, prog: str = None, legend_fontsize: int = None, node_size=100, cm_state_labels=None, no_written_state_labels=True):
274def plot_MCGraph( 275 mdp: Union["ContinuousMDP", "EpisodicMDP"], 276 node_palette: List[Tuple[float, float, float]] = sns.color_palette("deep"), 277 labels: Union[bool, Dict["NODE_TYPE", Union[float, str]]] = {}, 278 font_color_labels="k", 279 save_file: str = None, 280 ax: plt.Axes = None, 281 figsize: Tuple[int, int] = None, 282 prog: str = None, 283 legend_fontsize: int = None, 284 node_size=100, 285 cm_state_labels=None, 286 no_written_state_labels=True, 287): 288 """ 289 In the Markov chain-based representation, states are associated to nodes and actions are not visualized. Each state 290 is connected to states such that there exists at least one action with corresponding transition probability greater 291 than zero. 292 293 Parameters 294 ---------- 295 mdp : Union["ContinuousMDP", "EpisodicMDP"] 296 The MDP to be visualized. 297 node_palette : List[Tuple[float, float, float]] 298 The color palette for the nodes. By default, the seaborn 'bright' palette is used. 299 labels : Union[bool, Dict["NODE_TYPE", Union[float, str]]] 300 The dictionary mapping nodes to labels (either a float or a str). If provided as bool, then the default labels 301 are used. By default, no label is plotted. 302 font_color_labels 303 save_file : str 304 The path file where the visualization will be store. Be default, the visualization will not be stored. 305 ax : plt.Axes 306 The ax object where the plot will be put. By default, a new axis is created. 307 figsize : Tuple[int, int] 308 The size of the figures in the grid of plots. By default, None is provided. 309 prog : str 310 The default program to create the graph layout to be passed to the networkx library. By default, it is not 311 specified. 312 legend_fontsize : int 313 The font size of the legend. Vy default, it is not specified. 314 node_size : int 315 The size of the node in the plot. By default, it is set to 100. 316 cm_state_labels : LinearSegmentedColormap 317 The matplotlib color map for the labels of the state nodes. By default, it is not specified. 318 no_written_state_labels : bool 319 If True, the labels for the nodes are not shown. By default, it is set to True. 320 321 """ 322 323 show = ax is None 324 325 _, R = mdp.transition_matrix_and_rewards 326 327 if cm_state_labels is not None: 328 node_color = [ 329 cm_state_labels(labels[n] / max(labels.values())) for n in mdp.G.nodes 330 ] 331 else: 332 node_color = [ 333 node_palette[0] # brown 334 if node in mdp.starting_nodes 335 else node_palette[2] # green 336 if R[mdp.node_to_index[node]].max() == R.max() 337 else node_palette[1] # yellow 338 if node in mdp.recurrent_nodes_set 339 else node_palette[-1] # grey 340 for node in mdp.G.nodes 341 ] 342 343 if ax is None: 344 ax = _create_ax(mdp.graph_layout, figsize) 345 346 if cm_state_labels is None: 347 x, y = list(mdp.graph_layout.values())[0] 348 ax.scatter(x, y, color=node_palette[2], label="Highly rewarding state") 349 ax.scatter(x, y, color=node_palette[1], label="State") 350 if mdp.communication_class == MDPCommunicationClass.WEAKLY_COMMUNICATING: 351 ax.scatter(x, y, color=node_palette[-1], label="Transient state") 352 ax.scatter(x, y, color=node_palette[0], label="Starting state") 353 354 nx.draw( 355 mdp.G, 356 mdp.graph_layout 357 if prog is None 358 else nx.nx_agraph.graphviz_layout(mdp.G, prog=prog), 359 node_color=node_color, 360 node_size=node_size, 361 edgecolors="black", 362 edge_color=node_palette[-3], 363 labels={} 364 if cm_state_labels is not None and no_written_state_labels 365 else labels, 366 font_color=font_color_labels, 367 ax=ax, 368 ) 369 370 if cm_state_labels is None: 371 ax.legend(fontsize=legend_fontsize) 372 if save_file is not None: 373 plt.savefig(save_file) 374 if show: 375 plt.show()
In the Markov chain-based representation, states are associated to nodes and actions are not visualized. Each state is connected to states such that there exists at least one action with corresponding transition probability greater than zero.
Parameters
- mdp (Union["ContinuousMDP", "EpisodicMDP"]): The MDP to be visualized.
- node_palette (List[Tuple[float, float, float]]): The color palette for the nodes. By default, the seaborn 'bright' palette is used.
- labels (Union[bool, Dict["NODE_TYPE", Union[float, str]]]): The dictionary mapping nodes to labels (either a float or a str). If provided as bool, then the default labels are used. By default, no label is plotted.
- font_color_labels
- save_file (str): The path file where the visualization will be store. Be default, the visualization will not be stored.
- ax (plt.Axes): The ax object where the plot will be put. By default, a new axis is created.
- figsize (Tuple[int, int]): The size of the figures in the grid of plots. By default, None is provided.
- prog (str): The default program to create the graph layout to be passed to the networkx library. By default, it is not specified.
- legend_fontsize (int): The font size of the legend. Vy default, it is not specified.
- node_size (int): The size of the node in the plot. By default, it is set to 100.
- cm_state_labels (LinearSegmentedColormap): The matplotlib color map for the labels of the state nodes. By default, it is not specified.
- no_written_state_labels (bool): If True, the labels for the nodes are not shown. By default, it is set to True.