
  1import os
  2from typing import Iterable, List, Optional, Tuple
  4import networkx as nx
  5import numba
  6import numpy as np
  7import scipy
  8from pydtmc import MarkovChain
  9from scipy.sparse import coo_matrix, csr_matrix
 12def get_average_reward(
 13    T: np.ndarray,
 14    R: np.ndarray,
 15    policy: np.ndarray,
 16    next_states_and_probs: Optional,
 17    sparse_threshold_size: int = 500 * 500,
 18) -> float:
 19    """
 20    Returns
 21    -------
 22    float
 23        The expected average reward when following policy for the MDP defined by the given transition matrix and
 24    rewards matrix.
 25    """
 26    assert np.isclose(policy.sum(-1), 1).all(), "the policy specification is incorrect."
 28    average_rewards = get_average_rewards(R, policy)
 29    tps = get_transition_probabilities(T, policy)
 30    sd = get_stationary_distribution(tps, next_states_and_probs, sparse_threshold_size)
 31    return (average_rewards * sd).sum()
 34def get_average_rewards(R: np.ndarray, policy: np.ndarray) -> np.ndarray:
 35    """
 36    Returns
 37    -------
 38    np.ndarray
 39        The expected rewards for each state when following the given policy.
 40    """
 41    return np.einsum("sa,sa->s", R, policy)
 44def get_transition_probabilities(T: np.ndarray, policy: np.ndarray) -> np.ndarray:
 45    """
 46    Returns
 47    -------
 48    np.ndarray
 49        The transition probability matrix of the Markov chain yielded by the given policy.
 50    """
 51    return np.minimum(1.0, np.einsum("saj,sa->sj", T, policy))
 54def get_markov_chain(transition_probabilities: np.ndarray) -> MarkovChain:
 55    """
 56    Returns
 57    -------
 58    MarkovChain
 59        The Markov chain object from the pydtmc package.
 60    """
 61    return MarkovChain(transition_probabilities)
 64def get_stationary_distribution(
 65    tps: np.ndarray,
 66    starting_states_and_probs: Iterable[Tuple[int, float]],
 67    sparse_threshold_size: int = 500 * 500,
 68) -> np.ndarray:
 69    """
 70    returns the stationary distribution of the transition matrix. If there are multiple recurrent classes, and so
 71    multiple stationary distribution, the return stationary distribution is the average of the stationary distributions
 72    weighted using the starting state distribution.
 74    Parameters
 75    ----------
 76    tps : np.ndarray
 77        The transition probabilities matrix.
 78    starting_states_and_probs : List[Tuple[int, float]]
 79        The iterable over the starting states and their corresponding probabilities.
 80    sparse_threshold_size : int
 81        The threshold for the size of the transition probabilities matrix that flags whether it is better to use sparse
 82        matrices.
 84    Returns
 85    -------
 86    np.ndarray
 87        The stationary distribution of the transition matrix.
 88    """
 89    if tps.size > sparse_threshold_size:
 90        G = nx.DiGraph(coo_matrix(tps))
 91    else:
 92        G = nx.DiGraph(tps)
 94    # Obtain the recurrent classes
 95    recurrent_classes = list(map(tuple, nx.attracting_components(G)))
 96    if len(recurrent_classes) == 1 and len(recurrent_classes[0]) < len(tps):
 97        sd = np.zeros(len(tps), np.float32)
 98        if len(recurrent_classes[0]) == 1:
 99            sd[recurrent_classes[0][0]] = 1
