colosseum.agent.mdp_models.bayesian_models.conjugate_rewards

  1from typing import Dict, List, Tuple, Union
  2
  3import numpy as np
  4
  5from colosseum.agent.mdp_models.bayesian_models import ConjugateModel
  6
  7PRIOR_TYPE = Union[
  8    List[
  9        float,
 10    ],
 11    Dict[
 12        Tuple[int, int],
 13        List[
 14            float,
 15        ],
 16    ],
 17]
 18
 19
 20class N_NIG(ConjugateModel):
 21    """
 22    The Normal-Normal Inverse Gamma conjugate model.
 23    """
 24
 25    def __init__(
 26        self,
 27        n_states: int,
 28        n_actions: int,
 29        hyper_params: Union[List[float], List[List[float]]],
 30        seed: int,
 31        interpretable_parameters: bool = True,
 32    ):
 33        """
 34        Parameters
 35        ----------
 36        n_states : int
 37            The number of states of the MDP.
 38        n_actions : int
 39            The number of action of the MDP.
 40        hyper_params : Union[List[float],List[List[float]]]
 41            The prior parameters can either be a list of parameters that are set identical for each
 42            state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters
 43            as value.
 44        seed : int
 45            The random seed.
 46        interpretable_parameters : bool
 47            If True, the parameters are given in the natural way of speaking of NIG parameters.
 48        """
 49
 50        super(N_NIG, self).__init__(n_states, n_actions, hyper_params, seed)
 51
 52        assert self.hyper_params.shape == (n_states, n_actions, 4)
 53
 54        if interpretable_parameters:
 55            for i in range(self.hyper_params.shape[0]):
 56                for j in range(self.hyper_params.shape[1]):
 57                    mu, n_mu, tau, n_tau = self.hyper_params[i, j]
 58                    self.hyper_params[i, j] = (
 59                        mu,
 60                        n_mu,
 61                        n_tau * 0.5,
 62                        (0.5 * n_tau) / tau,
 63                    )
 64
 65    def update_sa(self, s: int, a: int, rs: List[float]):
 66        # Unpack the prior
 67        (mu0, lambda0, alpha0, beta0) = self.hyper_params[s, a]
 68
 69        n = len(rs)
 70        y_bar = np.mean(rs)
 71
 72        # Updating normal component
 73        lambda1 = lambda0 + n
 74        mu1 = (lambda0 * mu0 + n * y_bar) / lambda1
 75
 76        # Updating Inverse-Gamma component
 77        alpha1 = alpha0 + (n * 0.5)
 78        ssq = n * np.var(rs)
 79        prior_disc = lambda0 * n * ((y_bar - mu0) ** 2) / lambda1
 80        beta1 = beta0 + 0.5 * (ssq + prior_disc)
 81
 82        self.hyper_params[s, a] = (mu1, lambda1, alpha1, beta1)
 83
 84    def sample(self, n: int = 1) -> np.ndarray:
 85        # Unpack the prior
 86        (mu, lambda0, alpha, beta) = self.hyper_params.reshape(
 87            self.n_states * self.n_actions, -1
 88        ).T
 89
 90        # Sample scaling tau from a gamma distribution
 91        tau = self._rng.gamma(shape=alpha, scale=1.0 / beta).astype(np.float32)
 92        var = 1.0 / (lambda0 * tau)
 93
 94        # Sample mean from normal mean mu, var
 95        mean = self._rng.normal(loc=mu, scale=np.sqrt(var), size=(n, *mu.shape)).astype(
 96            np.float32
 97        )
 98
 99        return mean.reshape(self.n_states, self.n_actions).squeeze()
