colosseum.agent.utils
1from typing import Type, TYPE_CHECKING, Any, Dict, List 2 3import numpy as np 4 5from colosseum.utils.miscellanea import rounding_nested_structure 6 7if TYPE_CHECKING: 8 from colosseum.agent.agents.base import BaseAgent 9 10 11def sample_agent_hyperparameters( 12 agent_class: Type["BaseAgent"], seed: int 13) -> Dict[str, Any]: 14 """ 15 samples parameters from the agent class sample spaces. 16 17 Parameters 18 ---------- 19 agent_class : Type["BaseAgent"] 20 The agent class for which we are sampling from. 21 seed : int 22 The random seed. 23 24 Returns 25 ------- 26 Dict[str, Any] 27 The parameters sampled from the agent hyperparameter spaces. 28 """ 29 np.random.seed(seed) 30 search_spaces = agent_class.get_hyperparameters_search_spaces() 31 return rounding_nested_structure({k: v.sample() for k, v in search_spaces.items()}) 32 33 34def sample_n_agent_hyperparameters( 35 n: int, agent_class: Type["BaseAgent"], seed: int 36) -> List[Dict[str, Any]]: 37 """ 38 samples n parameters from the agent class sample spaces. 39 40 Parameters 41 ---------- 42 n : int 43 The number of samples. 44 agent_class : Type["BaseAgent"] 45 The agent class for which we are sampling from. 46 seed : int 47 The random seed. 48 49 Returns 50 ------- 51 List[Dict[str, Any]] 52 The list of n parameters sampled from the agent hyperparameter spaces. 53 """ 54 return [sample_agent_hyperparameters(agent_class, seed + i) for i in range(n)] 55 56 57def sample_agent_gin_configs( 58 agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42 59) -> List[str]: 60 """ 61 samples gin configurations from the agent class sample spaces. 62 63 Parameters 64 ---------- 65 agent_class : Type["BaseAgent"] 66 The agent class for which we are sampling from. 67 n : int 68 The number of samples. By default, it is set to one. 69 seed : int 70 The random seed. By default, it is set to :math:`42`. 71 72 Returns 73 ------- 74 List[str] 75 The list containing the sampled gin configs. 76 """ 77 return [ 78 agent_class.produce_gin_file_from_parameters(params, i) 79 for i, params in enumerate(sample_n_agent_hyperparameters(n, agent_class, seed)) 80 ] 81 82 83def sample_agent_gin_configs_file( 84 agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42 85) -> str: 86 """ 87 samples gin configurations from the agent class sample spaces and store them in a string that can be used to create 88 a gin config file. 89 90 Parameters 91 ---------- 92 agent_class : Type["BaseAgent"] 93 The agent class for which we are sampling from. 94 n : int 95 The number of samples. By default, it is set to one. 96 seed : int 97 The random seed. By default, it is set to :math:`42`. 98 99 Returns 100 ------- 101 str 102 The gin configuration file. 103 """ 104 return "\n".join(sample_agent_gin_configs(agent_class, n, seed))
def
sample_agent_hyperparameters( agent_class: Type[colosseum.agent.agents.base.BaseAgent], seed: int) -> Dict[str, Any]:
12def sample_agent_hyperparameters( 13 agent_class: Type["BaseAgent"], seed: int 14) -> Dict[str, Any]: 15 """ 16 samples parameters from the agent class sample spaces. 17 18 Parameters 19 ---------- 20 agent_class : Type["BaseAgent"] 21 The agent class for which we are sampling from. 22 seed : int 23 The random seed. 24 25 Returns 26 ------- 27 Dict[str, Any] 28 The parameters sampled from the agent hyperparameter spaces. 29 """ 30 np.random.seed(seed) 31 search_spaces = agent_class.get_hyperparameters_search_spaces() 32 return rounding_nested_structure({k: v.sample() for k, v in search_spaces.items()})
samples parameters from the agent class sample spaces.
Parameters
- agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
- seed (int): The random seed.
Returns
- Dict[str, Any]: The parameters sampled from the agent hyperparameter spaces.
def
sample_n_agent_hyperparameters( n: int, agent_class: Type[colosseum.agent.agents.base.BaseAgent], seed: int) -> List[Dict[str, Any]]:
35def sample_n_agent_hyperparameters( 36 n: int, agent_class: Type["BaseAgent"], seed: int 37) -> List[Dict[str, Any]]: 38 """ 39 samples n parameters from the agent class sample spaces. 40 41 Parameters 42 ---------- 43 n : int 44 The number of samples. 45 agent_class : Type["BaseAgent"] 46 The agent class for which we are sampling from. 47 seed : int 48 The random seed. 49 50 Returns 51 ------- 52 List[Dict[str, Any]] 53 The list of n parameters sampled from the agent hyperparameter spaces. 54 """ 55 return [sample_agent_hyperparameters(agent_class, seed + i) for i in range(n)]
samples n parameters from the agent class sample spaces.
Parameters
- n (int): The number of samples.
- agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
- seed (int): The random seed.
Returns
- List[Dict[str, Any]]: The list of n parameters sampled from the agent hyperparameter spaces.
def
sample_agent_gin_configs( agent_class: Type[colosseum.agent.agents.base.BaseAgent], n: int = 1, seed: int = 42) -> List[str]:
58def sample_agent_gin_configs( 59 agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42 60) -> List[str]: 61 """ 62 samples gin configurations from the agent class sample spaces. 63 64 Parameters 65 ---------- 66 agent_class : Type["BaseAgent"] 67 The agent class for which we are sampling from. 68 n : int 69 The number of samples. By default, it is set to one. 70 seed : int 71 The random seed. By default, it is set to :math:`42`. 72 73 Returns 74 ------- 75 List[str] 76 The list containing the sampled gin configs. 77 """ 78 return [ 79 agent_class.produce_gin_file_from_parameters(params, i) 80 for i, params in enumerate(sample_n_agent_hyperparameters(n, agent_class, seed)) 81 ]
samples gin configurations from the agent class sample spaces.
Parameters
- agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
- n (int): The number of samples. By default, it is set to one.
- seed (int): The random seed. By default, it is set to \( 42 \).
Returns
- List[str]: The list containing the sampled gin configs.
def
sample_agent_gin_configs_file( agent_class: Type[colosseum.agent.agents.base.BaseAgent], n: int = 1, seed: int = 42) -> str:
84def sample_agent_gin_configs_file( 85 agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42 86) -> str: 87 """ 88 samples gin configurations from the agent class sample spaces and store them in a string that can be used to create 89 a gin config file. 90 91 Parameters 92 ---------- 93 agent_class : Type["BaseAgent"] 94 The agent class for which we are sampling from. 95 n : int 96 The number of samples. By default, it is set to one. 97 seed : int 98 The random seed. By default, it is set to :math:`42`. 99 100 Returns 101 ------- 102 str 103 The gin configuration file. 104 """ 105 return "\n".join(sample_agent_gin_configs(agent_class, n, seed))
samples gin configurations from the agent class sample spaces and store them in a string that can be used to create a gin config file.
Parameters
- agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
- n (int): The number of samples. By default, it is set to one.
- seed (int): The random seed. By default, it is set to \( 42 \).
Returns
- str: The gin configuration file.