100        else:
101            sd[list(recurrent_classes[0])] = _get_stationary_distribution(
102                tps[np.ix_(recurrent_classes[0], recurrent_classes[0])],
103                sparse_threshold_size,
104            )
105        return sd
107    elif len(recurrent_classes) > 1 and len(recurrent_classes[0]) < len(tps):
109        sd = np.zeros(len(tps))
110        if len(recurrent_classes) > 1:
111            # Weight the stationary distribution of the recurrent classes with starting states distribution
112            for ss, p in starting_states_and_probs:
113                for recurrent_class in recurrent_classes:
114                    try:
115                        # this means that the starting state ss is connected to recurrent_class
116                        nx.shortest_path_length(G, ss, recurrent_class[0])
118                        # Weighting the stationary distribution with the probability of the starting state
119                        sd[list(recurrent_class)] += p * _get_stationary_distribution(
120                            tps[np.ix_(recurrent_class, recurrent_class)],
121                            sparse_threshold_size,
122                        )
123                        break
124                    except nx.exception.NetworkXNoPath:
125                        pass
126        else:
127            # No need to weight with the starting state distribution since there is only one recurrent class
128            sd[list(recurrent_classes[0])] += _get_stationary_distribution(
129                tps[np.ix_(recurrent_classes[0], recurrent_classes[0])],
130                sparse_threshold_size,
131            )
133        return sd
135    sd = _get_stationary_distribution(tps)
136    return sd
140def _gth_solve_numba(tps: np.ndarray) -> np.ndarray:
141    """
142    returns the stationary distribution of a transition probabilities matrix with a single recurrent class using the
143    GTH method.
144    """
145    a = np.copy(tps).astype(np.float64)
146    n = a.shape[0]
148    for i in range(n - 1):
149        scale = np.sum(a[i, i + 1 : n])
151        if scale <= 0.0:  # pragma: no cover
152            n = i + 1
153            break
155        a[i + 1 : n, i] /= scale
156        a[i + 1 : n, i + 1 : n] += np.outer(
157            a[i + 1 : n, i : i + 1], a[i : i + 1, i + 1 : n]
158        )
160    x = np.zeros(n, np.float64)
161    x[n - 1] = 1.0
162    x[n - 2] = a[n - 1, n - 2]
163    for i in range(n - 3, -1, -1):
164        x[i] = np.sum(x[i + 1 : n] * a[i + 1 : n, i])
165    x /= np.sum(x)
166    return x
169def _convertToRateMatrix(tps: csr_matrix):
170    """
171    converts the initial matrix to a rate matrix. We make all rows in Q sum to zero by subtracting the row sums from the
172    diagonal.
173    """
174    rowSums = tps.sum(axis=1).getA1()
175    idxRange = np.arange(tps.shape[0])
176    Qdiag = coo_matrix((rowSums, (idxRange, idxRange)), shape=tps.shape).tocsr()
177    return tps - Qdiag
180def _eigen_method(tps, tol=1e-8, maxiter=1e5):
181    """
182    returns the stationary distribution of a transition probabilities matrix with a single recurrent class using the
183    eigenvalue method.
184    """
185    Q = _convertToRateMatrix(tps)
186    size = Q.shape[0]
187    guess = np.ones(size, dtype=float)
188    w, v = scipy.sparse.linalg.eigs(
189        Q.T, k=1, v0=guess, sigma=1e-6, which="LM", tol=tol, maxiter=maxiter
190    )
191    pi = v[:, 0].real
192    pi /= pi.sum()
193    return np.maximum(pi, 0.0)
196def _get_stationary_distribution(
197    tps: np.ndarray, sparse_threshold_size: int = 500 * 500
198) -> np.ndarray:
199    os.makedirs("tmp", exist_ok=True)
201    if len(tps) == 1:
202        return np.ones(1, np.float32)
204    if tps.size > sparse_threshold_size:
205        sd = _eigen_method(csr_matrix(tps))
206        if np.isnan(sd).any() or not np.isclose(sd.sum(), 1.0):
207            # sometimes the eigen method fails so we use gth that is slower but more reliable
208            os.makedirs("tmp/sd_failures", exist_ok=True)
209            for i in range(1000):
210                if not os.path.isfile(f"tmp/sd_failures/tps{i}.npy"):
211          "tmp/sd_failures/tps{i}.npy", tps)
212                    break
214            sd = _gth_solve_numba(tps)
215            if not np.isclose(sd.sum(), 1.0) and np.isclose(sd.sum(), 1, rtol=4):
216                sd /= sd.sum()
217            assert not (np.isnan(sd).any() or not np.isclose(sd.sum(), 1.0)),
218                "tmp/tps.npy", tps
219            )
220            return sd
222    sd = _gth_solve_numba(tps)
223    if np.isnan(sd).any() or not np.isclose(sd.sum(), 1.0):
224"tmp/tps.npy", tps)
226        tps = tps / tps.sum(1, keepdims=True)
227        sd = _eigen_method(csr_matrix(tps))
228        if not np.isclose(sd.sum(), 1.0) and np.isclose(sd.sum(), 1, rtol=4):
229            sd /= sd.sum()
230        assert not (np.isnan(sd).any() or not np.isclose(sd.sum(), 1.0)),
231            "tmp/tps.npy", tps
232        )
233    return sd