100
101    def get_map_estimate(self) -> np.ndarray:
102        return self.hyper_params[:, :, 0]
103
104
105class N_N(ConjugateModel):
106    """
107    The Normal-Normal conjugate model.
108    """
109
110    def __init__(
111        self,
112        n_states: int,
113        n_actions: int,
114        hyper_params: Union[List[float], List[List[float]]],
115        seed: int,
116    ):
117        super(N_N, self).__init__(n_states, n_actions, hyper_params, seed)
118
119        assert self.hyper_params.shape == (n_states, n_actions, 2)
120
121    def update_sa(self, s: int, a: int, xs: List[float]):
122        for r in xs:
123            mu0, tau0 = self.hyper_params[s, a]
124            tau1 = tau0 + 1
125            mu1 = (mu0 * tau0 + r * 1) / tau1
126            self.hyper_params[s, a] = (mu1, tau1)
127
128    def sample(self, n: int = 1) -> np.ndarray:
129        return (
130            self._rng.normal(
131                loc=self.hyper_params[:, :, 0], scale=self.hyper_params[:, :, 1], size=n
132            )
133            .astype(np.float32)
134            .squeeze()
135        )
136
137    def get_map_estimate(self) -> np.ndarray:
138        return self.hyper_params[:, :, 0]
 21class N_NIG(ConjugateModel):
 22    """
 23    The Normal-Normal Inverse Gamma conjugate model.
 24    """
 25
 26    def __init__(
 27        self,
 28        n_states: int,
 29        n_actions: int,
 30        hyper_params: Union[List[float], List[List[float]]],
 31        seed: int,
 32        interpretable_parameters: bool = True,
 33    ):
 34        """
 35        Parameters
 36        ----------
 37        n_states : int
 38            The number of states of the MDP.
 39        n_actions : int
 40            The number of action of the MDP.
 41        hyper_params : Union[List[float],List[List[float]]]
 42            The prior parameters can either be a list of parameters that are set identical for each
 43            state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters
 44            as value.
 45        seed : int
 46            The random seed.
 47        interpretable_parameters : bool
 48            If True, the parameters are given in the natural way of speaking of NIG parameters.
 49        """
 50
 51        super(N_NIG, self).__init__(n_states, n_actions, hyper_params, seed)
 52
 53        assert self.hyper_params.shape == (n_states, n_actions, 4)
 54
 55        if interpretable_parameters:
 56            for i in range(self.hyper_params.shape[0]):
 57                for j in range(self.hyper_params.shape[1]):
 58                    mu, n_mu, tau, n_tau = self.hyper_params[i, j]
 59                    self.hyper_params[i, j] = (
 60                        mu,
 61                        n_mu,
 62                        n_tau * 0.5,
 63                        (0.5 * n_tau) / tau,
 64                    )
 65
 66    def update_sa(self, s: int, a: int, rs: List[float]):
 67        # Unpack the prior
 68        (mu0, lambda0, alpha0, beta0) = self.hyper_params[s, a]
 69
 70        n = len(rs)
 71        y_bar = np.mean(rs)
 72
 73        # Updating normal component
 74        lambda1 = lambda0 + n
 75        mu1 = (lambda0 * mu0 + n * y_bar) / lambda1
 76
 77        # Updating Inverse-Gamma component
 78        alpha1 = alpha0 + (n * 0.5)
 79        ssq = n * np.var(rs)
 80        prior_disc = lambda0 * n * ((y_bar - mu0) ** 2) / lambda1
 81        beta1 = beta0 + 0.5 * (ssq + prior_disc)
 82
 83        self.hyper_params[s, a] = (mu1, lambda1, alpha1, beta1)
 84
 85    def sample(self, n: int = 1) -> np.ndarray:
 86        # Unpack the prior
 87        (mu, lambda0, alpha, beta) = self.hyper_params.reshape(
 88            self.n_states * self.n_actions, -1
 89        ).T
 90
 91        # Sample scaling tau from a gamma distribution
 92        tau = self._rng.gamma(shape=alpha, scale=1.0 / beta).astype(np.float32)
 93        var = 1.0 / (lambda0 * tau)
 94
 95        # Sample mean from normal mean mu, var
 96        mean = self._rng.normal(loc=mu, scale=np.sqrt(var), size=(n, *mu.shape)).astype(
 97            np.float32
 98        )
 99
100        return mean.reshape(self.n_states, self.n_actions).squeeze()
101
102    def get_map_estimate(self) -> np.ndarray:
103        return self.hyper_params[:, :, 0]

The Normal-Normal Inverse Gamma conjugate model.

N_NIG( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int, interpretable_parameters: bool = True)
26    def __init__(
27        self,
28        n_states: int,
29        n_actions: int,
30        hyper_params: Union[List[float], List[List[float]]],
31        seed: int,
32        interpretable_parameters: bool = True,
33    ):
34        """
35        Parameters
36        ----------
37        n_states : int
38            The number of states of the MDP.
39        n_actions : int
40            The number of action of the MDP.
41        hyper_params : Union[List[float],List[List[float]]]
42            The prior parameters can either be a list of parameters that are set identical for each
43            state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters
44            as value.
45        seed : int
46            The random seed.
47        interpretable_parameters : bool
48            If True, the parameters are given in the natural way of speaking of NIG parameters.
49        """
50
51        super(N_NIG, self).__init__(n_states, n_actions, hyper_params, seed)
52
53        assert self.hyper_params.shape == (n_states, n_actions, 4)
54
55        if interpretable_parameters:
56            for i in range(self.hyper_params.shape[0]):
57                for j in range(self.hyper_params.shape[1]):
58                    mu, n_mu, tau, n_tau = self.hyper_params[i, j]
59                    self.hyper_params[i, j] = (
60                        mu,
61                        n_mu,
62                        n_tau * 0.5,
63                        (0.5 * n_tau) / tau,
64                    )
Parameters
  • n_states (int): The number of states of the MDP.
  • n_actions (int): The number of action of the MDP.
  • hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
  • seed (int): The random seed.
  • interpretable_parameters (bool): If True, the parameters are given in the natural way of speaking of NIG parameters.
def update_sa(self, s: int, a: int, rs: List[float]):
66    def update_sa(self, s: int, a: int, rs: List[float]):
67        # Unpack the prior
68        (mu0, lambda0, alpha0, beta0) = self.hyper_params[s, a]
69
70        n = len(rs)
71        y_bar = np.mean(rs)
72
73        # Updating normal component
74        lambda1 = lambda0 + n
75        mu1 = (lambda0 * mu0 + n * y_bar) / lambda1
76
77        # Updating Inverse-Gamma component
78        alpha1 = alpha0 + (n * 0.5)
79        ssq = n * np.var(rs)
80        prior_disc = lambda0 * n * ((y_bar - mu0) ** 2) / lambda1
81        beta1 = beta0 + 0.5 * (ssq + prior_disc)
82
83        self.hyper_params[s, a] = (mu1, lambda1, alpha1, beta1)

