colosseum.dynamic_programming.utils

  1import numba
  2import numpy as np
  3
  4ARGMAX_SEED = 42
  5_rng = np.random.RandomState(ARGMAX_SEED)
  6
  7
  8class DynamicProgrammingMaxIterationExceeded(Exception):
  9    pass
 10
 11
 12@numba.njit()
 13def argmax_2d(A: np.ndarray) -> np.ndarray:
 14    """
 15    Returns
 16    -------
 17    np.ndarray
 18        The array of same dimensionality of A with ones corresponding to the max across rows and zeros otherwise.
 19    """
 20    np.random.seed(ARGMAX_SEED)
 21    X = np.zeros_like(A, np.float32)
 22    for s in range(len(A)):
 23        i = np.random.choice(np.where(A[s] == A[s].max())[0])
 24        X[s, i] = 1
 25    return X
 26
 27
 28@numba.njit()
 29def argmax_3d(A: np.ndarray) -> np.ndarray:
 30    """
 31    implements a vectorized version of `argmax_2d`.
 32    """
 33    np.random.seed(ARGMAX_SEED)
 34    X = np.zeros(A.shape, np.float32)
 35    for h in range(len(A)):
 36        for s in range(A.shape[1]):
 37            i = np.random.choice(np.where(A[h, s] == A[h, s].max())[0])
 38            X[h, s, i] = 1.0
 39    return X
 40
 41
 42@numba.njit()
 43def get_deterministic_policy_from_q_values(Q: np.ndarray) -> np.ndarray:
 44    """
 45    Returns
 46    -------
 47    np.ndarray
 48        The infinite horizon optimal deterministic policy for the given q values.
 49    """
 50    np.random.seed(ARGMAX_SEED)
 51    X = np.zeros(Q.shape[:-1], np.int32)
 52    for s in range(len(Q)):
 53        i = np.random.choice(np.where(Q[s] == Q[s].max())[0])
 54        X[s] = np.int32(i)
 55    return X
 56
 57
 58@numba.njit()
 59def get_deterministic_policy_from_q_values_finite_horizon(Q: np.ndarray) -> np.ndarray:
 60    """
 61    Returns
 62    -------
 63    np.ndarray
 64        The finite horizon optimal deterministic policy for the given q values.
 65    """
 66    np.random.seed(ARGMAX_SEED)
 67    X = np.zeros(Q.shape[:-1], np.int32)
 68    for h in range(len(Q)):
 69        for s in range(Q.shape[1]):
 70            i = np.random.choice(np.where(Q[h, s] == Q[h, s].max())[0])
 71            X[h, s] = np.int32(i)
 72    return X
 73
 74
 75def get_policy_from_q_values(Q: np.ndarray, stochastic_form=False) -> np.ndarray:
 76    """
 77    Parameters
 78    ----------
 79    Q : np.ndarray
 80        The q-value estimates.
 81    stochastic_form : bool
 82        If False, the array contains the integers corresponding to the optimal actions. If True, the array contains
 83        vectors representing the deterministic probability distributions.
 84
 85    Returns
 86    -------
 87    np.ndarray
 88        The deterministic policy derived from the q_values given in input.
 89
 90    """
 91    # Episodic case
 92    if Q.ndim == 3:
 93        if stochastic_form:
 94            return argmax_3d(Q)
 95        return get_deterministic_policy_from_q_values_finite_horizon(Q)
 96
 97    # Infinite horizon case
 98    if stochastic_form:
 99        return argmax_2d(Q)
100    return get_deterministic_policy_from_q_values(Q)
class DynamicProgrammingMaxIterationExceeded(builtins.Exception):
 9class DynamicProgrammingMaxIterationExceeded(Exception):
