colosseum.dynamic_programming.infinite_horizon
1from typing import Tuple, Union 2 3import numba 4import numpy as np 5import sparse 6 7from colosseum.dynamic_programming import DP_MAX_ITERATION 8from colosseum.dynamic_programming.utils import ( 9 DynamicProgrammingMaxIterationExceeded, 10 argmax_2d, 11) 12 13 14def discounted_value_iteration( 15 T: Union[np.ndarray, sparse.COO], 16 R: np.ndarray, 17 gamma=0.99, 18 epsilon=1e-3, 19 max_abs_value: float = None, 20 sparse_n_states_threshold: int = 300 * 3 * 300, 21 sparse_nnz_per_threshold: float = 0.2, 22) -> Tuple[np.ndarray, np.ndarray]: 23 n_states, n_actions, _ = T.shape 24 25 if type(T) == sparse.COO: 26 return _discounted_value_iteration_sparse(T, R, gamma, epsilon, max_abs_value) 27 28 if T.size > sparse_n_states_threshold: 29 T_sparse = sparse.COO(T) 30 if T_sparse.nnz / T.size < sparse_nnz_per_threshold: 31 return _discounted_value_iteration_sparse( 32 T_sparse, R, gamma, epsilon, max_abs_value 33 ) 34 35 try: 36 res = _discounted_value_iteration(T, R, gamma, epsilon, max_abs_value) 37 except: 38 # Failure of discounted value iteration may randomly happen when using multiprocessing. 39 # If that happens, we can simply resort to sparse value iteration. 40 T_sparse = sparse.COO(T) 41 res = _discounted_value_iteration_sparse( 42 T_sparse, R, gamma, epsilon, max_abs_value 43 ) 44 return res 45 46 47def discounted_policy_evaluation( 48 T: Union[np.ndarray, sparse.COO], 49 R: np.ndarray, 50 pi: np.ndarray, 51 gamma=0.99, 52 epsilon=1e-7, 53 sparse_n_states_threshold: int = 200, 54 sparse_nnz_per_threshold: float = 0.2, 55) -> Tuple[np.ndarray, np.ndarray]: 56 n_states, n_actions, _ = T.shape 57 58 if type(T) == sparse.COO: 59 return _discounted_policy_evaluation_sparse(T, R, pi, gamma, epsilon) 60 if n_states > sparse_n_states_threshold: 61 T_sparse = sparse.COO(T) 62 if T_sparse.nnz / T.size < sparse_nnz_per_threshold: 63 return _discounted_policy_evaluation_sparse(T_sparse, R, pi, gamma, epsilon) 64 return _discounted_policy_evaluation(T, R, pi, gamma, epsilon) 65 66 67@numba.njit() 68def extended_value_iteration( 69 T: np.ndarray, 70 estimated_rewards: np.ndarray, 71 beta_r: np.ndarray, 72 beta_p: np.ndarray, 73 r_max: float, 74 epsilon=1e-3, 75) -> Union[None, Tuple[float, np.ndarray, np.ndarray]]: 76 """ 77 if successful, it returns the span of the value function, the Q matrix and the V matrix. It returns None when it was 78 not possible to complete the extended value iteration procedure. 79 """ 80 n_states, n_actions = beta_r.shape 81 82 Q = np.zeros((n_states, n_actions), dtype=np.float32) 83 V = np.zeros((n_states,), dtype=np.float32) 84 85 u1 = np.zeros(n_states, np.float32) 86 sorted_indices = np.arange(n_states) 87 u2 = np.zeros(n_states, np.float32) 88 vec = np.zeros(n_states, np.float32) 89 90 for _ in range(DP_MAX_ITERATION): 91 for s in range(n_states): 92 first_action = True 93 for a in range(n_actions): 94 vec = _max_proba( 95 T[s, a], sorted_indices, beta_p[s, a], n_states, n_actions 96 ) 97 vec[s] -= 1 98 r_optimal = min( 99 np.float32(r_max), 100 estimated_rewards[s, a] + beta_r[s, a], 101 ) 102 v = r_optimal + np.dot(vec, u1) 103 Q[s, a] = v 104 if ( 105 first_action 106 or v + u1[s] > u2[s] 107 or np.abs(v + u1[s] - u2[s]) < epsilon 108 ): # optimal policy = argmax 109 u2[s] = np.float32(v + u1[s]) 110 first_action = False 111 V[s] = np.max(Q[s]) 112 if np.ptp(u2 - u1) < epsilon: # stopping condition of EVI 113 return np.ptp(u1), Q, V 114 else: 115 u1 = u2 116 u2 = np.empty(n_states, np.float32) 117 sorted_indices = np.argsort(u1) 118 return None 119 120 121@numba.njit() 122def _discounted_value_iteration( 123 T: np.ndarray, R: np.ndarray, gamma=0.99, epsilon=1e-3, max_abs_value: float = None 124) -> Tuple[np.ndarray, np.ndarray]: 125 126 n_states, n_actions, _ = T.shape 127 gamma = np.float32(gamma) 128 129 V = np.zeros(n_states, dtype=np.float32) 130 Q = np.zeros((n_states, n_actions), dtype=np.float32) 131 for _ in range(DP_MAX_ITERATION): 132 V_old = V.copy() 133 for s in range(n_states): 134 Q[s] = R[s] + gamma * T[s] @ V 135 V[s] = Q[s].max() 136 if max_abs_value is not None: 137 if np.abs(V[s]) > max_abs_value: 138 return None 139 diff = np.abs(V_old - V).max() 140 if diff < epsilon: 141 return Q, V 142 raise DynamicProgrammingMaxIterationExceeded() 143 144 145def _discounted_value_iteration_sparse( 146 T: sparse.COO, R: np.ndarray, gamma=0.99, epsilon=1e-3, max_abs_value: float = None 147) -> Tuple[np.ndarray, np.ndarray]: 148 n_states, n_actions, _ = T.shape 149 gamma = np.float32(gamma) 150 151 V = np.zeros(n_states, dtype=np.float32) 152 for _ in range(DP_MAX_ITERATION): 153 V_old = V.copy() 154 Q = R + gamma * (T @ V).squeeze() 155 V = Q.max(1, keepdims=True) 156 157 if max_abs_value is not None: 158 if V.abs() > max_abs_value: 159 return None 160 161 diff = np.abs(V_old.squeeze() - V.squeeze()).max() 162 if diff < epsilon: 163 return Q, V.squeeze() 164 raise DynamicProgrammingMaxIterationExceeded() 165 166 167@numba.njit() 168def _discounted_policy_evaluation( 169 T: np.ndarray, R: np.ndarray, pi: np.ndarray, gamma=0.99, epsilon=1e-7 170) -> Tuple[np.ndarray, np.ndarray]: 171 n_states, n_actions, _ = T.shape 172 gamma = np.array([gamma], dtype=np.float32) 173 174 V = np.zeros(n_states, dtype=np.float32) 175 Q = np.zeros((n_states, n_actions), dtype=np.float32) 176 for _ in range(DP_MAX_ITERATION): 177 V_old = V.copy() 178 for s in range(n_states): 179 Q[s] = R[s] + gamma * T[s] @ V 180 V[s] = (Q[s] * pi[s]).sum() 181 diff = np.abs(V_old - V).max() 182 if diff < epsilon: 183 return Q, V 184 raise DynamicProgrammingMaxIterationExceeded() 185 186 187def _discounted_policy_evaluation_sparse( 188 T: Union[np.ndarray, sparse.COO], 189 R: np.ndarray, 190 pi: np.ndarray, 191 gamma=0.99, 192 epsilon=1e-7, 193) -> Tuple[np.ndarray, np.ndarray]: 194 n_states, n_actions, _ = T.shape 195 gamma = np.array([gamma], dtype=np.float32) 196 197 V = np.zeros(n_states, dtype=np.float32) 198 for _ in range(DP_MAX_ITERATION): 199 V_old = V.copy() 200 Q = R + gamma * T @ V 201 V = (Q * pi).sum(1) 202 diff = np.abs(V_old - V).max() 203 if diff < epsilon: 204 return Q, V 205 raise DynamicProgrammingMaxIterationExceeded() 206 207 208def discounted_policy_iteration(T: np.ndarray, R: np.ndarray, gamma=0.99, epsilon=1e-7): 209 n_states, n_actions, _ = T.shape 210 211 Q = np.random.rand(n_states, n_actions) 212 pi = argmax_2d(Q) 213 for t in range(DP_MAX_ITERATION): 214 old_pi = pi.copy() 215 Q, V = discounted_policy_evaluation(T, R, pi, gamma, epsilon) 216 pi = argmax_2d(Q) 217 if (pi != old_pi).sum() == 0: 218 return Q, V, pi 219 raise DynamicProgrammingMaxIterationExceeded() 220 221 222@numba.njit() 223def _max_proba( 224 p: np.ndarray, 225 sorted_indices: np.ndarray, 226 beta: np.ndarray, 227 n_states: int, 228 n_actions: int, 229) -> np.ndarray: 230 min1 = min(1.0, (p[sorted_indices[n_states - 1]] + beta / 2)[0]) 231 if min1 == 1: 232 p2 = np.zeros(n_states, np.float32) 233 p2[sorted_indices[n_states - 1]] = 1 234 else: 235 sorted_p = p[sorted_indices] 236 support_sorted_p = np.nonzero(sorted_p)[0] 237 restricted_sorted_p = sorted_p[support_sorted_p] 238 support_p = sorted_indices[support_sorted_p] 239 p2 = np.zeros(n_states, np.float32) 240 p2[support_p] = restricted_sorted_p 241 p2[sorted_indices[n_states - 1]] = min1 242 s = 1 - p[sorted_indices[n_states - 1]] + min1 243 s2 = s 244 for i, proba in enumerate(restricted_sorted_p): 245 max1 = max(0, 1 - s + proba) 246 s2 += max1 - proba 247 p2[support_p[i]] = max1 248 s = s2 249 if s <= 1: 250 break 251 return p2
def
discounted_value_iteration( T: Union[numpy.ndarray, sparse._coo.core.COO], R: numpy.ndarray, gamma=0.99, epsilon=0.001, max_abs_value: float = None, sparse_n_states_threshold: int = 270000, sparse_nnz_per_threshold: float = 0.2) -> Tuple[numpy.ndarray, numpy.ndarray]:
15def discounted_value_iteration( 16 T: Union[np.ndarray, sparse.COO], 17 R: np.ndarray, 18 gamma=0.99, 19 epsilon=1e-3, 20 max_abs_value: float = None, 21 sparse_n_states_threshold: int = 300 * 3 * 300, 22 sparse_nnz_per_threshold: float = 0.2, 23) -> Tuple[np.ndarray, np.ndarray]: 24 n_states, n_actions, _ = T.shape 25 26 if type(T) == sparse.COO: 27 return _discounted_value_iteration_sparse(T, R, gamma, epsilon, max_abs_value) 28 29 if T.size > sparse_n_states_threshold: 30 T_sparse = sparse.COO(T) 31 if T_sparse.nnz / T.size < sparse_nnz_per_threshold: 32 return _discounted_value_iteration_sparse( 33 T_sparse, R, gamma, epsilon, max_abs_value 34 ) 35 36 try: 37 res = _discounted_value_iteration(T, R, gamma, epsilon, max_abs_value) 38 except: 39 # Failure of discounted value iteration may randomly happen when using multiprocessing. 40 # If that happens, we can simply resort to sparse value iteration. 41 T_sparse = sparse.COO(T) 42 res = _discounted_value_iteration_sparse( 43 T_sparse, R, gamma, epsilon, max_abs_value 44 ) 45 return res
def
discounted_policy_evaluation( T: Union[numpy.ndarray, sparse._coo.core.COO], R: numpy.ndarray, pi: numpy.ndarray, gamma=0.99, epsilon=1e-07, sparse_n_states_threshold: int = 200, sparse_nnz_per_threshold: float = 0.2) -> Tuple[numpy.ndarray, numpy.ndarray]:
48def discounted_policy_evaluation( 49 T: Union[np.ndarray, sparse.COO], 50 R: np.ndarray, 51 pi: np.ndarray, 52 gamma=0.99, 53 epsilon=1e-7, 54 sparse_n_states_threshold: int = 200, 55 sparse_nnz_per_threshold: float = 0.2, 56) -> Tuple[np.ndarray, np.ndarray]: 57 n_states, n_actions, _ = T.shape 58 59 if type(T) == sparse.COO: 60 return _discounted_policy_evaluation_sparse(T, R, pi, gamma, epsilon) 61 if n_states > sparse_n_states_threshold: 62 T_sparse = sparse.COO(T) 63 if T_sparse.nnz / T.size < sparse_nnz_per_threshold: 64 return _discounted_policy_evaluation_sparse(T_sparse, R, pi, gamma, epsilon) 65 return _discounted_policy_evaluation(T, R, pi, gamma, epsilon)
@numba.njit()
def
extended_value_iteration( T: numpy.ndarray, estimated_rewards: numpy.ndarray, beta_r: numpy.ndarray, beta_p: numpy.ndarray, r_max: float, epsilon=0.001) -> Optional[Tuple[float, numpy.ndarray, numpy.ndarray]]:
68@numba.njit() 69def extended_value_iteration( 70 T: np.ndarray, 71 estimated_rewards: np.ndarray, 72 beta_r: np.ndarray, 73 beta_p: np.ndarray, 74 r_max: float, 75 epsilon=1e-3, 76) -> Union[None, Tuple[float, np.ndarray, np.ndarray]]: 77 """ 78 if successful, it returns the span of the value function, the Q matrix and the V matrix. It returns None when it was 79 not possible to complete the extended value iteration procedure. 80 """ 81 n_states, n_actions = beta_r.shape 82 83 Q = np.zeros((n_states, n_actions), dtype=np.float32) 84 V = np.zeros((n_states,), dtype=np.float32) 85 86 u1 = np.zeros(n_states, np.float32) 87 sorted_indices = np.arange(n_states) 88 u2 = np.zeros(n_states, np.float32) 89 vec = np.zeros(n_states, np.float32) 90 91 for _ in range(DP_MAX_ITERATION): 92 for s in range(n_states): 93 first_action = True 94 for a in range(n_actions): 95 vec = _max_proba( 96 T[s, a], sorted_indices, beta_p[s, a], n_states, n_actions 97 ) 98 vec[s] -= 1 99 r_optimal = min( 100 np.float32(r_max), 101 estimated_rewards[s, a] + beta_r[s, a], 102 ) 103 v = r_optimal + np.dot(vec, u1) 104 Q[s, a] = v 105 if ( 106 first_action 107 or v + u1[s] > u2[s] 108 or np.abs(v + u1[s] - u2[s]) < epsilon 109 ): # optimal policy = argmax 110 u2[s] = np.float32(v + u1[s]) 111 first_action = False 112 V[s] = np.max(Q[s]) 113 if np.ptp(u2 - u1) < epsilon: # stopping condition of EVI 114 return np.ptp(u1), Q, V 115 else: 116 u1 = u2 117 u2 = np.empty(n_states, np.float32) 118 sorted_indices = np.argsort(u1) 119 return None
if successful, it returns the span of the value function, the Q matrix and the V matrix. It returns None when it was not possible to complete the extended value iteration procedure.
def
discounted_policy_iteration(T: numpy.ndarray, R: numpy.ndarray, gamma=0.99, epsilon=1e-07):
209def discounted_policy_iteration(T: np.ndarray, R: np.ndarray, gamma=0.99, epsilon=1e-7): 210 n_states, n_actions, _ = T.shape 211 212 Q = np.random.rand(n_states, n_actions) 213 pi = argmax_2d(Q) 214 for t in range(DP_MAX_ITERATION): 215 old_pi = pi.copy() 216 Q, V = discounted_policy_evaluation(T, R, pi, gamma, epsilon) 217 pi = argmax_2d(Q) 218 if (pi != old_pi).sum() == 0: 219 return Q, V, pi 220 raise DynamicProgrammingMaxIterationExceeded()