colosseum.agent.mdp_models.bayesian_models.base_conjugate

  1from abc import ABC, abstractmethod
  2from typing import Any, Dict, List, Tuple, Union
  3
  4import numpy as np
  5
  6
  7class ConjugateModel(ABC):
  8    """
  9    Base class for Bayesian conjugate models.
 10    """
 11
 12    def __init__(
 13        self,
 14        n_states: int,
 15        n_actions: int,
 16        hyper_params: Union[
 17            List[float],
 18            List[
 19                List[
 20                    float,
 21                ]
 22            ],
 23        ],
 24        seed: int,
 25    ):
 26        """
 27        Parameters
 28        ----------
 29        n_states : int
 30            The number of states of the MDP.
 31        n_actions : int
 32            The number of action of the MDP.
 33        hyper_params : Union[List[float],List[List[float]]]
 34            The prior parameters can either be a list of parameters that are set identical for each
 35            state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters
 36            as value.
 37        seed : int
 38            The random seed.
 39        """
 40
 41        self.n_actions = n_actions
 42        self.n_states = n_states
 43        self._rng = np.random.RandomState(seed)
 44
 45        if type(hyper_params[0]) in [int, float] or "numpy.flo" in str(
 46            type(hyper_params[0])
 47        ):
 48            # same priors for each state action pair
 49            self.hyper_params = np.tile(hyper_params, (n_states, n_actions, 1)).astype(
 50                np.float32
 51            )
 52        elif type(hyper_params[0]) in [list, tuple, np.ndarray]:
 53            # each state action pair has a different prior
 54            self.hyper_params = np.array(hyper_params, np.float32)
 55        else:
 56            raise ValueError(
 57                f"Received incorrect parameters  with type "
 58                f"{type(hyper_params), type(hyper_params[0])}"
 59            )
 60
 61    @abstractmethod
 62    def update_sa(self, s: int, a: int, xs: List):
 63        """
 64        updates the beliefs of the given state action pair.
 65        Parameters
 66        ----------
 67        s : int
 68            The state to update.
 69        a : int
 70            The action to update.
 71        xs : List
 72            The samples obtained from state action pair (s,a).
 73        """
 74
 75    @abstractmethod
 76    def sample(self, n: int = 1) -> np.ndarray:
 77        """
 78        samples from the posterior
 79        Parameters
 80        ----------
 81        n : int
 82            The number of samples. By default, it is set to one.
 83
 84        Returns
 85        -------
 86        np.ndarray
 87            The n samples from the posterior.
 88        """
 89
 90    @abstractmethod
 91    def get_map_estimate(self) -> np.ndarray:
 92        """
 93        computes the maximum a posterior estimate.
 94        Returns
 95        -------
 96        np.ndarray
 97            The maximum a posteriori estimates
 98        """
 99
100    def update_single_transition(self, s: int, a: int, x: Any):
101        """
102        updates the posterior for a single transition.
103        Parameters
104        ----------
105        s : int
106            The state to update.
107        a : int
108            The action to update.
109        x : Any
110            A sample obtained from state action pair (s,a).
111        """
112        self.update_sa(s, a, [x])
113
114    def update(self, data: Dict[Tuple[int, int], List[float]]):
115        """
116        updates the Bayesian model.
117        Parameters
118        ----------
119        data : Dict[Tuple[int, int], List[float]]
120            the data to be used to update the model using Bayes rule.
121        """
122        for (s, a), xs in data.items():
123            self.update_sa(s, a, xs)
class ConjugateModel(abc.ABC):
  8class ConjugateModel(ABC):
  9    """
 10    Base class for Bayesian conjugate models.
 11    """
 12
 13    def __init__(
 14        self,
 15        n_states: int,
 16        n_actions: int,
 17        hyper_params: Union[
 18            List[float],
 19            List[
 20                List[
 21                    float,
 22                ]
 23            ],
 24        ],
 25        seed: int,
 26    ):
 27        """
 28        Parameters
 29        ----------
 30        n_states : int
 31            The number of states of the MDP.
 32        n_actions : int
 33            The number of action of the MDP.
 34        hyper_params : Union[List[float],List[List[float]]]
 35            The prior parameters can either be a list of parameters that are set identical for each
 36            state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters
 37            as value.
 38        seed : int
 39            The random seed.
 40        """
 41
 42        self.n_actions = n_actions
 43        self.n_states = n_states
 44        self._rng = np.random.RandomState(seed)
 45
 46        if type(hyper_params[0]) in [int, float] or "numpy.flo" in str(
 47            type(hyper_params[0])
 48        ):
 49            # same priors for each state action pair
 50            self.hyper_params = np.tile(hyper_params, (n_states, n_actions, 1)).astype(
 51                np.float32
 52            )
 53        elif type(hyper_params[0]) in [list, tuple, np.ndarray]:
 54            # each state action pair has a different prior
 55            self.hyper_params = np.array(hyper_params, np.float32)
 56        else:
 57            raise ValueError(
 58                f"Received incorrect parameters  with type "
 59                f"{type(hyper_params), type(hyper_params[0])}"
 60            )
 61
 62    @abstractmethod
 63    def update_sa(self, s: int, a: int, xs: List):
 64        """
 65        updates the beliefs of the given state action pair.
 66        Parameters
 67        ----------
 68        s : int
 69            The state to update.
 70        a : int
 71            The action to update.
 72        xs : List
 73            The samples obtained from state action pair (s,a).
 74        """
 75
 76    @abstractmethod
 77    def sample(self, n: int = 1) -> np.ndarray:
 78        """
 79        samples from the posterior
 80        Parameters
 81        ----------
 82        n : int
 83            The number of samples. By default, it is set to one.
 84
 85        Returns
 86        -------
 87        np.ndarray
 88            The n samples from the posterior.
 89        """
 90
 91    @abstractmethod
 92    def get_map_estimate(self) -> np.ndarray:
 93        """
 94        computes the maximum a posterior estimate.
 95        Returns
 96        -------
 97        np.ndarray
 98            The maximum a posteriori estimates
 99        """