10    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
@numba.njit()
def argmax_2d(A: numpy.ndarray) -> numpy.ndarray:
13@numba.njit()
14def argmax_2d(A: np.ndarray) -> np.ndarray:
15    """
16    Returns
17    -------
18    np.ndarray
19        The array of same dimensionality of A with ones corresponding to the max across rows and zeros otherwise.
20    """
21    np.random.seed(ARGMAX_SEED)
22    X = np.zeros_like(A, np.float32)
23    for s in range(len(A)):
24        i = np.random.choice(np.where(A[s] == A[s].max())[0])
25        X[s, i] = 1
26    return X
Returns
  • np.ndarray: The array of same dimensionality of A with ones corresponding to the max across rows and zeros otherwise.
@numba.njit()
def argmax_3d(A: numpy.ndarray) -> numpy.ndarray:
29@numba.njit()
30def argmax_3d(A: np.ndarray) -> np.ndarray:
31    """
32    implements a vectorized version of `argmax_2d`.
33    """
34    np.random.seed(ARGMAX_SEED)
35    X = np.zeros(A.shape, np.float32)
36    for h in range(len(A)):
37        for s in range(A.shape[1]):
38            i = np.random.choice(np.where(A[h, s] == A[h, s].max())[0])
39            X[h, s, i] = 1.0
40    return X

implements a vectorized version of argmax_2d.

@numba.njit()
def get_deterministic_policy_from_q_values(Q: numpy.ndarray) -> numpy.ndarray:
43@numba.njit()
44def get_deterministic_policy_from_q_values(Q: np.ndarray) -> np.ndarray:
45    """
46    Returns
47    -------
48    np.ndarray
49        The infinite horizon optimal deterministic policy for the given q values.
50    """
51    np.random.seed(ARGMAX_SEED)
52    X = np.zeros(Q.shape[:-1], np.int32)
53    for s in range(len(Q)):
54        i = np.random.choice(np.where(Q[s] == Q[s].max())[0])
55        X[s] = np.int32(i)
56    return X
Returns
  • np.ndarray: The infinite horizon optimal deterministic policy for the given q values.
@numba.njit()
def get_deterministic_policy_from_q_values_finite_horizon(Q: numpy.ndarray) -> numpy.ndarray:
59@numba.njit()
60def get_deterministic_policy_from_q_values_finite_horizon(Q: np.ndarray) -> np.ndarray:
61    """
62    Returns
63    -------
64    np.ndarray
65        The finite horizon optimal deterministic policy for the given q values.
66    """
67    np.random.seed(ARGMAX_SEED)
68    X = np.zeros(Q.shape[:-1], np.int32)
69    for h in range(len(Q)):
70        for s in range(Q.shape[1]):
71            i = np.random.choice(np.where(Q[h, s] == Q[h, s].max())[0])
72            X[h, s] = np.int32(i)
73    return X
Returns
  • np.ndarray: The finite horizon optimal deterministic policy for the given q values.
def get_policy_from_q_values(Q: numpy.ndarray, stochastic_form=False) -> numpy.ndarray:
 76def get_policy_from_q_values(Q: np.ndarray, stochastic_form=False) -> np.ndarray:
 77    """
 78    Parameters
 79    ----------
 80    Q : np.ndarray
 81        The q-value estimates.
 82    stochastic_form : bool
 83        If False, the array contains the integers corresponding to the optimal actions. If True, the array contains
 84        vectors representing the deterministic probability distributions.
 85
 86    Returns
 87    -------
 88    np.ndarray
 89        The deterministic policy derived from the q_values given in input.
 90
 91    """
 92    # Episodic case
 93    if Q.ndim == 3:
 94        if stochastic_form:
 95            return argmax_3d(Q)
 96        return get_deterministic_policy_from_q_values_finite_horizon(Q)
 97
 98    # Infinite horizon case
 99    if stochastic_form:
100        return argmax_2d(Q)
101    return get_deterministic_policy_from_q_values(Q)
Parameters
  • Q (np.ndarray): The q-value estimates.
  • stochastic_form (bool): If False, the array contains the integers corresponding to the optimal actions. If True, the array contains vectors representing the deterministic probability distributions.
Returns
  • np.ndarray: The deterministic policy derived from the q_values given in input.