colosseum.mdp.utils.communication_class

  1from __future__ import annotations
  2
  3import warnings
  4from copy import deepcopy
  5from enum import IntEnum
  6from functools import reduce
  7from typing import TYPE_CHECKING, Iterable
  8
  9import networkx as nx
 10import numba
 11import numpy as np
 12from numba import bool_, types
 13from numba.core.errors import NumbaTypeSafetyWarning
 14from numba.typed import Dict
 15from tqdm import trange
 16
 17import colosseum.config as config
 18
 19warnings.simplefilter("ignore", category=NumbaTypeSafetyWarning)
 20
 21
 22if TYPE_CHECKING:
 23    from colosseum.mdp import NODE_TYPE
 24
 25
 26class MDPCommunicationClass(IntEnum):
 27    """
 28    The MDP communication classes.
 29    """
 30
 31    ERGODIC = 0
 32    """The ergodic communication class."""
 33    COMMUNICATING = 1
 34    """The communicating communication class."""
 35    WEAKLY_COMMUNICATING = 2
 36    """The weakly-communicating communication class."""
 37
 38
 39def get_recurrent_nodes_set(
 40    communication_type: MDPCommunicationClass, G: nx.DiGraph
 41) -> Iterable[NODE_TYPE]:
 42    """
 43    Returns
 44    -------
 45    Iterable[NODE_TYPE]
 46        The recurrent states set. Note that for ergodic and communicating MDPs this corresponds to the state space.
 47    """
 48    if communication_type == MDPCommunicationClass.WEAKLY_COMMUNICATING:
 49        c = nx.condensation(G)
 50        leaf_nodes = [x for x in c.nodes() if c.out_degree(x) == 0]
 51        assert len(leaf_nodes) == 1
 52        return c.nodes(data="members")[leaf_nodes[0]]
 53    return G.nodes
 54
 55
 56def get_communication_class(T: np.ndarray, G: nx.DiGraph) -> MDPCommunicationClass:
 57    """
 58    Returns
 59    -------
 60    MDPCommunicationClass
 61        The communication class for the MDP.
 62    """
 63    if T.ndim == 4:  # episodic MDP
 64        assert (
 65            len(list(G.nodes)[0]) == 2
 66        ), "For an episodic MDP, you must input a episodic graph form."
 67        return _get_episodic_MDP_class(T, G)
 68    return _get_continuous_MDP_class(T)
 69
 70
 71def _get_episodic_MDP_class(T, episodic_graph: nx.DiGraph):
 72    G = episodic_graph.copy()
 73    for (h, u), (hp1, v) in episodic_graph.edges():
 74        if not (T[h, u, :, v] > 0).all():
 75            G.remove_edge((h, u), (hp1, v))
 76
 77    if _check_ergodicity(G, T, True):
 78        return MDPCommunicationClass.ERGODIC
 79    return (
 80        MDPCommunicationClass.COMMUNICATING
 81    )  # if an episodic MDP is not ergodic is, by definition, communicating
 82
 83
 84def _get_continuous_MDP_class(T):
 85    return _calculate_MDP_class(T)
 86
 87
 88def _calculate_MDP_class(T) -> MDPCommunicationClass:
 89    G_1 = nx.DiGraph(np.all(T > 0, axis=1))
 90    if _check_ergodicity(G_1, T, False):
 91        return MDPCommunicationClass.ERGODIC
 92
 93    G_2 = nx.DiGraph(np.any(T > 0, axis=1))
 94    G_2.remove_edges_from(nx.selfloop_edges(G_2))
 95    sccs = list(nx.strongly_connected_components(G_2))
 96    if len(sccs) == 1:
 97        return MDPCommunicationClass.COMMUNICATING
 98
 99    m = 0
