colosseum.agent.mdp_models.bayesian_models.conjugate_transitions

 1from typing import List, Tuple, Union
 2
 3import numpy as np
 4
 5from colosseum.agent.mdp_models.bayesian_models import ConjugateModel
 6from colosseum.utils.miscellanea import state_occurencens_to_counts
 7
 8PRIOR_TYPE = Union[
 9    List[
10        float,
11    ],
12    List[
13        List[
14            float,
15        ],
16    ],
17]
18
19
20class M_DIR(ConjugateModel):
21    """
22    Multinomial-Dirichlet conjugate model.
23    """
24
25    def __init__(
26        self,
27        n_states: int,
28        n_actions: int,
29        hyper_params: Union[
30            List[float],
31            List[
32                List[
33                    float,
34                ]
35            ],
36        ],
37        seed: int,
38    ):
39        super(M_DIR, self).__init__(n_states, n_actions, hyper_params, seed)
40        if self.hyper_params.shape == (n_states, n_actions, 1):
41            self.hyper_params = np.tile(self.hyper_params, (1, 1, n_states))
42        assert self.hyper_params.shape == (n_states, n_actions, n_states)
43
44    def update_sa(self, s: int, a: int, xs: List[int]):
45        xs = [state_occurencens_to_counts(x, self.n_states) for x in xs]
46        self.hyper_params[s, a] += np.array(xs).sum(0)
47
48    def _sample(self, hyper_params: np.ndarray, n: int) -> np.ndarray:
49        r = (
50            self._rng.standard_gamma(hyper_params, (n, *hyper_params.shape))
51            .astype(np.float32)
52            .squeeze()
53        )
54        return r / (1e-5 + r.sum(-1, keepdims=True))
55
56    def sample(self, n: int = 1) -> np.ndarray:
57        r = self._sample(
58            self.hyper_params.reshape(self.n_states * self.n_actions, -1), n
59        )
60        return r.reshape((self.n_states, self.n_actions, -1))
61
62    def sample_sa(self, sa: Tuple[int, int]) -> np.ndarray:
63        return self._sample(self.hyper_params[sa], 1)
64
65    def get_map_estimate(self) -> np.ndarray:
66        return self.hyper_params / self.hyper_params.sum(-1, keepdims=True)
21class M_DIR(ConjugateModel):
22    """
23    Multinomial-Dirichlet conjugate model.
24    """
25
26    def __init__(
27        self,
28        n_states: int,
29        n_actions: int,
30        hyper_params: Union[
31            List[float],
32            List[
33                List[
34                    float,
35                ]
36            ],
37        ],
38        seed: int,
39    ):
40        super(M_DIR, self).__init__(n_states, n_actions, hyper_params, seed)
41        if self.hyper_params.shape == (n_states, n_actions, 1):
42            self.hyper_params = np.tile(self.hyper_params, (1, 1, n_states))
43        assert self.hyper_params.shape == (n_states, n_actions, n_states)
44
45    def update_sa(self, s: int, a: int, xs: List[int]):
46        xs = [state_occurencens_to_counts(x, self.n_states) for x in xs]
47        self.hyper_params[s, a] += np.array(xs).sum(0)
48
49    def _sample(self, hyper_params: np.ndarray, n: int) -> np.ndarray:
50        r = (
51            self._rng.standard_gamma(hyper_params, (n, *hyper_params.shape))
52            .astype(np.float32)
53            .squeeze()
54        )
55        return r / (1e-5 + r.sum(-1, keepdims=True))
56
57    def sample(self, n: int = 1) -> np.ndarray:
58        r = self._sample(
59            self.hyper_params.reshape(self.n_states * self.n_actions, -1), n
60        )
61        return r.reshape((self.n_states, self.n_actions, -1))
62
63    def sample_sa(self, sa: Tuple[int, int]) -> np.ndarray:
64        return self._sample(self.hyper_params[sa], 1)
65
66    def get_map_estimate(self) -> np.ndarray:
67        return self.hyper_params / self.hyper_params.sum(-1, keepdims=True)

Multinomial-Dirichlet conjugate model.

M_DIR( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int)
26    def __init__(
27        self,
28        n_states: int,
29        n_actions: int,
30        hyper_params: Union[
31            List[float],
32            List[
33                List[
34                    float,
35                ]
36            ],
37        ],
38        seed: int,
39    ):
40        super(M_DIR, self).__init__(n_states, n_actions, hyper_params, seed)
41        if self.hyper_params.shape == (n_states, n_actions, 1):
42            self.hyper_params = np.tile(self.hyper_params, (1, 1, n_states))
43        assert self.hyper_params.shape == (n_states, n_actions, n_states)
Parameters
  • n_states (int): The number of states of the MDP.
  • n_actions (int): The number of action of the MDP.
  • hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
  • seed (int): The random seed.
def update_sa(self, s: int, a: int, xs: List[int]):
45    def update_sa(self, s: int, a: int, xs: List[int]):
46        xs = [state_occurencens_to_counts(x, self.n_states) for x in xs]
47        self.hyper_params[s, a] += np.array(xs).sum(0)

updates the beliefs of the given state action pair.

Parameters
  • s (int): The state to update.
  • a (int): The action to update.
  • xs (List): The samples obtained from state action pair (s,a).
def sample(self, n: int = 1) -> numpy.ndarray:
57    def sample(self, n: int = 1) -> np.ndarray:
58        r = self._sample(
59            self.hyper_params.reshape(self.n_states * self.n_actions, -1), n
60        )
61        return r.reshape((self.n_states, self.n_actions, -1))

samples from the posterior

Parameters
  • n (int): The number of samples. By default, it is set to one.
Returns
  • np.ndarray: The n samples from the posterior.
def sample_sa(self, sa: Tuple[int, int]) -> numpy.ndarray:
63    def sample_sa(self, sa: Tuple[int, int]) -> np.ndarray:
64        return self._sample(self.hyper_params[sa], 1)
def get_map_estimate(self) -> numpy.ndarray:
66    def get_map_estimate(self) -> np.ndarray:
67        return self.hyper_params / self.hyper_params.sum(-1, keepdims=True)

computes the maximum a posterior estimate.

Returns
  • np.ndarray: The maximum a posteriori estimates