colosseum.dynamic_programming.infinite_horizon

  1from typing import Tuple, Union
  2
  3import numba
  4import numpy as np
  5import sparse
  6
  7from colosseum.dynamic_programming import DP_MAX_ITERATION
  8from colosseum.dynamic_programming.utils import (
  9    DynamicProgrammingMaxIterationExceeded,
 10    argmax_2d,
 11)
 12
 13
 14def discounted_value_iteration(
 15    T: Union[np.ndarray, sparse.COO],
 16    R: np.ndarray,
 17    gamma=0.99,
 18    epsilon=1e-3,
 19    max_abs_value: float = None,
 20    sparse_n_states_threshold: int = 300 * 3 * 300,
 21    sparse_nnz_per_threshold: float = 0.2,
 22) -> Tuple[np.ndarray, np.ndarray]:
 23    n_states, n_actions, _ = T.shape
 24
 25    if type(T) == sparse.COO:
 26        return _discounted_value_iteration_sparse(T, R, gamma, epsilon, max_abs_value)
 27
 28    if T.size > sparse_n_states_threshold:
 29        T_sparse = sparse.COO(T)
 30        if T_sparse.nnz / T.size < sparse_nnz_per_threshold:
 31            return _discounted_value_iteration_sparse(
 32                T_sparse, R, gamma, epsilon, max_abs_value
 33            )
 34
 35    try:
 36        res = _discounted_value_iteration(T, R, gamma, epsilon, max_abs_value)
 37    except:
 38        # Failure of discounted value iteration may randomly happen when using multiprocessing.
 39        # If that happens, we can simply resort to sparse value iteration.
 40        T_sparse = sparse.COO(T)
 41        res = _discounted_value_iteration_sparse(
 42            T_sparse, R, gamma, epsilon, max_abs_value
 43        )
 44    return res
 45
 46
 47def discounted_policy_evaluation(
 48    T: Union[np.ndarray, sparse.COO],
 49    R: np.ndarray,
 50    pi: np.ndarray,
 51    gamma=0.99,
 52    epsilon=1e-7,
 53    sparse_n_states_threshold: int = 200,
 54    sparse_nnz_per_threshold: float = 0.2,
 55) -> Tuple[np.ndarray, np.ndarray]:
 56    n_states, n_actions, _ = T.shape
 57
 58    if type(T) == sparse.COO:
 59        return _discounted_policy_evaluation_sparse(T, R, pi, gamma, epsilon)
 60    if n_states > sparse_n_states_threshold:
 61        T_sparse = sparse.COO(T)
 62        if T_sparse.nnz / T.size < sparse_nnz_per_threshold:
 63            return _discounted_policy_evaluation_sparse(T_sparse, R, pi, gamma, epsilon)
 64    return _discounted_policy_evaluation(T, R, pi, gamma, epsilon)
 65
 66
 67@numba.njit()
 68def extended_value_iteration(
 69    T: np.ndarray,
 70    estimated_rewards: np.ndarray,
 71    beta_r: np.ndarray,
 72    beta_p: np.ndarray,
 73    r_max: float,
 74    epsilon=1e-3,
 75) -> Union[None, Tuple[float, np.ndarray, np.ndarray]]:
 76    """
 77    if successful, it returns the span of the value function, the Q matrix and the V matrix. It returns None when it was
 78    not possible to complete the extended value iteration procedure.
 79    """
 80    n_states, n_actions = beta_r.shape
 81
 82    Q = np.zeros((n_states, n_actions), dtype=np.float32)
 83    V = np.zeros((n_states,), dtype=np.float32)
 84
 85    u1 = np.zeros(n_states, np.float32)
 86    sorted_indices = np.arange(n_states)
 87    u2 = np.zeros(n_states, np.float32)
 88    vec = np.zeros(n_states, np.float32)
 89
 90    for _ in range(DP_MAX_ITERATION):
 91        for s in range(n_states):
 92            first_action = True
 93            for a in range(n_actions):
 94                vec = _max_proba(
 95                    T[s, a], sorted_indices, beta_p[s, a], n_states, n_actions
 96                )
 97                vec[s] -= 1
 98                r_optimal = min(
 99                    np.float32(r_max),
100                    estimated_rewards[s, a] + beta_r[s, a],
101                )
102                v = r_optimal + np.dot(vec, u1)
103                Q[s, a] = v
104                if (
105                    first_action
106                    or v + u1[s] > u2[s]
107                    or np.abs(v + u1[s] - u2[s]) < epsilon
108                ):  # optimal policy = argmax
109                    u2[s] = np.float32(v + u1[s])
110                first_action = False
111            V[s] = np.max(Q[s])
112        if np.ptp(u2 - u1) < epsilon:  # stopping condition of EVI
113            return np.ptp(u1), Q, V
114        else:
115            u1 = u2
116            u2 = np.empty(n_states, np.float32)
117            sorted_indices = np.argsort(u1)
118    return None
119
120
121@numba.njit()
122def _discounted_value_iteration(
123    T: np.ndarray, R: np.ndarray, gamma=0.99, epsilon=1e-3, max_abs_value: float = None
124) -> Tuple[np.ndarray, np.ndarray]:
125
126    n_states, n_actions, _ = T.shape
127    gamma = np.float32(gamma)
128
129    V = np.zeros(n_states, dtype=np.float32)
130    Q = np.zeros((n_states, n_actions), dtype=np.float32)
131    for _ in range(DP_MAX_ITERATION):
132        V_old = V.copy()
133        for s in range(n_states):
134            Q[s] = R[s] + gamma * T[s] @ V
135            V[s] = Q[s].max()
136            if max_abs_value is not None:
137                if np.abs(V[s]) > max_abs_value:
138                    return None
139        diff = np.abs(V_old - V).max()
140        if diff < epsilon:
141            return Q, V
142    raise DynamicProgrammingMaxIterationExceeded()
143
144
145def _discounted_value_iteration_sparse(
146    T: sparse.COO, R: np.ndarray, gamma=0.99, epsilon=1e-3, max_abs_value: float = None
147) -> Tuple[np.ndarray, np.ndarray]:
148    n_states, n_actions, _ = T.shape
149    gamma = np.float32(gamma)
150
151    V = np.zeros(n_states, dtype=np.float32)
152    for _ in range(DP_MAX_ITERATION):
153        V_old = V.copy()
154        Q = R + gamma * (T @ V).squeeze()
155        V = Q.max(1, keepdims=True)
156
157        if max_abs_value is not None:
158            if V.abs() > max_abs_value:
159                return None
160
161        diff = np.abs(V_old.squeeze() - V.squeeze()).max()
162        if diff < epsilon:
163            return Q, V.squeeze()
164    raise DynamicProgrammingMaxIterationExceeded()
165
166
167@numba.njit()
168def _discounted_policy_evaluation(
169    T: np.ndarray, R: np.ndarray, pi: np.ndarray, gamma=0.99, epsilon=1e-7
170) -> Tuple[np.ndarray, np.ndarray]:
171    n_states, n_actions, _ = T.shape
172    gamma = np.array([gamma], dtype=np.float32)
173
174    V = np.zeros(n_states, dtype=np.float32)
175    Q = np.zeros((n_states, n_actions), dtype=np.float32)
176    for _ in range(DP_MAX_ITERATION):
177        V_old = V.copy()
178        for s in range(n_states):
179            Q[s] = R[s] + gamma * T[s] @ V
180            V[s] = (Q[s] * pi[s]).sum()
181        diff = np.abs(V_old - V).max()
182        if diff < epsilon:
183            return Q, V
184    raise DynamicProgrammingMaxIterationExceeded()
185
186
187def _discounted_policy_evaluation_sparse(
188    T: Union[np.ndarray, sparse.COO],
189    R: np.ndarray,
190    pi: np.ndarray,
191    gamma=0.99,
192    epsilon=1e-7,
193) -> Tuple[np.ndarray, np.ndarray]:
194    n_states, n_actions, _ = T.shape
195    gamma = np.array([gamma], dtype=np.float32)
196
197    V = np.zeros(n_states, dtype=np.float32)
198    for _ in range(DP_MAX_ITERATION):
199        V_old = V.copy()
200        Q = R + gamma * T @ V
201        V = (Q * pi).sum(1)
202        diff = np.abs(V_old - V).max()
203        if diff < epsilon:
204            return Q, V
205    raise DynamicProgrammingMaxIterationExceeded()
206
207
208def discounted_policy_iteration(T: np.ndarray, R: np.ndarray, gamma=0.99, epsilon=1e-7):
209    n_states, n_actions, _ = T.shape
210
211    Q = np.random.rand(n_states, n_actions)
212    pi = argmax_2d(Q)
213    for t in range(DP_MAX_ITERATION):
214        old_pi = pi.copy()
215        Q, V = discounted_policy_evaluation(T, R, pi, gamma, epsilon)
216        pi = argmax_2d(Q)
217        if (pi != old_pi).sum() == 0:
218            return Q, V, pi
219    raise DynamicProgrammingMaxIterationExceeded()
220
221
222@numba.njit()
223def _max_proba(
224    p: np.ndarray,
225    sorted_indices: np.ndarray,
226    beta: np.ndarray,
227    n_states: int,
228    n_actions: int,
229) -> np.ndarray:
230    min1 = min(1.0, (p[sorted_indices[n_states - 1]] + beta / 2)[0])
231    if min1 == 1:
232        p2 = np.zeros(n_states, np.float32)
233        p2[sorted_indices[n_states - 1]] = 1
234    else:
235        sorted_p = p[sorted_indices]
236        support_sorted_p = np.nonzero(sorted_p)[0]
237        restricted_sorted_p = sorted_p[support_sorted_p]
238        support_p = sorted_indices[support_sorted_p]
239        p2 = np.zeros(n_states, np.float32)
240        p2[support_p] = restricted_sorted_p
241        p2[sorted_indices[n_states - 1]] = min1
242        s = 1 - p[sorted_indices[n_states - 1]] + min1
243        s2 = s
244        for i, proba in enumerate(restricted_sorted_p):
245            max1 = max(0, 1 - s + proba)
246            s2 += max1 - proba
247            p2[support_p[i]] = max1
248            s = s2
249            if s <= 1:
250                break
251    return p2
def discounted_value_iteration( T: Union[numpy.ndarray, sparse._coo.core.COO], R: numpy.ndarray, gamma=0.99, epsilon=0.001, max_abs_value: float = None, sparse_n_states_threshold: int = 270000, sparse_nnz_per_threshold: float = 0.2) -> Tuple[numpy.ndarray, numpy.ndarray]:
15def discounted_value_iteration(
16    T: Union[np.ndarray, sparse.COO],
17    R: np.ndarray,
18    gamma=0.99,
19    epsilon=1e-3,
20    max_abs_value: float = None,
21    sparse_n_states_threshold: int = 300 * 3 * 300,
22    sparse_nnz_per_threshold: float = 0.2,
23) -> Tuple[np.ndarray, np.ndarray]:
24    n_states, n_actions, _ = T.shape
25
26    if type(T) == sparse.COO:
27        return _discounted_value_iteration_sparse(T, R, gamma, epsilon, max_abs_value)
28
29    if T.size > sparse_n_states_threshold:
30        T_sparse = sparse.COO(T)
31        if T_sparse.nnz / T.size < sparse_nnz_per_threshold:
32            return _discounted_value_iteration_sparse(
33                T_sparse, R, gamma, epsilon, max_abs_value
34            )
35
36    try:
37        res = _discounted_value_iteration(T, R, gamma, epsilon, max_abs_value)
38    except:
39        # Failure of discounted value iteration may randomly happen when using multiprocessing.
40        # If that happens, we can simply resort to sparse value iteration.
41        T_sparse = sparse.COO(T)
42        res = _discounted_value_iteration_sparse(
43            T_sparse, R, gamma, epsilon, max_abs_value
44        )
45    return res
def discounted_policy_evaluation( T: Union[numpy.ndarray, sparse._coo.core.COO], R: numpy.ndarray, pi: numpy.ndarray, gamma=0.99, epsilon=1e-07, sparse_n_states_threshold: int = 200, sparse_nnz_per_threshold: float = 0.2) -> Tuple[numpy.ndarray, numpy.ndarray]:
48def discounted_policy_evaluation(
49    T: Union[np.ndarray, sparse.COO],
50    R: np.ndarray,
51    pi: np.ndarray,
52    gamma=0.99,
53    epsilon=1e-7,
54    sparse_n_states_threshold: int = 200,
55    sparse_nnz_per_threshold: float = 0.2,
56) -> Tuple[np.ndarray, np.ndarray]:
57    n_states, n_actions, _ = T.shape
58
59    if type(T) == sparse.COO:
60        return _discounted_policy_evaluation_sparse(T, R, pi, gamma, epsilon)
61    if n_states > sparse_n_states_threshold:
62        T_sparse = sparse.COO(T)
63        if T_sparse.nnz / T.size < sparse_nnz_per_threshold:
64            return _discounted_policy_evaluation_sparse(T_sparse, R, pi, gamma, epsilon)
65    return _discounted_policy_evaluation(T, R, pi, gamma, epsilon)
@numba.njit()
def extended_value_iteration( T: numpy.ndarray, estimated_rewards: numpy.ndarray, beta_r: numpy.ndarray, beta_p: numpy.ndarray, r_max: float, epsilon=0.001) -> Optional[Tuple[float, numpy.ndarray, numpy.ndarray]]:
 68@numba.njit()
 69def extended_value_iteration(
 70    T: np.ndarray,
 71    estimated_rewards: np.ndarray,
 72    beta_r: np.ndarray,
 73    beta_p: np.ndarray,
 74    r_max: float,
 75    epsilon=1e-3,
 76) -> Union[None, Tuple[float, np.ndarray, np.ndarray]]:
 77    """
 78    if successful, it returns the span of the value function, the Q matrix and the V matrix. It returns None when it was
 79    not possible to complete the extended value iteration procedure.
 80    """
 81    n_states, n_actions = beta_r.shape
 82
 83    Q = np.zeros((n_states, n_actions), dtype=np.float32)
 84    V = np.zeros((n_states,), dtype=np.float32)
 85
 86    u1 = np.zeros(n_states, np.float32)
 87    sorted_indices = np.arange(n_states)
 88    u2 = np.zeros(n_states, np.float32)
 89    vec = np.zeros(n_states, np.float32)
 90
 91    for _ in range(DP_MAX_ITERATION):
 92        for s in range(n_states):
 93            first_action = True
 94            for a in range(n_actions):
 95                vec = _max_proba(
 96                    T[s, a], sorted_indices, beta_p[s, a], n_states, n_actions
 97                )
 98                vec[s] -= 1
 99                r_optimal = min(
100                    np.float32(r_max),
101                    estimated_rewards[s, a] + beta_r[s, a],
102                )
103                v = r_optimal + np.dot(vec, u1)
104                Q[s, a] = v
105                if (
106                    first_action
107                    or v + u1[s] > u2[s]
108                    or np.abs(v + u1[s] - u2[s]) < epsilon
109                ):  # optimal policy = argmax
110                    u2[s] = np.float32(v + u1[s])
111                first_action = False
112            V[s] = np.max(Q[s])
113        if np.ptp(u2 - u1) < epsilon:  # stopping condition of EVI
114            return np.ptp(u1), Q, V
115        else:
116            u1 = u2
117            u2 = np.empty(n_states, np.float32)
118            sorted_indices = np.argsort(u1)
119    return None

if successful, it returns the span of the value function, the Q matrix and the V matrix. It returns None when it was not possible to complete the extended value iteration procedure.

def discounted_policy_iteration(T: numpy.ndarray, R: numpy.ndarray, gamma=0.99, epsilon=1e-07):
209def discounted_policy_iteration(T: np.ndarray, R: np.ndarray, gamma=0.99, epsilon=1e-7):
210    n_states, n_actions, _ = T.shape
211
212    Q = np.random.rand(n_states, n_actions)
213    pi = argmax_2d(Q)
214    for t in range(DP_MAX_ITERATION):
215        old_pi = pi.copy()
216        Q, V = discounted_policy_evaluation(T, R, pi, gamma, epsilon)
217        pi = argmax_2d(Q)
218        if (pi != old_pi).sum() == 0:
219            return Q, V, pi
220    raise DynamicProgrammingMaxIterationExceeded()