100    T_ = set()
101    R = []
102    for C_k in sccs:
103        is_closed = not np.any(
104            np.delete(
105                T[
106                    list(C_k),
107                ],
108                list(C_k),
109                axis=-1,
110            )
111            > 0
112        )
113        if is_closed:
114            m += 1
115            R.append(C_k)
116        else:
117            T_ = T_.union(C_k)
118
119    if m == 1:
120        return MDPCommunicationClass.WEAKLY_COMMUNICATING
121
122    return MDPCommunicationClass.NON_WEAKLY_COMMUNICATING
123
124
125def _condense_mpd_graph_old(G_ccs, T):
126    adj = np.zeros((len(G_ccs.keys()), len(G_ccs.keys())), dtype=bool)
127    for k in G_ccs:
128        for l in G_ccs:
129            if k == l:
130                continue
131            if (
132                T[np.array(G_ccs[k]).reshape(-1, 1), :, np.array(G_ccs[l])]
133                .sum(1)
134                .min(1)
135                .max()
136                > 0
137            ):
138                adj[k, l] = True
139    return adj
140
141
142@numba.njit()
143def _condense_mpd_graph(G_ccs, T, d):
144    _, n_actions, _ = T.shape
145    adj = np.zeros((d, d), dtype=bool_)
146    for k in G_ccs:
147        for l in G_ccs:
148            if k == l:
149                continue
150            M = np.zeros(len(G_ccs[k]), np.float32)
151            for i, r in enumerate(G_ccs[k]):
152                min_a = np.inf
153                for a in range(n_actions):
154                    summation = 0.0
155                    for s in G_ccs[l]:
156                        summation += T[r, a, s]
157                        if summation > min_a:
158                            break
159                    min_a = min(min_a, summation)
160                M[i] = min_a
161            if M.max() > 0:
162                adj[k, l] = True
163    return adj
164
165
166@numba.njit()
167def _condense_mpd_graph_episodic(G_ccs, T, d):
168    H, _, n_actions, _ = T.shape
169    adj = np.zeros((d, d), dtype=bool_)
170    for k in G_ccs:
171        for l in G_ccs:
172            if k == l:
173                continue
174            M = np.zeros(len(G_ccs[k]), np.float32)
175            for i, (hr, r) in enumerate(G_ccs[k]):
176                min_a = np.inf
177                for a in range(n_actions):
178                    summation = 0.0
179                    for hs, s in G_ccs[l]:
180                        if hr + 1 == hs or (hr + 1 == H and hs == 0):
181                            summation += T[hr, r, a, s]
182                            if summation > min_a:
183                                break
184                    min_a = min(min_a, summation)
185                M[i] = min_a
186            if M.max() > 0:
187                adj[k, l] = True
188    return adj
189
190
191def _get_ultimate_condensation(G, T, is_episodic=False):
192    mapping = {i: tuple(cc) for i, cc in enumerate(nx.strongly_connected_components(G))}
193
194    loop = (
195        trange(1_000_000, desc="Communication class calculation", mininterval=5)
196        if config.VERBOSE_LEVEL > 0
197        else range(1_000_000)
198    )
199    for _ in loop:
200        old_mapping = deepcopy(mapping)
201        if is_episodic:
202            d = Dict.empty(
203                key_type=types.int16, value_type=types.Array(types.int16, 2, "A")
204            )
205            for k, v in mapping.items():
206                d[k] = np.array(v).reshape(-1, 2).astype(np.int16)
207            new_G_c = nx.DiGraph(_condense_mpd_graph_episodic(d, T, len(mapping)))
208        else:
209            d = Dict.empty(
210                key_type=types.int16, value_type=types.Array(types.int16, 1, "A")
211            )
212            for k, v in mapping.items():
213                d[k] = np.array(v).astype(np.int16)
214            new_G_c = nx.DiGraph(_condense_mpd_graph(d, T, len(mapping)))
215
216        new_mapping = {
217            i: reduce(lambda x, y: x + y, (mapping[c] for c in cc))
218            for i, cc in enumerate(nx.strongly_connected_components(new_G_c))
219        }
220        if old_mapping == new_mapping:
221            return new_mapping
222        mapping = new_mapping
223
224
225def _check_ergodicity(G_1, T, is_episodic):
226    G_1.remove_edges_from(nx.selfloop_edges(G_1))
227    G_1_c_star_mapping = _get_ultimate_condensation(G_1, T, is_episodic=is_episodic)
228    if len(G_1_c_star_mapping.keys()) == 1:
229        return True
230    return False
class MDPCommunicationClass(enum.IntEnum):
27class MDPCommunicationClass(IntEnum):
28    """
29    The MDP communication classes.
30    """
31
32    ERGODIC = 0
33    """The ergodic communication class."""
34    COMMUNICATING = 1
35    """The communicating communication class."""
36    WEAKLY_COMMUNICATING = 2
37    """The weakly-communicating communication class."""

The MDP communication classes.

The ergodic communication class.

The communicating communication class.

WEAKLY_COMMUNICATING = <MDPCommunicationClass.WEAKLY_COMMUNICATING: 2>

The weakly-communicating communication class.

Inherited Members
enum.Enum
name
value
builtins.int
conjugate
bit_length
to_bytes
from_bytes
as_integer_ratio
real
imag
numerator
denominator
40def get_recurrent_nodes_set(
41    communication_type: MDPCommunicationClass, G: nx.DiGraph
42) -> Iterable[NODE_TYPE]:
43    """
44    Returns
45    -------
46    Iterable[NODE_TYPE]
47        The recurrent states set. Note that for ergodic and communicating MDPs this corresponds to the state space.
48    """
49    if communication_type == MDPCommunicationClass.WEAKLY_COMMUNICATING:
50        c = nx.condensation(G)
51        leaf_nodes = [x for x in c.nodes() if c.out_degree(x) == 0]
52        assert len(leaf_nodes) == 1
53        return c.nodes(data="members")[leaf_nodes[0]]
54    return G.nodes
Returns
  • Iterable[NODE_TYPE]: The recurrent states set. Note that for ergodic and communicating MDPs this corresponds to the state space.
def get_communication_class( T: numpy.ndarray, G: networkx.classes.digraph.DiGraph) -> colosseum.mdp.utils.communication_class.MDPCommunicationClass:
57def get_communication_class(T: np.ndarray, G: nx.DiGraph) -> MDPCommunicationClass:
58    """
59    Returns
60    -------
61    MDPCommunicationClass
62        The communication class for the MDP.
63    """
64    if T.ndim == 4:  # episodic MDP
65        assert (
66            len(list(G.nodes)[0]) == 2
67        ), "For an episodic MDP, you must input a episodic graph form."
68        return _get_episodic_MDP_class(T, G)
69    return _get_continuous_MDP_class(T)
Returns
  • MDPCommunicationClass: The communication class for the MDP.