100
101    def update_single_transition(self, s: int, a: int, x: Any):
102        """
103        updates the posterior for a single transition.
104        Parameters
105        ----------
106        s : int
107            The state to update.
108        a : int
109            The action to update.
110        x : Any
111            A sample obtained from state action pair (s,a).
112        """
113        self.update_sa(s, a, [x])
114
115    def update(self, data: Dict[Tuple[int, int], List[float]]):
116        """
117        updates the Bayesian model.
118        Parameters
119        ----------
120        data : Dict[Tuple[int, int], List[float]]
121            the data to be used to update the model using Bayes rule.
122        """
123        for (s, a), xs in data.items():
124            self.update_sa(s, a, xs)

Base class for Bayesian conjugate models.

ConjugateModel( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int)
13    def __init__(
14        self,
15        n_states: int,
16        n_actions: int,
17        hyper_params: Union[
18            List[float],
19            List[
20                List[
21                    float,
22                ]
23            ],
24        ],
25        seed: int,
26    ):
27        """
28        Parameters
29        ----------
30        n_states : int
31            The number of states of the MDP.
32        n_actions : int
33            The number of action of the MDP.
34        hyper_params : Union[List[float],List[List[float]]]
35            The prior parameters can either be a list of parameters that are set identical for each
36            state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters
37            as value.
38        seed : int
39            The random seed.
40        """
41
42        self.n_actions = n_actions
43        self.n_states = n_states
44        self._rng = np.random.RandomState(seed)
45
46        if type(hyper_params[0]) in [int, float] or "numpy.flo" in str(
47            type(hyper_params[0])
48        ):
49            # same priors for each state action pair
50            self.hyper_params = np.tile(hyper_params, (n_states, n_actions, 1)).astype(
51                np.float32
52            )
53        elif type(hyper_params[0]) in [list, tuple, np.ndarray]:
54            # each state action pair has a different prior
55            self.hyper_params = np.array(hyper_params, np.float32)
56        else:
57            raise ValueError(
58                f"Received incorrect parameters  with type "
59                f"{type(hyper_params), type(hyper_params[0])}"
60            )
Parameters
  • n_states (int): The number of states of the MDP.
  • n_actions (int): The number of action of the MDP.
  • hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
  • seed (int): The random seed.
@abstractmethod
def update_sa(self, s: int, a: int, xs: List):
62    @abstractmethod
63    def update_sa(self, s: int, a: int, xs: List):
64        """
65        updates the beliefs of the given state action pair.
66        Parameters
67        ----------
68        s : int
69            The state to update.
70        a : int
71            The action to update.
72        xs : List
73            The samples obtained from state action pair (s,a).
74        """

updates the beliefs of the given state action pair.

Parameters
  • s (int): The state to update.
  • a (int): The action to update.
  • xs (List): The samples obtained from state action pair (s,a).
@abstractmethod
def sample(self, n: int = 1) -> numpy.ndarray:
76    @abstractmethod
77    def sample(self, n: int = 1) -> np.ndarray:
78        """
79        samples from the posterior
80        Parameters
81        ----------
82        n : int
83            The number of samples. By default, it is set to one.
84
85        Returns
86        -------
87        np.ndarray
88            The n samples from the posterior.
89        """

samples from the posterior

Parameters
  • n (int): The number of samples. By default, it is set to one.
Returns
  • np.ndarray: The n samples from the posterior.
@abstractmethod
def get_map_estimate(self) -> numpy.ndarray:
91    @abstractmethod
92    def get_map_estimate(self) -> np.ndarray:
93        """
94        computes the maximum a posterior estimate.
95        Returns
96        -------
97        np.ndarray
98            The maximum a posteriori estimates
99        """

computes the maximum a posterior estimate.

Returns
  • np.ndarray: The maximum a posteriori estimates
def update_single_transition(self, s: int, a: int, x: Any):
101    def update_single_transition(self, s: int, a: int, x: Any):
102        """
103        updates the posterior for a single transition.
104        Parameters
105        ----------
106        s : int
107            The state to update.
108        a : int
109            The action to update.
110        x : Any
111            A sample obtained from state action pair (s,a).
112        """
113        self.update_sa(s, a, [x])

updates the posterior for a single transition.

Parameters
  • s (int): The state to update.
  • a (int): The action to update.
  • x (Any): A sample obtained from state action pair (s,a).
def update(self, data: Dict[Tuple[int, int], List[float]]):
115    def update(self, data: Dict[Tuple[int, int], List[float]]):
116        """
117        updates the Bayesian model.
118        Parameters
119        ----------
120        data : Dict[Tuple[int, int], List[float]]
121            the data to be used to update the model using Bayes rule.
122        """
123        for (s, a), xs in data.items():
124            self.update_sa(s, a, xs)

updates the Bayesian model.

Parameters
  • data (Dict[Tuple[int, int], List[float]]): the data to be used to update the model using Bayes rule.