colosseum.agent.mdp_models.bayesian_models.conjugate_rewards
1from typing import Dict, List, Tuple, Union 2 3import numpy as np 4 5from colosseum.agent.mdp_models.bayesian_models import ConjugateModel 6 7PRIOR_TYPE = Union[ 8 List[ 9 float, 10 ], 11 Dict[ 12 Tuple[int, int], 13 List[ 14 float, 15 ], 16 ], 17] 18 19 20class N_NIG(ConjugateModel): 21 """ 22 The Normal-Normal Inverse Gamma conjugate model. 23 """ 24 25 def __init__( 26 self, 27 n_states: int, 28 n_actions: int, 29 hyper_params: Union[List[float], List[List[float]]], 30 seed: int, 31 interpretable_parameters: bool = True, 32 ): 33 """ 34 Parameters 35 ---------- 36 n_states : int 37 The number of states of the MDP. 38 n_actions : int 39 The number of action of the MDP. 40 hyper_params : Union[List[float],List[List[float]]] 41 The prior parameters can either be a list of parameters that are set identical for each 42 state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters 43 as value. 44 seed : int 45 The random seed. 46 interpretable_parameters : bool 47 If True, the parameters are given in the natural way of speaking of NIG parameters. 48 """ 49 50 super(N_NIG, self).__init__(n_states, n_actions, hyper_params, seed) 51 52 assert self.hyper_params.shape == (n_states, n_actions, 4) 53 54 if interpretable_parameters: 55 for i in range(self.hyper_params.shape[0]): 56 for j in range(self.hyper_params.shape[1]): 57 mu, n_mu, tau, n_tau = self.hyper_params[i, j] 58 self.hyper_params[i, j] = ( 59 mu, 60 n_mu, 61 n_tau * 0.5, 62 (0.5 * n_tau) / tau, 63 ) 64 65 def update_sa(self, s: int, a: int, rs: List[float]): 66 # Unpack the prior 67 (mu0, lambda0, alpha0, beta0) = self.hyper_params[s, a] 68 69 n = len(rs) 70 y_bar = np.mean(rs) 71 72 # Updating normal component 73 lambda1 = lambda0 + n 74 mu1 = (lambda0 * mu0 + n * y_bar) / lambda1 75 76 # Updating Inverse-Gamma component 77 alpha1 = alpha0 + (n * 0.5) 78 ssq = n * np.var(rs) 79 prior_disc = lambda0 * n * ((y_bar - mu0) ** 2) / lambda1 80 beta1 = beta0 + 0.5 * (ssq + prior_disc) 81 82 self.hyper_params[s, a] = (mu1, lambda1, alpha1, beta1) 83 84 def sample(self, n: int = 1) -> np.ndarray: 85 # Unpack the prior 86 (mu, lambda0, alpha, beta) = self.hyper_params.reshape( 87 self.n_states * self.n_actions, -1 88 ).T 89 90 # Sample scaling tau from a gamma distribution 91 tau = self._rng.gamma(shape=alpha, scale=1.0 / beta).astype(np.float32) 92 var = 1.0 / (lambda0 * tau) 93 94 # Sample mean from normal mean mu, var 95 mean = self._rng.normal(loc=mu, scale=np.sqrt(var), size=(n, *mu.shape)).astype( 96 np.float32 97 ) 98 99 return mean.reshape(self.n_states, self.n_actions).squeeze() 100 101 def get_map_estimate(self) -> np.ndarray: 102 return self.hyper_params[:, :, 0] 103 104 105class N_N(ConjugateModel): 106 """ 107 The Normal-Normal conjugate model. 108 """ 109 110 def __init__( 111 self, 112 n_states: int, 113 n_actions: int, 114 hyper_params: Union[List[float], List[List[float]]], 115 seed: int, 116 ): 117 super(N_N, self).__init__(n_states, n_actions, hyper_params, seed) 118 119 assert self.hyper_params.shape == (n_states, n_actions, 2) 120 121 def update_sa(self, s: int, a: int, xs: List[float]): 122 for r in xs: 123 mu0, tau0 = self.hyper_params[s, a] 124 tau1 = tau0 + 1 125 mu1 = (mu0 * tau0 + r * 1) / tau1 126 self.hyper_params[s, a] = (mu1, tau1) 127 128 def sample(self, n: int = 1) -> np.ndarray: 129 return ( 130 self._rng.normal( 131 loc=self.hyper_params[:, :, 0], scale=self.hyper_params[:, :, 1], size=n 132 ) 133 .astype(np.float32) 134 .squeeze() 135 ) 136 137 def get_map_estimate(self) -> np.ndarray: 138 return self.hyper_params[:, :, 0]
21class N_NIG(ConjugateModel): 22 """ 23 The Normal-Normal Inverse Gamma conjugate model. 24 """ 25 26 def __init__( 27 self, 28 n_states: int, 29 n_actions: int, 30 hyper_params: Union[List[float], List[List[float]]], 31 seed: int, 32 interpretable_parameters: bool = True, 33 ): 34 """ 35 Parameters 36 ---------- 37 n_states : int 38 The number of states of the MDP. 39 n_actions : int 40 The number of action of the MDP. 41 hyper_params : Union[List[float],List[List[float]]] 42 The prior parameters can either be a list of parameters that are set identical for each 43 state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters 44 as value. 45 seed : int 46 The random seed. 47 interpretable_parameters : bool 48 If True, the parameters are given in the natural way of speaking of NIG parameters. 49 """ 50 51 super(N_NIG, self).__init__(n_states, n_actions, hyper_params, seed) 52 53 assert self.hyper_params.shape == (n_states, n_actions, 4) 54 55 if interpretable_parameters: 56 for i in range(self.hyper_params.shape[0]): 57 for j in range(self.hyper_params.shape[1]): 58 mu, n_mu, tau, n_tau = self.hyper_params[i, j] 59 self.hyper_params[i, j] = ( 60 mu, 61 n_mu, 62 n_tau * 0.5, 63 (0.5 * n_tau) / tau, 64 ) 65 66 def update_sa(self, s: int, a: int, rs: List[float]): 67 # Unpack the prior 68 (mu0, lambda0, alpha0, beta0) = self.hyper_params[s, a] 69 70 n = len(rs) 71 y_bar = np.mean(rs) 72 73 # Updating normal component 74 lambda1 = lambda0 + n 75 mu1 = (lambda0 * mu0 + n * y_bar) / lambda1 76 77 # Updating Inverse-Gamma component 78 alpha1 = alpha0 + (n * 0.5) 79 ssq = n * np.var(rs) 80 prior_disc = lambda0 * n * ((y_bar - mu0) ** 2) / lambda1 81 beta1 = beta0 + 0.5 * (ssq + prior_disc) 82 83 self.hyper_params[s, a] = (mu1, lambda1, alpha1, beta1) 84 85 def sample(self, n: int = 1) -> np.ndarray: 86 # Unpack the prior 87 (mu, lambda0, alpha, beta) = self.hyper_params.reshape( 88 self.n_states * self.n_actions, -1 89 ).T 90 91 # Sample scaling tau from a gamma distribution 92 tau = self._rng.gamma(shape=alpha, scale=1.0 / beta).astype(np.float32) 93 var = 1.0 / (lambda0 * tau) 94 95 # Sample mean from normal mean mu, var 96 mean = self._rng.normal(loc=mu, scale=np.sqrt(var), size=(n, *mu.shape)).astype( 97 np.float32 98 ) 99 100 return mean.reshape(self.n_states, self.n_actions).squeeze() 101 102 def get_map_estimate(self) -> np.ndarray: 103 return self.hyper_params[:, :, 0]
The Normal-Normal Inverse Gamma conjugate model.
N_NIG( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int, interpretable_parameters: bool = True)
26 def __init__( 27 self, 28 n_states: int, 29 n_actions: int, 30 hyper_params: Union[List[float], List[List[float]]], 31 seed: int, 32 interpretable_parameters: bool = True, 33 ): 34 """ 35 Parameters 36 ---------- 37 n_states : int 38 The number of states of the MDP. 39 n_actions : int 40 The number of action of the MDP. 41 hyper_params : Union[List[float],List[List[float]]] 42 The prior parameters can either be a list of parameters that are set identical for each 43 state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters 44 as value. 45 seed : int 46 The random seed. 47 interpretable_parameters : bool 48 If True, the parameters are given in the natural way of speaking of NIG parameters. 49 """ 50 51 super(N_NIG, self).__init__(n_states, n_actions, hyper_params, seed) 52 53 assert self.hyper_params.shape == (n_states, n_actions, 4) 54 55 if interpretable_parameters: 56 for i in range(self.hyper_params.shape[0]): 57 for j in range(self.hyper_params.shape[1]): 58 mu, n_mu, tau, n_tau = self.hyper_params[i, j] 59 self.hyper_params[i, j] = ( 60 mu, 61 n_mu, 62 n_tau * 0.5, 63 (0.5 * n_tau) / tau, 64 )
Parameters
- n_states (int): The number of states of the MDP.
- n_actions (int): The number of action of the MDP.
- hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
- seed (int): The random seed.
- interpretable_parameters (bool): If True, the parameters are given in the natural way of speaking of NIG parameters.
def
update_sa(self, s: int, a: int, rs: List[float]):
66 def update_sa(self, s: int, a: int, rs: List[float]): 67 # Unpack the prior 68 (mu0, lambda0, alpha0, beta0) = self.hyper_params[s, a] 69 70 n = len(rs) 71 y_bar = np.mean(rs) 72 73 # Updating normal component 74 lambda1 = lambda0 + n 75 mu1 = (lambda0 * mu0 + n * y_bar) / lambda1 76 77 # Updating Inverse-Gamma component 78 alpha1 = alpha0 + (n * 0.5) 79 ssq = n * np.var(rs) 80 prior_disc = lambda0 * n * ((y_bar - mu0) ** 2) / lambda1 81 beta1 = beta0 + 0.5 * (ssq + prior_disc) 82 83 self.hyper_params[s, a] = (mu1, lambda1, alpha1, beta1)
updates the beliefs of the given state action pair.
Parameters
- s (int): The state to update.
- a (int): The action to update.
- xs (List): The samples obtained from state action pair (s,a).
def
sample(self, n: int = 1) -> numpy.ndarray:
85 def sample(self, n: int = 1) -> np.ndarray: 86 # Unpack the prior 87 (mu, lambda0, alpha, beta) = self.hyper_params.reshape( 88 self.n_states * self.n_actions, -1 89 ).T 90 91 # Sample scaling tau from a gamma distribution 92 tau = self._rng.gamma(shape=alpha, scale=1.0 / beta).astype(np.float32) 93 var = 1.0 / (lambda0 * tau) 94 95 # Sample mean from normal mean mu, var 96 mean = self._rng.normal(loc=mu, scale=np.sqrt(var), size=(n, *mu.shape)).astype( 97 np.float32 98 ) 99 100 return mean.reshape(self.n_states, self.n_actions).squeeze()
samples from the posterior
Parameters
- n (int): The number of samples. By default, it is set to one.
Returns
- np.ndarray: The n samples from the posterior.
106class N_N(ConjugateModel): 107 """ 108 The Normal-Normal conjugate model. 109 """ 110 111 def __init__( 112 self, 113 n_states: int, 114 n_actions: int, 115 hyper_params: Union[List[float], List[List[float]]], 116 seed: int, 117 ): 118 super(N_N, self).__init__(n_states, n_actions, hyper_params, seed) 119 120 assert self.hyper_params.shape == (n_states, n_actions, 2) 121 122 def update_sa(self, s: int, a: int, xs: List[float]): 123 for r in xs: 124 mu0, tau0 = self.hyper_params[s, a] 125 tau1 = tau0 + 1 126 mu1 = (mu0 * tau0 + r * 1) / tau1 127 self.hyper_params[s, a] = (mu1, tau1) 128 129 def sample(self, n: int = 1) -> np.ndarray: 130 return ( 131 self._rng.normal( 132 loc=self.hyper_params[:, :, 0], scale=self.hyper_params[:, :, 1], size=n 133 ) 134 .astype(np.float32) 135 .squeeze() 136 ) 137 138 def get_map_estimate(self) -> np.ndarray: 139 return self.hyper_params[:, :, 0]
The Normal-Normal conjugate model.
N_N( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int)
111 def __init__( 112 self, 113 n_states: int, 114 n_actions: int, 115 hyper_params: Union[List[float], List[List[float]]], 116 seed: int, 117 ): 118 super(N_N, self).__init__(n_states, n_actions, hyper_params, seed) 119 120 assert self.hyper_params.shape == (n_states, n_actions, 2)
Parameters
- n_states (int): The number of states of the MDP.
- n_actions (int): The number of action of the MDP.
- hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
- seed (int): The random seed.
def
update_sa(self, s: int, a: int, xs: List[float]):
122 def update_sa(self, s: int, a: int, xs: List[float]): 123 for r in xs: 124 mu0, tau0 = self.hyper_params[s, a] 125 tau1 = tau0 + 1 126 mu1 = (mu0 * tau0 + r * 1) / tau1 127 self.hyper_params[s, a] = (mu1, tau1)
updates the beliefs of the given state action pair.
Parameters
- s (int): The state to update.
- a (int): The action to update.
- xs (List): The samples obtained from state action pair (s,a).
def
sample(self, n: int = 1) -> numpy.ndarray:
129 def sample(self, n: int = 1) -> np.ndarray: 130 return ( 131 self._rng.normal( 132 loc=self.hyper_params[:, :, 0], scale=self.hyper_params[:, :, 1], size=n 133 ) 134 .astype(np.float32) 135 .squeeze() 136 )
samples from the posterior
Parameters
- n (int): The number of samples. By default, it is set to one.
Returns
- np.ndarray: The n samples from the posterior.