colosseum.agent.agents.infinite_horizon.boot_dqn
1from typing import Any, Dict, Callable, TYPE_CHECKING 2 3import dm_env 4import gin 5import numpy as np 6import sonnet as snt 7import tensorflow as tf 8from bsuite.baselines.tf.boot_dqn import BootstrappedDqn, make_ensemble 9from ray import tune 10 11from colosseum.dynamic_programming.utils import argmax_2d 12from colosseum.utils.non_tabular.bsuite import NonTabularBsuiteAgentWrapper 13 14if TYPE_CHECKING: 15 from colosseum.agent.agents.base import BaseAgent 16 from colosseum.utils.acme.specs import MDPSpec 17 from colosseum.mdp import ACTION_TYPE 18 19 20@gin.configurable 21class BootDQNContinuous(NonTabularBsuiteAgentWrapper): 22 """ 23 The wrapper for the `BootDQN` agent from `bsuite`. 24 """ 25 26 @staticmethod 27 def produce_gin_file_from_parameters(parameters: Dict[str, Any], index: int = 0): 28 string = "" 29 for k, v in parameters.items(): 30 string += f"prms_{index}/BootDQNContinuous.{k} = {v}\n" 31 return string[:-1] 32 33 @staticmethod 34 def is_episodic() -> bool: 35 return False 36 37 @staticmethod 38 def get_hyperparameters_search_spaces() -> Dict[str, tune.sample.Domain]: 39 return { 40 "network_width": tune.choice([64, 128, 256]), 41 "network_depth": tune.choice([2, 4]), 42 "batch_size": tune.choice([32, 64, 128]), 43 "sgd_period": tune.choice([1, 4, 8]), 44 "target_update_period": tune.choice([4, 16, 32]), 45 "mask_prob": tune.choice([0.8, 0.9, 1.0]), 46 "noise_scale": tune.choice([0.0, 0.05, 0.1]), 47 "n_ensemble": tune.choice([8, 16, 20]), 48 } 49 50 @staticmethod 51 def get_agent_instance_from_parameters( 52 seed: int, 53 optimization_horizon: int, 54 mdp_specs: "MDPSpec", 55 parameters: Dict[str, Any], 56 ) -> "BaseAgent": 57 return BootDQNContinuous( 58 seed, 59 mdp_specs, 60 optimization_horizon, 61 parameters["network_width"], 62 parameters["network_depth"], 63 parameters["batch_size"], 64 parameters["sgd_period"], 65 parameters["target_update_period"], 66 parameters["mask_prob"], 67 parameters["noise_scale"], 68 parameters["n_ensemble"], 69 ) 70 71 @property 72 def current_optimal_stochastic_policy(self) -> np.ndarray: 73 qvals = tf.stop_gradient( 74 self._agent._forward[self._agent._active_head]( 75 self.emission_map.all_observations 76 ) 77 ).numpy() 78 policy = argmax_2d(qvals) 79 assert np.isclose(policy.sum(-1).mean(), 1) 80 return policy 81 82 def step_update( 83 self, ts_t: dm_env.TimeStep, a_t: "ACTION_TYPE", ts_tp1: dm_env.TimeStep, h: int 84 ): 85 if self._rng_fast.random() < self._switch_prob: 86 self._agent._active_head = self._rng.randint(self._agent._num_ensemble) 87 super(BootDQNContinuous, self).step_update(ts_t, a_t, ts_tp1, h) 88 89 def __init__( 90 self, 91 seed: int, 92 mdp_specs: "MDPSpec", 93 optimization_horizon: int, 94 # MDP model parameters 95 network_width: int, 96 network_depth: int, 97 batch_size: int, 98 sgd_period: int, 99 target_update_period: int, 100 # Actor parameters 101 mask_prob: float, 102 noise_scale: float, 103 n_ensemble: int, 104 switch_prob: float = 0.1, 105 learning_rate: float = 1e-3, 106 discount: float = 0.99, 107 replay_capacity: int = 10000, 108 epsilon_fn: Callable[[int], float] = lambda t: 0, # lambda t: 10 / (10 + t) 109 ): 110 r""" 111 Parameters 112 ---------- 113 seed : int 114 The random seed. 115 mdp_specs : MDPSpec 116 The full specification of the MDP. 117 optimization_horizon : int 118 The total number of interactions that the agent is expected to have with the MDP. 119 network_width : int 120 The width of the neural networks of the agent. 121 network_depth : int 122 The depth of the neural networks of the agent. 123 batch_size : int 124 The batch size for training the agent. 125 sgd_period : int 126 The stochastic gradient descent update period. 127 target_update_period : int 128 The interval length between updating the target network. 129 mask_prob : float 130 The masking probability for the bootstrapping procedure. 131 noise_scale : float 132 The scale of the Gaussian noise_class added to the value estimates. 133 n_ensemble : int 134 The number of ensembles. 135 switch_prob : float 136 The probability of changing the ensemble whose q-estimates are used to select actions. By default, it is set 137 to :math:`0.1`. 138 learning_rate : float 139 The learning rate of the optimizer. By default, it is set to 1e-3. 140 discount : float 141 The discount factor. 142 replay_capacity : int 143 The maximum capacity of the replay buffer. By default, it is set to 10 000. 144 epsilon_fn : Callable[[int], float]] 145 The :math:`\epsilon` greedy probability as a function of the time. By default, it is set to zero. 146 """ 147 148 self._switch_prob = switch_prob 149 150 tf.random.set_seed(seed) 151 np.random.seed(seed) 152 153 ensemble = make_ensemble( 154 mdp_specs.actions.num_values, n_ensemble, network_depth, network_width 155 ) 156 optimizer = snt.optimizers.Adam(learning_rate=learning_rate) 157 158 agent = BootstrappedDqn( 159 obs_spec=mdp_specs.observations, 160 action_spec=mdp_specs.actions, 161 ensemble=ensemble, 162 batch_size=batch_size, 163 discount=discount, 164 replay_capacity=replay_capacity, 165 min_replay_size=batch_size, 166 sgd_period=sgd_period, 167 target_update_period=target_update_period, 168 optimizer=optimizer, 169 mask_prob=mask_prob, 170 noise_scale=noise_scale, 171 seed=seed, 172 epsilon_fn=epsilon_fn, 173 ) 174 super(BootDQNContinuous, self).__init__(seed, agent, mdp_specs)
@gin.configurable
class
BootDQNContinuous21@gin.configurable 22class BootDQNContinuous(NonTabularBsuiteAgentWrapper): 23 """ 24 The wrapper for the `BootDQN` agent from `bsuite`. 25 """ 26 27 @staticmethod 28 def produce_gin_file_from_parameters(parameters: Dict[str, Any], index: int = 0): 29 string = "" 30 for k, v in parameters.items(): 31 string += f"prms_{index}/BootDQNContinuous.{k} = {v}\n" 32 return string[:-1] 33 34 @staticmethod 35 def is_episodic() -> bool: 36 return False 37 38 @staticmethod 39 def get_hyperparameters_search_spaces() -> Dict[str, tune.sample.Domain]: 40 return { 41 "network_width": tune.choice([64, 128, 256]), 42 "network_depth": tune.choice([2, 4]), 43 "batch_size": tune.choice([32, 64, 128]), 44 "sgd_period": tune.choice([1, 4, 8]), 45 "target_update_period": tune.choice([4, 16, 32]), 46 "mask_prob": tune.choice([0.8, 0.9, 1.0]), 47 "noise_scale": tune.choice([0.0, 0.05, 0.1]), 48 "n_ensemble": tune.choice([8, 16, 20]), 49 } 50 51 @staticmethod 52 def get_agent_instance_from_parameters( 53 seed: int, 54 optimization_horizon: int, 55 mdp_specs: "MDPSpec", 56 parameters: Dict[str, Any], 57 ) -> "BaseAgent": 58 return BootDQNContinuous( 59 seed, 60 mdp_specs, 61 optimization_horizon, 62 parameters["network_width"], 63 parameters["network_depth"], 64 parameters["batch_size"], 65 parameters["sgd_period"], 66 parameters["target_update_period"], 67 parameters["mask_prob"], 68 parameters["noise_scale"], 69 parameters["n_ensemble"], 70 ) 71 72 @property 73 def current_optimal_stochastic_policy(self) -> np.ndarray: 74 qvals = tf.stop_gradient( 75 self._agent._forward[self._agent._active_head]( 76 self.emission_map.all_observations 77 ) 78 ).numpy() 79 policy = argmax_2d(qvals) 80 assert np.isclose(policy.sum(-1).mean(), 1) 81 return policy 82 83 def step_update( 84 self, ts_t: dm_env.TimeStep, a_t: "ACTION_TYPE", ts_tp1: dm_env.TimeStep, h: int 85 ): 86 if self._rng_fast.random() < self._switch_prob: 87 self._agent._active_head = self._rng.randint(self._agent._num_ensemble) 88 super(BootDQNContinuous, self).step_update(ts_t, a_t, ts_tp1, h) 89 90 def __init__( 91 self, 92 seed: int, 93 mdp_specs: "MDPSpec", 94 optimization_horizon: int, 95 # MDP model parameters 96 network_width: int, 97 network_depth: int, 98 batch_size: int, 99 sgd_period: int, 100 target_update_period: int, 101 # Actor parameters 102 mask_prob: float, 103 noise_scale: float, 104 n_ensemble: int, 105 switch_prob: float = 0.1, 106 learning_rate: float = 1e-3, 107 discount: float = 0.99, 108 replay_capacity: int = 10000, 109 epsilon_fn: Callable[[int], float] = lambda t: 0, # lambda t: 10 / (10 + t) 110 ): 111 r""" 112 Parameters 113 ---------- 114 seed : int 115 The random seed. 116 mdp_specs : MDPSpec 117 The full specification of the MDP. 118 optimization_horizon : int 119 The total number of interactions that the agent is expected to have with the MDP. 120 network_width : int 121 The width of the neural networks of the agent. 122 network_depth : int 123 The depth of the neural networks of the agent. 124 batch_size : int 125 The batch size for training the agent. 126 sgd_period : int 127 The stochastic gradient descent update period. 128 target_update_period : int 129 The interval length between updating the target network. 130 mask_prob : float 131 The masking probability for the bootstrapping procedure. 132 noise_scale : float 133 The scale of the Gaussian noise_class added to the value estimates. 134 n_ensemble : int 135 The number of ensembles. 136 switch_prob : float 137 The probability of changing the ensemble whose q-estimates are used to select actions. By default, it is set 138 to :math:`0.1`. 139 learning_rate : float 140 The learning rate of the optimizer. By default, it is set to 1e-3. 141 discount : float 142 The discount factor. 143 replay_capacity : int 144 The maximum capacity of the replay buffer. By default, it is set to 10 000. 145 epsilon_fn : Callable[[int], float]] 146 The :math:`\epsilon` greedy probability as a function of the time. By default, it is set to zero. 147 """ 148 149 self._switch_prob = switch_prob 150 151 tf.random.set_seed(seed) 152 np.random.seed(seed) 153 154 ensemble = make_ensemble( 155 mdp_specs.actions.num_values, n_ensemble, network_depth, network_width 156 ) 157 optimizer = snt.optimizers.Adam(learning_rate=learning_rate) 158 159 agent = BootstrappedDqn( 160 obs_spec=mdp_specs.observations, 161 action_spec=mdp_specs.actions, 162 ensemble=ensemble, 163 batch_size=batch_size, 164 discount=discount, 165 replay_capacity=replay_capacity, 166 min_replay_size=batch_size, 167 sgd_period=sgd_period, 168 target_update_period=target_update_period, 169 optimizer=optimizer, 170 mask_prob=mask_prob, 171 noise_scale=noise_scale, 172 seed=seed, 173 epsilon_fn=epsilon_fn, 174 ) 175 super(BootDQNContinuous, self).__init__(seed, agent, mdp_specs)
The wrapper for the BootDQN
agent from bsuite
.
BootDQNContinuous( seed: int, mdp_specs: colosseum.utils.acme.specs.MDPSpec, optimization_horizon: int, network_width: int, network_depth: int, batch_size: int, sgd_period: int, target_update_period: int, mask_prob: float, noise_scale: float, n_ensemble: int, switch_prob: float = 0.1, learning_rate: float = 0.001, discount: float = 0.99, replay_capacity: int = 10000, epsilon_fn: Callable[[int], float] = <function BootDQNContinuous.<lambda>>)
90 def __init__( 91 self, 92 seed: int, 93 mdp_specs: "MDPSpec", 94 optimization_horizon: int, 95 # MDP model parameters 96 network_width: int, 97 network_depth: int, 98 batch_size: int, 99 sgd_period: int, 100 target_update_period: int, 101 # Actor parameters 102 mask_prob: float, 103 noise_scale: float, 104 n_ensemble: int, 105 switch_prob: float = 0.1, 106 learning_rate: float = 1e-3, 107 discount: float = 0.99, 108 replay_capacity: int = 10000, 109 epsilon_fn: Callable[[int], float] = lambda t: 0, # lambda t: 10 / (10 + t) 110 ): 111 r""" 112 Parameters 113 ---------- 114 seed : int 115 The random seed. 116 mdp_specs : MDPSpec 117 The full specification of the MDP. 118 optimization_horizon : int 119 The total number of interactions that the agent is expected to have with the MDP. 120 network_width : int 121 The width of the neural networks of the agent. 122 network_depth : int 123 The depth of the neural networks of the agent. 124 batch_size : int 125 The batch size for training the agent. 126 sgd_period : int 127 The stochastic gradient descent update period. 128 target_update_period : int 129 The interval length between updating the target network. 130 mask_prob : float 131 The masking probability for the bootstrapping procedure. 132 noise_scale : float 133 The scale of the Gaussian noise_class added to the value estimates. 134 n_ensemble : int 135 The number of ensembles. 136 switch_prob : float 137 The probability of changing the ensemble whose q-estimates are used to select actions. By default, it is set 138 to :math:`0.1`. 139 learning_rate : float 140 The learning rate of the optimizer. By default, it is set to 1e-3. 141 discount : float 142 The discount factor. 143 replay_capacity : int 144 The maximum capacity of the replay buffer. By default, it is set to 10 000. 145 epsilon_fn : Callable[[int], float]] 146 The :math:`\epsilon` greedy probability as a function of the time. By default, it is set to zero. 147 """ 148 149 self._switch_prob = switch_prob 150 151 tf.random.set_seed(seed) 152 np.random.seed(seed) 153 154 ensemble = make_ensemble( 155 mdp_specs.actions.num_values, n_ensemble, network_depth, network_width 156 ) 157 optimizer = snt.optimizers.Adam(learning_rate=learning_rate) 158 159 agent = BootstrappedDqn( 160 obs_spec=mdp_specs.observations, 161 action_spec=mdp_specs.actions, 162 ensemble=ensemble, 163 batch_size=batch_size, 164 discount=discount, 165 replay_capacity=replay_capacity, 166 min_replay_size=batch_size, 167 sgd_period=sgd_period, 168 target_update_period=target_update_period, 169 optimizer=optimizer, 170 mask_prob=mask_prob, 171 noise_scale=noise_scale, 172 seed=seed, 173 epsilon_fn=epsilon_fn, 174 ) 175 super(BootDQNContinuous, self).__init__(seed, agent, mdp_specs)
Parameters
- seed (int): The random seed.
- mdp_specs (MDPSpec): The full specification of the MDP.
- optimization_horizon (int): The total number of interactions that the agent is expected to have with the MDP.
- network_width (int): The width of the neural networks of the agent.
- network_depth (int): The depth of the neural networks of the agent.
- batch_size (int): The batch size for training the agent.
- sgd_period (int): The stochastic gradient descent update period.
- target_update_period (int): The interval length between updating the target network.
- mask_prob (float): The masking probability for the bootstrapping procedure.
- noise_scale (float): The scale of the Gaussian noise_class added to the value estimates.
- n_ensemble (int): The number of ensembles.
- switch_prob (float): The probability of changing the ensemble whose q-estimates are used to select actions. By default, it is set to \( 0.1 \).
- learning_rate (float): The learning rate of the optimizer. By default, it is set to 1e-3.
- discount (float): The discount factor.
- replay_capacity (int): The maximum capacity of the replay buffer. By default, it is set to 10 000.
- epsilon_fn (Callable[[int], float]]): The \( \epsilon \) greedy probability as a function of the time. By default, it is set to zero.
@staticmethod
def
produce_gin_file_from_parameters(parameters: Dict[str, Any], index: int = 0):
27 @staticmethod 28 def produce_gin_file_from_parameters(parameters: Dict[str, Any], index: int = 0): 29 string = "" 30 for k, v in parameters.items(): 31 string += f"prms_{index}/BootDQNContinuous.{k} = {v}\n" 32 return string[:-1]
produces a string containing the gin config file corresponding to the parameters given in input.
Parameters
- parameters (Dict[str, Any]): The dictionary containing the parameters of the agent.
- index (int): The index assigned to the gin configuration.
Returns
- gin_config (str): The gin configuration file.
@staticmethod
def
is_episodic() -> bool:
Returns
- bool: True if the agent is suited for the episodic setting.
@staticmethod
def
get_hyperparameters_search_spaces() -> Dict[str, ray.tune.sample.Domain]:
38 @staticmethod 39 def get_hyperparameters_search_spaces() -> Dict[str, tune.sample.Domain]: 40 return { 41 "network_width": tune.choice([64, 128, 256]), 42 "network_depth": tune.choice([2, 4]), 43 "batch_size": tune.choice([32, 64, 128]), 44 "sgd_period": tune.choice([1, 4, 8]), 45 "target_update_period": tune.choice([4, 16, 32]), 46 "mask_prob": tune.choice([0.8, 0.9, 1.0]), 47 "noise_scale": tune.choice([0.0, 0.05, 0.1]), 48 "n_ensemble": tune.choice([8, 16, 20]), 49 }
Returns
- Dict[str, tune.sample.Domain]: The dictionary with key value pairs corresponding to hyperparameter names and corresponding
ray.tune
samplers.
@staticmethod
def
get_agent_instance_from_parameters( seed: int, optimization_horizon: int, mdp_specs: colosseum.utils.acme.specs.MDPSpec, parameters: Dict[str, Any]) -> colosseum.agent.agents.base.BaseAgent:
51 @staticmethod 52 def get_agent_instance_from_parameters( 53 seed: int, 54 optimization_horizon: int, 55 mdp_specs: "MDPSpec", 56 parameters: Dict[str, Any], 57 ) -> "BaseAgent": 58 return BootDQNContinuous( 59 seed, 60 mdp_specs, 61 optimization_horizon, 62 parameters["network_width"], 63 parameters["network_depth"], 64 parameters["batch_size"], 65 parameters["sgd_period"], 66 parameters["target_update_period"], 67 parameters["mask_prob"], 68 parameters["noise_scale"], 69 parameters["n_ensemble"], 70 )
returns an agent instance for the mdp specification and agent parameters given in input.
Parameters
- seed (int): The random seed.
- optimization_horizon (int): The total number of interactions that the agent is expected to have with the MDP.
- mdp_specs (MDPSpec): The full specification of the MDP.
- parameters (Dict[str, Any]): The dictionary containing the parameters of the agent.
Returns
- BaseAgent: The agent instance.
current_optimal_stochastic_policy: numpy.ndarray
Returns
- np.ndarray: The estimates of the best optimal policy given the current knowledge of the agent in the form of distribution over actions.
def
step_update( self, ts_t: dm_env._environment.TimeStep, a_t: Union[int, float, numpy.ndarray], ts_tp1: dm_env._environment.TimeStep, h: int):
83 def step_update( 84 self, ts_t: dm_env.TimeStep, a_t: "ACTION_TYPE", ts_tp1: dm_env.TimeStep, h: int 85 ): 86 if self._rng_fast.random() < self._switch_prob: 87 self._agent._active_head = self._rng.randint(self._agent._num_ensemble) 88 super(BootDQNContinuous, self).step_update(ts_t, a_t, ts_tp1, h)
adds the transition in input to the MDP model.
Parameters
- ts_t (dm_env.TimeStep): The TimeStep at time t.
- a_t ("ACTION_TYPE"): The action taken by the agent at time t.
- ts_tp1 (dm_env.TimeStep): The TimeStep at time t + 1.
- time (int): The current time of the environment. In the episodic case, this refers to the in-episode time, whereas in the continuous case this refers to the total number of previous interactions.