colosseum.dynamic_programming.utils
1import numba 2import numpy as np 3 4ARGMAX_SEED = 42 5_rng = np.random.RandomState(ARGMAX_SEED) 6 7 8class DynamicProgrammingMaxIterationExceeded(Exception): 9 pass 10 11 12@numba.njit() 13def argmax_2d(A: np.ndarray) -> np.ndarray: 14 """ 15 Returns 16 ------- 17 np.ndarray 18 The array of same dimensionality of A with ones corresponding to the max across rows and zeros otherwise. 19 """ 20 np.random.seed(ARGMAX_SEED) 21 X = np.zeros_like(A, np.float32) 22 for s in range(len(A)): 23 i = np.random.choice(np.where(A[s] == A[s].max())[0]) 24 X[s, i] = 1 25 return X 26 27 28@numba.njit() 29def argmax_3d(A: np.ndarray) -> np.ndarray: 30 """ 31 implements a vectorized version of `argmax_2d`. 32 """ 33 np.random.seed(ARGMAX_SEED) 34 X = np.zeros(A.shape, np.float32) 35 for h in range(len(A)): 36 for s in range(A.shape[1]): 37 i = np.random.choice(np.where(A[h, s] == A[h, s].max())[0]) 38 X[h, s, i] = 1.0 39 return X 40 41 42@numba.njit() 43def get_deterministic_policy_from_q_values(Q: np.ndarray) -> np.ndarray: 44 """ 45 Returns 46 ------- 47 np.ndarray 48 The infinite horizon optimal deterministic policy for the given q values. 49 """ 50 np.random.seed(ARGMAX_SEED) 51 X = np.zeros(Q.shape[:-1], np.int32) 52 for s in range(len(Q)): 53 i = np.random.choice(np.where(Q[s] == Q[s].max())[0]) 54 X[s] = np.int32(i) 55 return X 56 57 58@numba.njit() 59def get_deterministic_policy_from_q_values_finite_horizon(Q: np.ndarray) -> np.ndarray: 60 """ 61 Returns 62 ------- 63 np.ndarray 64 The finite horizon optimal deterministic policy for the given q values. 65 """ 66 np.random.seed(ARGMAX_SEED) 67 X = np.zeros(Q.shape[:-1], np.int32) 68 for h in range(len(Q)): 69 for s in range(Q.shape[1]): 70 i = np.random.choice(np.where(Q[h, s] == Q[h, s].max())[0]) 71 X[h, s] = np.int32(i) 72 return X 73 74 75def get_policy_from_q_values(Q: np.ndarray, stochastic_form=False) -> np.ndarray: 76 """ 77 Parameters 78 ---------- 79 Q : np.ndarray 80 The q-value estimates. 81 stochastic_form : bool 82 If False, the array contains the integers corresponding to the optimal actions. If True, the array contains 83 vectors representing the deterministic probability distributions. 84 85 Returns 86 ------- 87 np.ndarray 88 The deterministic policy derived from the q_values given in input. 89 90 """ 91 # Episodic case 92 if Q.ndim == 3: 93 if stochastic_form: 94 return argmax_3d(Q) 95 return get_deterministic_policy_from_q_values_finite_horizon(Q) 96 97 # Infinite horizon case 98 if stochastic_form: 99 return argmax_2d(Q) 100 return get_deterministic_policy_from_q_values(Q)
class
DynamicProgrammingMaxIterationExceeded(builtins.Exception):
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
@numba.njit()
def
argmax_2d(A: numpy.ndarray) -> numpy.ndarray:
13@numba.njit() 14def argmax_2d(A: np.ndarray) -> np.ndarray: 15 """ 16 Returns 17 ------- 18 np.ndarray 19 The array of same dimensionality of A with ones corresponding to the max across rows and zeros otherwise. 20 """ 21 np.random.seed(ARGMAX_SEED) 22 X = np.zeros_like(A, np.float32) 23 for s in range(len(A)): 24 i = np.random.choice(np.where(A[s] == A[s].max())[0]) 25 X[s, i] = 1 26 return X
Returns
- np.ndarray: The array of same dimensionality of A with ones corresponding to the max across rows and zeros otherwise.
@numba.njit()
def
argmax_3d(A: numpy.ndarray) -> numpy.ndarray:
29@numba.njit() 30def argmax_3d(A: np.ndarray) -> np.ndarray: 31 """ 32 implements a vectorized version of `argmax_2d`. 33 """ 34 np.random.seed(ARGMAX_SEED) 35 X = np.zeros(A.shape, np.float32) 36 for h in range(len(A)): 37 for s in range(A.shape[1]): 38 i = np.random.choice(np.where(A[h, s] == A[h, s].max())[0]) 39 X[h, s, i] = 1.0 40 return X
implements a vectorized version of argmax_2d
.
@numba.njit()
def
get_deterministic_policy_from_q_values(Q: numpy.ndarray) -> numpy.ndarray:
43@numba.njit() 44def get_deterministic_policy_from_q_values(Q: np.ndarray) -> np.ndarray: 45 """ 46 Returns 47 ------- 48 np.ndarray 49 The infinite horizon optimal deterministic policy for the given q values. 50 """ 51 np.random.seed(ARGMAX_SEED) 52 X = np.zeros(Q.shape[:-1], np.int32) 53 for s in range(len(Q)): 54 i = np.random.choice(np.where(Q[s] == Q[s].max())[0]) 55 X[s] = np.int32(i) 56 return X
Returns
- np.ndarray: The infinite horizon optimal deterministic policy for the given q values.
@numba.njit()
def
get_deterministic_policy_from_q_values_finite_horizon(Q: numpy.ndarray) -> numpy.ndarray:
59@numba.njit() 60def get_deterministic_policy_from_q_values_finite_horizon(Q: np.ndarray) -> np.ndarray: 61 """ 62 Returns 63 ------- 64 np.ndarray 65 The finite horizon optimal deterministic policy for the given q values. 66 """ 67 np.random.seed(ARGMAX_SEED) 68 X = np.zeros(Q.shape[:-1], np.int32) 69 for h in range(len(Q)): 70 for s in range(Q.shape[1]): 71 i = np.random.choice(np.where(Q[h, s] == Q[h, s].max())[0]) 72 X[h, s] = np.int32(i) 73 return X
Returns
- np.ndarray: The finite horizon optimal deterministic policy for the given q values.
def
get_policy_from_q_values(Q: numpy.ndarray, stochastic_form=False) -> numpy.ndarray:
76def get_policy_from_q_values(Q: np.ndarray, stochastic_form=False) -> np.ndarray: 77 """ 78 Parameters 79 ---------- 80 Q : np.ndarray 81 The q-value estimates. 82 stochastic_form : bool 83 If False, the array contains the integers corresponding to the optimal actions. If True, the array contains 84 vectors representing the deterministic probability distributions. 85 86 Returns 87 ------- 88 np.ndarray 89 The deterministic policy derived from the q_values given in input. 90 91 """ 92 # Episodic case 93 if Q.ndim == 3: 94 if stochastic_form: 95 return argmax_3d(Q) 96 return get_deterministic_policy_from_q_values_finite_horizon(Q) 97 98 # Infinite horizon case 99 if stochastic_form: 100 return argmax_2d(Q) 101 return get_deterministic_policy_from_q_values(Q)
Parameters
- Q (np.ndarray): The q-value estimates.
- stochastic_form (bool): If False, the array contains the integers corresponding to the optimal actions. If True, the array contains vectors representing the deterministic probability distributions.
Returns
- np.ndarray: The deterministic policy derived from the q_values given in input.