updates the beliefs of the given state action pair.

Parameters
  • s (int): The state to update.
  • a (int): The action to update.
  • xs (List): The samples obtained from state action pair (s,a).
def sample(self, n: int = 1) -> numpy.ndarray:
 85    def sample(self, n: int = 1) -> np.ndarray:
 86        # Unpack the prior
 87        (mu, lambda0, alpha, beta) = self.hyper_params.reshape(
 88            self.n_states * self.n_actions, -1
 89        ).T
 90
 91        # Sample scaling tau from a gamma distribution
 92        tau = self._rng.gamma(shape=alpha, scale=1.0 / beta).astype(np.float32)
 93        var = 1.0 / (lambda0 * tau)
 94
 95        # Sample mean from normal mean mu, var
 96        mean = self._rng.normal(loc=mu, scale=np.sqrt(var), size=(n, *mu.shape)).astype(
 97            np.float32
 98        )
 99
100        return mean.reshape(self.n_states, self.n_actions).squeeze()

samples from the posterior

Parameters
  • n (int): The number of samples. By default, it is set to one.
Returns
  • np.ndarray: The n samples from the posterior.
def get_map_estimate(self) -> numpy.ndarray:
102    def get_map_estimate(self) -> np.ndarray:
103        return self.hyper_params[:, :, 0]

computes the maximum a posterior estimate.

Returns
  • np.ndarray: The maximum a posteriori estimates
106class N_N(ConjugateModel):
107    """
108    The Normal-Normal conjugate model.
109    """
110
111    def __init__(
112        self,
113        n_states: int,
114        n_actions: int,
115        hyper_params: Union[List[float], List[List[float]]],
116        seed: int,
117    ):
118        super(N_N, self).__init__(n_states, n_actions, hyper_params, seed)
119
120        assert self.hyper_params.shape == (n_states, n_actions, 2)
121
122    def update_sa(self, s: int, a: int, xs: List[float]):
123        for r in xs:
124            mu0, tau0 = self.hyper_params[s, a]
125            tau1 = tau0 + 1
126            mu1 = (mu0 * tau0 + r * 1) / tau1
127            self.hyper_params[s, a] = (mu1, tau1)
128
129    def sample(self, n: int = 1) -> np.ndarray:
130        return (
131            self._rng.normal(
132                loc=self.hyper_params[:, :, 0], scale=self.hyper_params[:, :, 1], size=n
133            )
134            .astype(np.float32)
135            .squeeze()
136        )
137
138    def get_map_estimate(self) -> np.ndarray:
139        return self.hyper_params[:, :, 0]

The Normal-Normal conjugate model.

N_N( n_states: int, n_actions: int, hyper_params: Union[List[float], List[List[float]]], seed: int)
111    def __init__(
112        self,
113        n_states: int,
114        n_actions: int,
115        hyper_params: Union[List[float], List[List[float]]],
116        seed: int,
117    ):
118        super(N_N, self).__init__(n_states, n_actions, hyper_params, seed)
119
120        assert self.hyper_params.shape == (n_states, n_actions, 2)
Parameters
  • n_states (int): The number of states of the MDP.
  • n_actions (int): The number of action of the MDP.
  • hyper_params (Union[List[float],List[List[float]]]): The prior parameters can either be a list of parameters that are set identical for each state-action pair, or it can be a dictionary with the state action pair as key and a list of parameters as value.
  • seed (int): The random seed.
def update_sa(self, s: int, a: int, xs: List[float]):
122    def update_sa(self, s: int, a: int, xs: List[float]):
123        for r in xs:
124            mu0, tau0 = self.hyper_params[s, a]
125            tau1 = tau0 + 1
126            mu1 = (mu0 * tau0 + r * 1) / tau1
127            self.hyper_params[s, a] = (mu1, tau1)

updates the beliefs of the given state action pair.

Parameters
  • s (int): The state to update.
  • a (int): The action to update.
  • xs (List): The samples obtained from state action pair (s,a).
def sample(self, n: int = 1) -> numpy.ndarray:
129    def sample(self, n: int = 1) -> np.ndarray:
130        return (
131            self._rng.normal(
132                loc=self.hyper_params[:, :, 0], scale=self.hyper_params[:, :, 1], size=n
133            )
134            .astype(np.float32)
135            .squeeze()
136        )

samples from the posterior

Parameters
  • n (int): The number of samples. By default, it is set to one.
Returns
  • np.ndarray: The n samples from the posterior.
def get_map_estimate(self) -> numpy.ndarray:
138    def get_map_estimate(self) -> np.ndarray:
139        return self.hyper_params[:, :, 0]

computes the maximum a posterior estimate.

Returns
  • np.ndarray: The maximum a posteriori estimates