colosseum.agent.mdp_models.bayesian_models.conjugate_transitions
1from typing import List, Tuple, Union 2 3import numpy as np 4 5from colosseum.agent.mdp_models.bayesian_models import ConjugateModel 6from colosseum.utils.miscellanea import state_occurencens_to_counts 7 8PRIOR_TYPE = Union[ 9 List[ 10 float, 11 ], 12 List[ 13 List[ 14 float, 15 ], 16 ], 17] 18 19 20class M_DIR(ConjugateModel): 21 """ 22 Multinomial-Dirichlet conjugate model. 23 """ 24 25 def __init__( 26 self, 27 n_states: int, 28 n_actions: int, 29 hyper_params: Union[ 30 List[float], 31 List[ 32 List[ 33 float, 34 ] 35 ], 36 ], 37 seed: int, 38 ): 39 super(M_DIR, self).__init__(n_states, n_actions, hyper_params, seed) 40 if self.hyper_params.shape == (n_states, n_actions, 1): 41 self.hyper_params = np.tile(self.hyper_params, (1, 1, n_states)) 42 assert self.hyper_params.shape == (n_states, n_actions, n_states) 43 44 def update_sa(self, s: int, a: int, xs: List[int]): 45 xs = [state_occurencens_to_counts(x, self.n_states) for x in xs] 46 self.hyper_params[s, a] += np.array(xs).sum(0) 47 48 def _sample(self, hyper_params: np.ndarray, n: int) -> np.ndarray: 49 r = ( 50 self._rng.standard_gamma(hyper_params, (n, *hyper_params.shape)) 51 .astype(np.float32) 52 .squeeze() 53 ) 54 return r / (1e-5 + r.sum(-1, keepdims=True)) 55 56 def sample(self, n: int = 1) -> np.ndarray: 57 r = self._sample( 58 self.hyper_params.reshape(self.n_states * self.n_actions, -1), n 59 ) 60 return r.reshape((self.n_states, self.n_actions, -1)) 61 62 def sample_sa(self, sa: Tuple[int, int]) -> np.ndarray: 63 return self._sample(self.hyper_params[sa], 1) 64 65 def get_map_estimate(self) -> np.ndarray: 66 return self.hyper_params / self.hyper_params.sum(-1, keepdims=True)
21class M_DIR(ConjugateModel): 22 """ 23 Multinomial-Dirichlet conjugate model. 24 """ 25 26 def __init__( 27 self, 28 n_states: int, 29 n_actions: int, 30 hyper_params: Union[ 31 List[float], 32 List[ 33 List[ 34 float, 35 ] 36 ], 37 ], 38 seed: int, 39 ): 40 super(M_DIR, self).__init__(n_states, n_actions, hyper_params, seed) 41 if self.hyper_params.shape == (n_states, n_actions, 1): 42 self.hyper_params = np.tile(self.hyper_params, (1, 1, n_states)) 43 assert self.hyper_params.shape == (n_states, n_actions, n_states) 44 45 def update_sa(self, s: int, a: int, xs: List[int]): 46 xs = [state_occurencens_to_counts(x, self.n_states) for x in xs] 47 self.hyper_params[s, a] += np.array(xs).sum(0) 48 49 def _sample(self, hyper_params: np.ndarray, n: int) -> np.ndarray: 50 r = ( 51 self._rng.standard_gamma(hyper_params, (n, *hyper_params.shape)) 52 .astype(np.float32) 53 .squeeze() 54 ) 55 return r / (1e-5 + r.sum(-1, keepdims=True)) 56 57 def sample(self, n: int = 1) -> np.ndarray: 58 r = self._sample( 59 self.hyper_params.reshape(self.n_states * self.n_actions, -1), n 60 ) 61 return r.reshape((self.n_states, self.n_actions, -1)) 62 63 def sample_sa(self, sa: Tuple[int, int]) -> np.ndarray: 64 return self._sample(self.hyper_params[sa], 1) 65 66 def get_map_estimate(self) -> np.ndarray: 67 return self.hyper_params / self.hyper_params.sum(-1, keepdims=True)
Multinomial-Dirichlet conjugate model.
M_DIR( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int)
26 def __init__( 27 self, 28 n_states: int, 29 n_actions: int, 30 hyper_params: Union[ 31 List[float], 32 List[ 33 List[ 34 float, 35 ] 36 ], 37 ], 38 seed: int, 39 ): 40 super(M_DIR, self).__init__(n_states, n_actions, hyper_params, seed) 41 if self.hyper_params.shape == (n_states, n_actions, 1): 42 self.hyper_params = np.tile(self.hyper_params, (1, 1, n_states)) 43 assert self.hyper_params.shape == (n_states, n_actions, n_states)
Parameters
- n_states (int): The number of states of the MDP.
- n_actions (int): The number of action of the MDP.
- hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
- seed (int): The random seed.
def
update_sa(self, s: int, a: int, xs: List[int]):
45 def update_sa(self, s: int, a: int, xs: List[int]): 46 xs = [state_occurencens_to_counts(x, self.n_states) for x in xs] 47 self.hyper_params[s, a] += np.array(xs).sum(0)
updates the beliefs of the given state action pair.
Parameters
- s (int): The state to update.
- a (int): The action to update.
- xs (List): The samples obtained from state action pair (s,a).
def
sample(self, n: int = 1) -> numpy.ndarray:
57 def sample(self, n: int = 1) -> np.ndarray: 58 r = self._sample( 59 self.hyper_params.reshape(self.n_states * self.n_actions, -1), n 60 ) 61 return r.reshape((self.n_states, self.n_actions, -1))
samples from the posterior
Parameters
- n (int): The number of samples. By default, it is set to one.
Returns
- np.ndarray: The n samples from the posterior.