colosseum.mdp.utils.communication_class
1from __future__ import annotations 2 3import warnings 4from copy import deepcopy 5from enum import IntEnum 6from functools import reduce 7from typing import TYPE_CHECKING, Iterable 8 9import networkx as nx 10import numba 11import numpy as np 12from numba import bool_, types 13from numba.core.errors import NumbaTypeSafetyWarning 14from numba.typed import Dict 15from tqdm import trange 16 17import colosseum.config as config 18 19warnings.simplefilter("ignore", category=NumbaTypeSafetyWarning) 20 21 22if TYPE_CHECKING: 23 from colosseum.mdp import NODE_TYPE 24 25 26class MDPCommunicationClass(IntEnum): 27 """ 28 The MDP communication classes. 29 """ 30 31 ERGODIC = 0 32 """The ergodic communication class.""" 33 COMMUNICATING = 1 34 """The communicating communication class.""" 35 WEAKLY_COMMUNICATING = 2 36 """The weakly-communicating communication class.""" 37 38 39def get_recurrent_nodes_set( 40 communication_type: MDPCommunicationClass, G: nx.DiGraph 41) -> Iterable[NODE_TYPE]: 42 """ 43 Returns 44 ------- 45 Iterable[NODE_TYPE] 46 The recurrent states set. Note that for ergodic and communicating MDPs this corresponds to the state space. 47 """ 48 if communication_type == MDPCommunicationClass.WEAKLY_COMMUNICATING: 49 c = nx.condensation(G) 50 leaf_nodes = [x for x in c.nodes() if c.out_degree(x) == 0] 51 assert len(leaf_nodes) == 1 52 return c.nodes(data="members")[leaf_nodes[0]] 53 return G.nodes 54 55 56def get_communication_class(T: np.ndarray, G: nx.DiGraph) -> MDPCommunicationClass: 57 """ 58 Returns 59 ------- 60 MDPCommunicationClass 61 The communication class for the MDP. 62 """ 63 if T.ndim == 4: # episodic MDP 64 assert ( 65 len(list(G.nodes)[0]) == 2 66 ), "For an episodic MDP, you must input a episodic graph form." 67 return _get_episodic_MDP_class(T, G) 68 return _get_continuous_MDP_class(T) 69 70 71def _get_episodic_MDP_class(T, episodic_graph: nx.DiGraph): 72 G = episodic_graph.copy() 73 for (h, u), (hp1, v) in episodic_graph.edges(): 74 if not (T[h, u, :, v] > 0).all(): 75 G.remove_edge((h, u), (hp1, v)) 76 77 if _check_ergodicity(G, T, True): 78 return MDPCommunicationClass.ERGODIC 79 return ( 80 MDPCommunicationClass.COMMUNICATING 81 ) # if an episodic MDP is not ergodic is, by definition, communicating 82 83 84def _get_continuous_MDP_class(T): 85 return _calculate_MDP_class(T) 86 87 88def _calculate_MDP_class(T) -> MDPCommunicationClass: 89 G_1 = nx.DiGraph(np.all(T > 0, axis=1)) 90 if _check_ergodicity(G_1, T, False): 91 return MDPCommunicationClass.ERGODIC 92 93 G_2 = nx.DiGraph(np.any(T > 0, axis=1)) 94 G_2.remove_edges_from(nx.selfloop_edges(G_2)) 95 sccs = list(nx.strongly_connected_components(G_2)) 96 if len(sccs) == 1: 97 return MDPCommunicationClass.COMMUNICATING 98 99 m = 0 100 T_ = set() 101 R = [] 102 for C_k in sccs: 103 is_closed = not np.any( 104 np.delete( 105 T[ 106 list(C_k), 107 ], 108 list(C_k), 109 axis=-1, 110 ) 111 > 0 112 ) 113 if is_closed: 114 m += 1 115 R.append(C_k) 116 else: 117 T_ = T_.union(C_k) 118 119 if m == 1: 120 return MDPCommunicationClass.WEAKLY_COMMUNICATING 121 122 return MDPCommunicationClass.NON_WEAKLY_COMMUNICATING 123 124 125def _condense_mpd_graph_old(G_ccs, T): 126 adj = np.zeros((len(G_ccs.keys()), len(G_ccs.keys())), dtype=bool) 127 for k in G_ccs: 128 for l in G_ccs: 129 if k == l: 130 continue 131 if ( 132 T[np.array(G_ccs[k]).reshape(-1, 1), :, np.array(G_ccs[l])] 133 .sum(1) 134 .min(1) 135 .max() 136 > 0 137 ): 138 adj[k, l] = True 139 return adj 140 141 142@numba.njit() 143def _condense_mpd_graph(G_ccs, T, d): 144 _, n_actions, _ = T.shape 145 adj = np.zeros((d, d), dtype=bool_) 146 for k in G_ccs: 147 for l in G_ccs: 148 if k == l: 149 continue 150 M = np.zeros(len(G_ccs[k]), np.float32) 151 for i, r in enumerate(G_ccs[k]): 152 min_a = np.inf 153 for a in range(n_actions): 154 summation = 0.0 155 for s in G_ccs[l]: 156 summation += T[r, a, s] 157 if summation > min_a: 158 break 159 min_a = min(min_a, summation) 160 M[i] = min_a 161 if M.max() > 0: 162 adj[k, l] = True 163 return adj 164 165 166@numba.njit() 167def _condense_mpd_graph_episodic(G_ccs, T, d): 168 H, _, n_actions, _ = T.shape 169 adj = np.zeros((d, d), dtype=bool_) 170 for k in G_ccs: 171 for l in G_ccs: 172 if k == l: 173 continue 174 M = np.zeros(len(G_ccs[k]), np.float32) 175 for i, (hr, r) in enumerate(G_ccs[k]): 176 min_a = np.inf 177 for a in range(n_actions): 178 summation = 0.0 179 for hs, s in G_ccs[l]: 180 if hr + 1 == hs or (hr + 1 == H and hs == 0): 181 summation += T[hr, r, a, s] 182 if summation > min_a: 183 break 184 min_a = min(min_a, summation) 185 M[i] = min_a 186 if M.max() > 0: 187 adj[k, l] = True 188 return adj 189 190 191def _get_ultimate_condensation(G, T, is_episodic=False): 192 mapping = {i: tuple(cc) for i, cc in enumerate(nx.strongly_connected_components(G))} 193 194 loop = ( 195 trange(1_000_000, desc="Communication class calculation", mininterval=5) 196 if config.VERBOSE_LEVEL > 0 197 else range(1_000_000) 198 ) 199 for _ in loop: 200 old_mapping = deepcopy(mapping) 201 if is_episodic: 202 d = Dict.empty( 203 key_type=types.int16, value_type=types.Array(types.int16, 2, "A") 204 ) 205 for k, v in mapping.items(): 206 d[k] = np.array(v).reshape(-1, 2).astype(np.int16) 207 new_G_c = nx.DiGraph(_condense_mpd_graph_episodic(d, T, len(mapping))) 208 else: 209 d = Dict.empty( 210 key_type=types.int16, value_type=types.Array(types.int16, 1, "A") 211 ) 212 for k, v in mapping.items(): 213 d[k] = np.array(v).astype(np.int16) 214 new_G_c = nx.DiGraph(_condense_mpd_graph(d, T, len(mapping))) 215 216 new_mapping = { 217 i: reduce(lambda x, y: x + y, (mapping[c] for c in cc)) 218 for i, cc in enumerate(nx.strongly_connected_components(new_G_c)) 219 } 220 if old_mapping == new_mapping: 221 return new_mapping 222 mapping = new_mapping 223 224 225def _check_ergodicity(G_1, T, is_episodic): 226 G_1.remove_edges_from(nx.selfloop_edges(G_1)) 227 G_1_c_star_mapping = _get_ultimate_condensation(G_1, T, is_episodic=is_episodic) 228 if len(G_1_c_star_mapping.keys()) == 1: 229 return True 230 return False
class
MDPCommunicationClass(enum.IntEnum):
27class MDPCommunicationClass(IntEnum): 28 """ 29 The MDP communication classes. 30 """ 31 32 ERGODIC = 0 33 """The ergodic communication class.""" 34 COMMUNICATING = 1 35 """The communicating communication class.""" 36 WEAKLY_COMMUNICATING = 2 37 """The weakly-communicating communication class."""
The MDP communication classes.
WEAKLY_COMMUNICATING = <MDPCommunicationClass.WEAKLY_COMMUNICATING: 2>
The weakly-communicating communication class.
Inherited Members
- enum.Enum
- name
- value
- builtins.int
- conjugate
- bit_length
- to_bytes
- from_bytes
- as_integer_ratio
- real
- imag
- numerator
- denominator
def
get_recurrent_nodes_set( communication_type: colosseum.mdp.utils.communication_class.MDPCommunicationClass, G: networkx.classes.digraph.DiGraph) -> Iterable[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode]]:
40def get_recurrent_nodes_set( 41 communication_type: MDPCommunicationClass, G: nx.DiGraph 42) -> Iterable[NODE_TYPE]: 43 """ 44 Returns 45 ------- 46 Iterable[NODE_TYPE] 47 The recurrent states set. Note that for ergodic and communicating MDPs this corresponds to the state space. 48 """ 49 if communication_type == MDPCommunicationClass.WEAKLY_COMMUNICATING: 50 c = nx.condensation(G) 51 leaf_nodes = [x for x in c.nodes() if c.out_degree(x) == 0] 52 assert len(leaf_nodes) == 1 53 return c.nodes(data="members")[leaf_nodes[0]] 54 return G.nodes
Returns
- Iterable[NODE_TYPE]: The recurrent states set. Note that for ergodic and communicating MDPs this corresponds to the state space.
def
get_communication_class( T: numpy.ndarray, G: networkx.classes.digraph.DiGraph) -> colosseum.mdp.utils.communication_class.MDPCommunicationClass:
57def get_communication_class(T: np.ndarray, G: nx.DiGraph) -> MDPCommunicationClass: 58 """ 59 Returns 60 ------- 61 MDPCommunicationClass 62 The communication class for the MDP. 63 """ 64 if T.ndim == 4: # episodic MDP 65 assert ( 66 len(list(G.nodes)[0]) == 2 67 ), "For an episodic MDP, you must input a episodic graph form." 68 return _get_episodic_MDP_class(T, G) 69 return _get_continuous_MDP_class(T)
Returns
- MDPCommunicationClass: The communication class for the MDP.