colosseum.agent.utils

  1from typing import Type, TYPE_CHECKING, Any, Dict, List
  2
  3import numpy as np
  4
  5from colosseum.utils.miscellanea import rounding_nested_structure
  6
  7if TYPE_CHECKING:
  8    from colosseum.agent.agents.base import BaseAgent
  9
 10
 11def sample_agent_hyperparameters(
 12    agent_class: Type["BaseAgent"], seed: int
 13) -> Dict[str, Any]:
 14    """
 15    samples parameters from the agent class sample spaces.
 16
 17    Parameters
 18    ----------
 19    agent_class : Type["BaseAgent"]
 20        The agent class for which we are sampling from.
 21    seed : int
 22        The random seed.
 23
 24    Returns
 25    -------
 26    Dict[str, Any]
 27        The parameters sampled from the agent hyperparameter spaces.
 28    """
 29    np.random.seed(seed)
 30    search_spaces = agent_class.get_hyperparameters_search_spaces()
 31    return rounding_nested_structure({k: v.sample() for k, v in search_spaces.items()})
 32
 33
 34def sample_n_agent_hyperparameters(
 35    n: int, agent_class: Type["BaseAgent"], seed: int
 36) -> List[Dict[str, Any]]:
 37    """
 38    samples n parameters from the agent class sample spaces.
 39
 40    Parameters
 41    ----------
 42    n : int
 43        The number of samples.
 44    agent_class : Type["BaseAgent"]
 45        The agent class for which we are sampling from.
 46    seed : int
 47        The random seed.
 48
 49    Returns
 50    -------
 51    List[Dict[str, Any]]
 52        The list of n parameters sampled from the agent hyperparameter spaces.
 53    """
 54    return [sample_agent_hyperparameters(agent_class, seed + i) for i in range(n)]
 55
 56
 57def sample_agent_gin_configs(
 58    agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42
 59) -> List[str]:
 60    """
 61    samples gin configurations from the agent class sample spaces.
 62
 63    Parameters
 64    ----------
 65    agent_class : Type["BaseAgent"]
 66        The agent class for which we are sampling from.
 67    n : int
 68        The number of samples. By default, it is set to one.
 69    seed : int
 70        The random seed. By default, it is set to :math:`42`.
 71
 72    Returns
 73    -------
 74    List[str]
 75        The list containing the sampled gin configs.
 76    """
 77    return [
 78        agent_class.produce_gin_file_from_parameters(params, i)
 79        for i, params in enumerate(sample_n_agent_hyperparameters(n, agent_class, seed))
 80    ]
 81
 82
 83def sample_agent_gin_configs_file(
 84    agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42
 85) -> str:
 86    """
 87    samples gin configurations from the agent class sample spaces and store them in a string that can be used to create
 88    a gin config file.
 89
 90    Parameters
 91    ----------
 92    agent_class : Type["BaseAgent"]
 93        The agent class for which we are sampling from.
 94    n : int
 95        The number of samples. By default, it is set to one.
 96    seed : int
 97        The random seed. By default, it is set to :math:`42`.
 98
 99    Returns
100    -------
101    str
102        The gin configuration file.
103    """
104    return "\n".join(sample_agent_gin_configs(agent_class, n, seed))
def sample_agent_hyperparameters( agent_class: Type[colosseum.agent.agents.base.BaseAgent], seed: int) -> Dict[str, Any]:
12def sample_agent_hyperparameters(
13    agent_class: Type["BaseAgent"], seed: int
14) -> Dict[str, Any]:
15    """
16    samples parameters from the agent class sample spaces.
17
18    Parameters
19    ----------
20    agent_class : Type["BaseAgent"]
21        The agent class for which we are sampling from.
22    seed : int
23        The random seed.
24
25    Returns
26    -------
27    Dict[str, Any]
28        The parameters sampled from the agent hyperparameter spaces.
29    """
30    np.random.seed(seed)
31    search_spaces = agent_class.get_hyperparameters_search_spaces()
32    return rounding_nested_structure({k: v.sample() for k, v in search_spaces.items()})

samples parameters from the agent class sample spaces.

Parameters
  • agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
  • seed (int): The random seed.
Returns
  • Dict[str, Any]: The parameters sampled from the agent hyperparameter spaces.
def sample_n_agent_hyperparameters( n: int, agent_class: Type[colosseum.agent.agents.base.BaseAgent], seed: int) -> List[Dict[str, Any]]:
35def sample_n_agent_hyperparameters(
36    n: int, agent_class: Type["BaseAgent"], seed: int
37) -> List[Dict[str, Any]]:
38    """
39    samples n parameters from the agent class sample spaces.
40
41    Parameters
42    ----------
43    n : int
44        The number of samples.
45    agent_class : Type["BaseAgent"]
46        The agent class for which we are sampling from.
47    seed : int
48        The random seed.
49
50    Returns
51    -------
52    List[Dict[str, Any]]
53        The list of n parameters sampled from the agent hyperparameter spaces.
54    """
55    return [sample_agent_hyperparameters(agent_class, seed + i) for i in range(n)]

samples n parameters from the agent class sample spaces.

Parameters
  • n (int): The number of samples.
  • agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
  • seed (int): The random seed.
Returns
  • List[Dict[str, Any]]: The list of n parameters sampled from the agent hyperparameter spaces.
def sample_agent_gin_configs( agent_class: Type[colosseum.agent.agents.base.BaseAgent], n: int = 1, seed: int = 42) -> List[str]:
58def sample_agent_gin_configs(
59    agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42
60) -> List[str]:
61    """
62    samples gin configurations from the agent class sample spaces.
63
64    Parameters
65    ----------
66    agent_class : Type["BaseAgent"]
67        The agent class for which we are sampling from.
68    n : int
69        The number of samples. By default, it is set to one.
70    seed : int
71        The random seed. By default, it is set to :math:`42`.
72
73    Returns
74    -------
75    List[str]
76        The list containing the sampled gin configs.
77    """
78    return [
79        agent_class.produce_gin_file_from_parameters(params, i)
80        for i, params in enumerate(sample_n_agent_hyperparameters(n, agent_class, seed))
81    ]

samples gin configurations from the agent class sample spaces.

Parameters
  • agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
  • n (int): The number of samples. By default, it is set to one.
  • seed (int): The random seed. By default, it is set to \( 42 \).
Returns
  • List[str]: The list containing the sampled gin configs.
def sample_agent_gin_configs_file( agent_class: Type[colosseum.agent.agents.base.BaseAgent], n: int = 1, seed: int = 42) -> str:
 84def sample_agent_gin_configs_file(
 85    agent_class: Type["BaseAgent"], n: int = 1, seed: int = 42
 86) -> str:
 87    """
 88    samples gin configurations from the agent class sample spaces and store them in a string that can be used to create
 89    a gin config file.
 90
 91    Parameters
 92    ----------
 93    agent_class : Type["BaseAgent"]
 94        The agent class for which we are sampling from.
 95    n : int
 96        The number of samples. By default, it is set to one.
 97    seed : int
 98        The random seed. By default, it is set to :math:`42`.
 99
100    Returns
101    -------
102    str
103        The gin configuration file.
104    """
105    return "\n".join(sample_agent_gin_configs(agent_class, n, seed))

samples gin configurations from the agent class sample spaces and store them in a string that can be used to create a gin config file.

Parameters
  • agent_class (Type["BaseAgent"]): The agent class for which we are sampling from.
  • n (int): The number of samples. By default, it is set to one.
  • seed (int): The random seed. By default, it is set to \( 42 \).
Returns
  • str: The gin configuration file.