colosseum.agent.mdp_models.bayesian_models.base_conjugate
1from abc import ABC, abstractmethod 2from typing import Any, Dict, List, Tuple, Union 3 4import numpy as np 5 6 7class ConjugateModel(ABC): 8 """ 9 Base class for Bayesian conjugate models. 10 """ 11 12 def __init__( 13 self, 14 n_states: int, 15 n_actions: int, 16 hyper_params: Union[ 17 List[float], 18 List[ 19 List[ 20 float, 21 ] 22 ], 23 ], 24 seed: int, 25 ): 26 """ 27 Parameters 28 ---------- 29 n_states : int 30 The number of states of the MDP. 31 n_actions : int 32 The number of action of the MDP. 33 hyper_params : Union[List[float],List[List[float]]] 34 The prior parameters can either be a list of parameters that are set identical for each 35 state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters 36 as value. 37 seed : int 38 The random seed. 39 """ 40 41 self.n_actions = n_actions 42 self.n_states = n_states 43 self._rng = np.random.RandomState(seed) 44 45 if type(hyper_params[0]) in [int, float] or "numpy.flo" in str( 46 type(hyper_params[0]) 47 ): 48 # same priors for each state action pair 49 self.hyper_params = np.tile(hyper_params, (n_states, n_actions, 1)).astype( 50 np.float32 51 ) 52 elif type(hyper_params[0]) in [list, tuple, np.ndarray]: 53 # each state action pair has a different prior 54 self.hyper_params = np.array(hyper_params, np.float32) 55 else: 56 raise ValueError( 57 f"Received incorrect parameters with type " 58 f"{type(hyper_params), type(hyper_params[0])}" 59 ) 60 61 @abstractmethod 62 def update_sa(self, s: int, a: int, xs: List): 63 """ 64 updates the beliefs of the given state action pair. 65 Parameters 66 ---------- 67 s : int 68 The state to update. 69 a : int 70 The action to update. 71 xs : List 72 The samples obtained from state action pair (s,a). 73 """ 74 75 @abstractmethod 76 def sample(self, n: int = 1) -> np.ndarray: 77 """ 78 samples from the posterior 79 Parameters 80 ---------- 81 n : int 82 The number of samples. By default, it is set to one. 83 84 Returns 85 ------- 86 np.ndarray 87 The n samples from the posterior. 88 """ 89 90 @abstractmethod 91 def get_map_estimate(self) -> np.ndarray: 92 """ 93 computes the maximum a posterior estimate. 94 Returns 95 ------- 96 np.ndarray 97 The maximum a posteriori estimates 98 """ 99 100 def update_single_transition(self, s: int, a: int, x: Any): 101 """ 102 updates the posterior for a single transition. 103 Parameters 104 ---------- 105 s : int 106 The state to update. 107 a : int 108 The action to update. 109 x : Any 110 A sample obtained from state action pair (s,a). 111 """ 112 self.update_sa(s, a, [x]) 113 114 def update(self, data: Dict[Tuple[int, int], List[float]]): 115 """ 116 updates the Bayesian model. 117 Parameters 118 ---------- 119 data : Dict[Tuple[int, int], List[float]] 120 the data to be used to update the model using Bayes rule. 121 """ 122 for (s, a), xs in data.items(): 123 self.update_sa(s, a, xs)
class
ConjugateModel(abc.ABC):
8class ConjugateModel(ABC): 9 """ 10 Base class for Bayesian conjugate models. 11 """ 12 13 def __init__( 14 self, 15 n_states: int, 16 n_actions: int, 17 hyper_params: Union[ 18 List[float], 19 List[ 20 List[ 21 float, 22 ] 23 ], 24 ], 25 seed: int, 26 ): 27 """ 28 Parameters 29 ---------- 30 n_states : int 31 The number of states of the MDP. 32 n_actions : int 33 The number of action of the MDP. 34 hyper_params : Union[List[float],List[List[float]]] 35 The prior parameters can either be a list of parameters that are set identical for each 36 state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters 37 as value. 38 seed : int 39 The random seed. 40 """ 41 42 self.n_actions = n_actions 43 self.n_states = n_states 44 self._rng = np.random.RandomState(seed) 45 46 if type(hyper_params[0]) in [int, float] or "numpy.flo" in str( 47 type(hyper_params[0]) 48 ): 49 # same priors for each state action pair 50 self.hyper_params = np.tile(hyper_params, (n_states, n_actions, 1)).astype( 51 np.float32 52 ) 53 elif type(hyper_params[0]) in [list, tuple, np.ndarray]: 54 # each state action pair has a different prior 55 self.hyper_params = np.array(hyper_params, np.float32) 56 else: 57 raise ValueError( 58 f"Received incorrect parameters with type " 59 f"{type(hyper_params), type(hyper_params[0])}" 60 ) 61 62 @abstractmethod 63 def update_sa(self, s: int, a: int, xs: List): 64 """ 65 updates the beliefs of the given state action pair. 66 Parameters 67 ---------- 68 s : int 69 The state to update. 70 a : int 71 The action to update. 72 xs : List 73 The samples obtained from state action pair (s,a). 74 """ 75 76 @abstractmethod 77 def sample(self, n: int = 1) -> np.ndarray: 78 """ 79 samples from the posterior 80 Parameters 81 ---------- 82 n : int 83 The number of samples. By default, it is set to one. 84 85 Returns 86 ------- 87 np.ndarray 88 The n samples from the posterior. 89 """ 90 91 @abstractmethod 92 def get_map_estimate(self) -> np.ndarray: 93 """ 94 computes the maximum a posterior estimate. 95 Returns 96 ------- 97 np.ndarray 98 The maximum a posteriori estimates 99 """ 100 101 def update_single_transition(self, s: int, a: int, x: Any): 102 """ 103 updates the posterior for a single transition. 104 Parameters 105 ---------- 106 s : int 107 The state to update. 108 a : int 109 The action to update. 110 x : Any 111 A sample obtained from state action pair (s,a). 112 """ 113 self.update_sa(s, a, [x]) 114 115 def update(self, data: Dict[Tuple[int, int], List[float]]): 116 """ 117 updates the Bayesian model. 118 Parameters 119 ---------- 120 data : Dict[Tuple[int, int], List[float]] 121 the data to be used to update the model using Bayes rule. 122 """ 123 for (s, a), xs in data.items(): 124 self.update_sa(s, a, xs)
Base class for Bayesian conjugate models.
ConjugateModel( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int)
13 def __init__( 14 self, 15 n_states: int, 16 n_actions: int, 17 hyper_params: Union[ 18 List[float], 19 List[ 20 List[ 21 float, 22 ] 23 ], 24 ], 25 seed: int, 26 ): 27 """ 28 Parameters 29 ---------- 30 n_states : int 31 The number of states of the MDP. 32 n_actions : int 33 The number of action of the MDP. 34 hyper_params : Union[List[float],List[List[float]]] 35 The prior parameters can either be a list of parameters that are set identical for each 36 state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters 37 as value. 38 seed : int 39 The random seed. 40 """ 41 42 self.n_actions = n_actions 43 self.n_states = n_states 44 self._rng = np.random.RandomState(seed) 45 46 if type(hyper_params[0]) in [int, float] or "numpy.flo" in str( 47 type(hyper_params[0]) 48 ): 49 # same priors for each state action pair 50 self.hyper_params = np.tile(hyper_params, (n_states, n_actions, 1)).astype( 51 np.float32 52 ) 53 elif type(hyper_params[0]) in [list, tuple, np.ndarray]: 54 # each state action pair has a different prior 55 self.hyper_params = np.array(hyper_params, np.float32) 56 else: 57 raise ValueError( 58 f"Received incorrect parameters with type " 59 f"{type(hyper_params), type(hyper_params[0])}" 60 )
Parameters
- n_states (int): The number of states of the MDP.
- n_actions (int): The number of action of the MDP.
- hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
- seed (int): The random seed.
@abstractmethod
def
update_sa(self, s: int, a: int, xs: List):
62 @abstractmethod 63 def update_sa(self, s: int, a: int, xs: List): 64 """ 65 updates the beliefs of the given state action pair. 66 Parameters 67 ---------- 68 s : int 69 The state to update. 70 a : int 71 The action to update. 72 xs : List 73 The samples obtained from state action pair (s,a). 74 """
updates the beliefs of the given state action pair.
Parameters
- s (int): The state to update.
- a (int): The action to update.
- xs (List): The samples obtained from state action pair (s,a).
@abstractmethod
def
sample(self, n: int = 1) -> numpy.ndarray:
76 @abstractmethod 77 def sample(self, n: int = 1) -> np.ndarray: 78 """ 79 samples from the posterior 80 Parameters 81 ---------- 82 n : int 83 The number of samples. By default, it is set to one. 84 85 Returns 86 ------- 87 np.ndarray 88 The n samples from the posterior. 89 """
samples from the posterior
Parameters
- n (int): The number of samples. By default, it is set to one.
Returns
- np.ndarray: The n samples from the posterior.
@abstractmethod
def
get_map_estimate(self) -> numpy.ndarray:
91 @abstractmethod 92 def get_map_estimate(self) -> np.ndarray: 93 """ 94 computes the maximum a posterior estimate. 95 Returns 96 ------- 97 np.ndarray 98 The maximum a posteriori estimates 99 """
computes the maximum a posterior estimate.
Returns
- np.ndarray: The maximum a posteriori estimates
def
update_single_transition(self, s: int, a: int, x: Any):
101 def update_single_transition(self, s: int, a: int, x: Any): 102 """ 103 updates the posterior for a single transition. 104 Parameters 105 ---------- 106 s : int 107 The state to update. 108 a : int 109 The action to update. 110 x : Any 111 A sample obtained from state action pair (s,a). 112 """ 113 self.update_sa(s, a, [x])
updates the posterior for a single transition.
Parameters
- s (int): The state to update.
- a (int): The action to update.
- x (Any): A sample obtained from state action pair (s,a).
def
update(self, data: Dict[Tuple[int, int], List[float]]):
115 def update(self, data: Dict[Tuple[int, int], List[float]]): 116 """ 117 updates the Bayesian model. 118 Parameters 119 ---------- 120 data : Dict[Tuple[int, int], List[float]] 121 the data to be used to update the model using Bayes rule. 122 """ 123 for (s, a), xs in data.items(): 124 self.update_sa(s, a, xs)
updates the Bayesian model.
Parameters
- data (Dict[Tuple[int, int], List[float]]): the data to be used to update the model using Bayes rule.