colosseum.utils.non_tabular.bsuite

 1import abc
 2from typing import TYPE_CHECKING
 3
 4import dm_env
 5from bsuite.baselines.base import Agent as BAgent
 6
 7from colosseum.agent.agents.base import BaseAgent
 8from colosseum.emission_maps import EmissionMap
 9from colosseum.utils.acme.specs import MDPSpec
10
11if TYPE_CHECKING:
12    from colosseum.mdp import ACTION_TYPE
13
14
15class NonTabularBsuiteAgentWrapper(BaseAgent, abc.ABC):
16    """
17    A simple wrapper for `bsuite` agents.
18    """
19
20    @staticmethod
21    def is_emission_map_accepted(emission_map: "EmissionMap") -> bool:
22        return not emission_map.is_tabular
23
24    def is_episode_end(
25        self,
26        ts_t: dm_env.TimeStep,
27        a_t: "ACTION_TYPE",
28        ts_tp1: dm_env.TimeStep,
29        time: int,
30    ) -> bool:
31        return False
32
33    def __init__(
34        self,
35        seed: int,
36        agent: BAgent,
37        mdp_specs: MDPSpec,
38    ):
39        self._agent = agent
40        self._mdp_spec = mdp_specs
41        self.emission_map = mdp_specs.emission_map
42
43        super(NonTabularBsuiteAgentWrapper, self).__init__(
44            seed, mdp_specs, None, None, None
45        )
46
47    def select_action(self, ts: dm_env.TimeStep, time: int) -> "ACTION_TYPE":
48        return self._agent.select_action(ts)
49
50    def step_update(
51        self, ts_t: dm_env.TimeStep, a_t: "ACTION_TYPE", ts_tp1: dm_env.TimeStep, h: int
52    ):
53        self._agent.update(ts_t, a_t, ts_tp1)
54
55    def update_models(self):
56        pass
57
58    def _before_new_episode(self):
59        pass
60
61    def episode_end_update(self):
62        pass
63
64    def before_start_interacting(self):
65        pass
class NonTabularBsuiteAgentWrapper(colosseum.agent.agents.base.BaseAgent, abc.ABC):
16class NonTabularBsuiteAgentWrapper(BaseAgent, abc.ABC):
17    """
18    A simple wrapper for `bsuite` agents.
19    """
20
21    @staticmethod
22    def is_emission_map_accepted(emission_map: "EmissionMap") -> bool:
23        return not emission_map.is_tabular
24
25    def is_episode_end(
26        self,
27        ts_t: dm_env.TimeStep,
28        a_t: "ACTION_TYPE",
29        ts_tp1: dm_env.TimeStep,
30        time: int,
31    ) -> bool:
32        return False
33
34    def __init__(
35        self,
36        seed: int,
37        agent: BAgent,
38        mdp_specs: MDPSpec,
39    ):
40        self._agent = agent
41        self._mdp_spec = mdp_specs
42        self.emission_map = mdp_specs.emission_map
43
44        super(NonTabularBsuiteAgentWrapper, self).__init__(
45            seed, mdp_specs, None, None, None
46        )
47
48    def select_action(self, ts: dm_env.TimeStep, time: int) -> "ACTION_TYPE":
49        return self._agent.select_action(ts)
50
51    def step_update(
52        self, ts_t: dm_env.TimeStep, a_t: "ACTION_TYPE", ts_tp1: dm_env.TimeStep, h: int
53    ):
54        self._agent.update(ts_t, a_t, ts_tp1)
55
56    def update_models(self):
57        pass
58
59    def _before_new_episode(self):
60        pass
61
62    def episode_end_update(self):
63        pass
64
65    def before_start_interacting(self):
66        pass

A simple wrapper for bsuite agents.

@staticmethod
def is_emission_map_accepted(emission_map: colosseum.emission_maps.base.EmissionMap) -> bool:
21    @staticmethod
22    def is_emission_map_accepted(emission_map: "EmissionMap") -> bool:
23        return not emission_map.is_tabular
Returns
  • bool: True if the agent class accepts the emission map.
def is_episode_end( self, ts_t: dm_env._environment.TimeStep, a_t: Union[int, float, numpy.ndarray], ts_tp1: dm_env._environment.TimeStep, time: int) -> bool:
25    def is_episode_end(
26        self,
27        ts_t: dm_env.TimeStep,
28        a_t: "ACTION_TYPE",
29        ts_tp1: dm_env.TimeStep,
30        time: int,
31    ) -> bool:
32        return False

checks whether the episode is terminated. By default, this checks whether the current time step exceeds the time horizon. In the continuous case, this can be used to define artificial episodes.

Parameters
  • ts_t (dm_env.TimeStep): The TimeStep at time t.
  • a_t ("ACTION_TYPE"): The action taken by the agent at time t.
  • ts_tp1 (dm_env.TimeStep): The TimeStep at time t + 1.
  • time (int): The current time of the environment. In the episodic case, this refers to the in-episode time, whereas in the continuous case this refers to the total number of previous interactions.
Returns
  • bool: True if the episode terminated at time t+1.
def select_action( self, ts: dm_env._environment.TimeStep, time: int) -> Union[int, float, numpy.ndarray]:
48    def select_action(self, ts: dm_env.TimeStep, time: int) -> "ACTION_TYPE":
49        return self._agent.select_action(ts)
Parameters
  • ts (dm_env.TimeStep): The TimeStep for which the agent is required to calculate the next action.
  • time (int): The current time of the environment. In the episodic case, this refers to the in-episode time, whereas in the continuous case this refers to the total number of previous interactions.
Returns
  • action (ACTION_TYPE): The action that the agent suggests to take given the observation and the time step.
def step_update( self, ts_t: dm_env._environment.TimeStep, a_t: Union[int, float, numpy.ndarray], ts_tp1: dm_env._environment.TimeStep, h: int):
51    def step_update(
52        self, ts_t: dm_env.TimeStep, a_t: "ACTION_TYPE", ts_tp1: dm_env.TimeStep, h: int
53    ):
54        self._agent.update(ts_t, a_t, ts_tp1)

adds the transition in input to the MDP model.

Parameters
  • ts_t (dm_env.TimeStep): The TimeStep at time t.
  • a_t ("ACTION_TYPE"): The action taken by the agent at time t.
  • ts_tp1 (dm_env.TimeStep): The TimeStep at time t + 1.
  • time (int): The current time of the environment. In the episodic case, this refers to the in-episode time, whereas in the continuous case this refers to the total number of previous interactions.
def update_models(self):
56    def update_models(self):
57        pass
def episode_end_update(self):
62    def episode_end_update(self):
63        pass

is called when an episode ends. In the infinite horizon case, we refer to artificial episodes.

def before_start_interacting(self):
65    def before_start_interacting(self):
66        pass

is called before the agent starts interacting with the MDP.