colosseum.benchmark.utils

  1import os
  2import re
  3import shutil
  4from glob import glob
  5from typing import List, Dict, TYPE_CHECKING, Type, Union
  6
  7import yaml
  8
  9from colosseum.benchmark.benchmark import ColosseumBenchmark, BENCHMARKS_DIRECTORY
 10from colosseum.emission_maps import EmissionMap
 11from colosseum.experiment import ExperimentConfig
 12from colosseum.utils import ensure_folder
 13from colosseum.utils.miscellanea import (
 14    get_agent_class_from_name,
 15    get_mdp_class_from_name,
 16    compare_gin_configs,
 17)
 18
 19if TYPE_CHECKING:
 20    from colosseum.mdp.base import BaseMDP
 21    from colosseum.agent.agents.base import BaseAgent
 22
 23
 24def get_mdps_configs_from_mdps(mdps: List["BaseMDP"]) -> List[str]:
 25    """
 26    Returns
 27    -------
 28    List[str]
 29        The gin configs of the MDPs
 30    """
 31    mdp_configs = dict()
 32    for mdp in mdps:
 33        if not type(mdp) in mdp_configs:
 34            mdp_configs[type(mdp)] = []
 35        mdp_configs[type(mdp)].append(mdp.get_gin_config(len(mdp_configs[type(mdp)])))
 36    return mdp_configs
 37
 38
 39def instantiate_agent_configs(
 40    agents_configs: Dict[Type["BaseAgent"], Union[str, None]],
 41    benchmark_folder: str,
 42):
 43    """
 44    instantiates the gin agent configurations in the given folder.
 45
 46    Parameters
 47    ----------
 48    agents_configs : Dict[Type["BaseAgent"], Union[str, None]]
 49        The dictionary associates agent classes to their gin config files. If no agent config is given, we try to
 50        retrieve it from the cached_hyperparameters folder in package
 51    benchmark_folder : str
 52        The folder where the corresponding benchmark is located.
 53    """
 54
 55    # If no agent config is given, we retrieve it from the cached_hyperparameters folder in package
 56    for ag_cl in list(agents_configs):
 57        if agents_configs[ag_cl] is None:
 58            cached_config = (
 59                BENCHMARKS_DIRECTORY
 60                + "cached_hyperparameters"
 61                + os.sep
 62                + "agent_configs"
 63                + os.sep
 64                + ag_cl.__name__
 65                + ".gin"
 66            )
 67            if os.path.isfile(cached_config):
 68                with open(cached_config, "r") as f:
 69                    agents_configs[ag_cl] = f.read()
 70            else:
 71                raise f"No configuration was given for agent {ag_cl.__name__}"
 72
 73    if os.path.isdir(benchmark_folder + "agents_configs" + os.sep):
 74        try:
 75            local_agent_configs = retrieve_agent_configs(benchmark_folder)
 76            if not compare_gin_configs(agents_configs, local_agent_configs):
 77                raise ValueError(
 78                    f"The existing agent configs in {benchmark_folder} are different from the one in input."
 79                )
 80
 81        # If the local folder is corrupted, we eliminate it
 82        except AssertionError:
 83            shutil.rmtree(benchmark_folder + "agents_configs")
 84    else:
 85        os.makedirs(ensure_folder(benchmark_folder) + "agents_configs", exist_ok=True)
 86        for ag_cl, gin_config in agents_configs.items():
 87            with open(
 88                ensure_folder(benchmark_folder)
 89                + "agents_configs"
 90                + os.sep
 91                + ag_cl.__name__
 92                + ".gin",
 93                "w",
 94            ) as f:
 95                f.write(gin_config)
 96
 97
 98def instantiate_benchmark_folder(benchmark: ColosseumBenchmark, benchmark_folder: str):
 99    """
100    instantiates the benchmark locally. If a local benchmark is found then it merges iff they are the same in terms of
101    MDP configs and experiment configurations.
102
103    Parameters
104    ----------
105    benchmark : ColosseumBenchmark
106        The benchmark to instantiate.
107    benchmark_folder : str
108        The folder where the corresponding benchmark is located.
109    """
110
111    # Check whether the experiment folder has already been created
112    if os.path.isdir(benchmark_folder) and len(os.listdir(benchmark_folder)) > 0:
113        try:
114            local_benchmark = retrieve_benchmark(benchmark_folder)
115            if local_benchmark != benchmark:
116                raise ValueError(
117                    f"The experiment folder {benchmark_folder} is already occupied."
118                )
119
120        # If the local folder is corrupted, we eliminate it
121        except AssertionError:
122            shutil.rmtree(benchmark_folder)
123    else:
124        benchmark.instantiate(benchmark_folder)
125
126
127def retrieve_benchmark(
128    benchmark_folder: str, experiment_config: ExperimentConfig = None, postfix: str = ""
129) -> ColosseumBenchmark:
130    """
131    retrieves a benchmark from a folder.
132
133    Parameters
134    ----------
135    benchmark_folder : ColosseumBenchmark
136        The folder where the benchmark is located.
137    experiment_config : ExperimentConfig
138        The experiment config to be substituted to the default one. By default, no substitution happens.
139    postfix : str
140        The postfix to add to the name of the benchmark. By default, no postfix is added.
141
142    Returns
143    -------
144    ColosseumBenchmark
145        The retrieved benchmark.
146    """
147    benchmark = ColosseumBenchmark(
148        os.path.basename(ensure_folder(benchmark_folder)[:-1]) + postfix,
149        retrieve_mdp_configs(benchmark_folder),
150        retrieve_experiment_config(benchmark_folder)
151        if experiment_config is None
152        else experiment_config,
153    )
154    return benchmark
155
156
157def update_emission_map(benchmark_folder: str, emission_map: EmissionMap):
158    """
159    substitutes the emission map in the experiment config of the given experiment folder with the one given in input.
160    """
161    config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml"
162    assert os.path.isfile(
163        config_fp
164    ), f"The folder {benchmark_folder} does not contain a configuration file."
165
166    with open(config_fp, "r") as f:
167        config_file = yaml.load(f, yaml.Loader)
168    config_file["emission_map"] = emission_map.__name__
169    with open(config_fp, "w") as f:
170        yaml.dump(config_file, f)
171
172
173def retrieve_experiment_config(benchmark_folder: str) -> ExperimentConfig:
174    """
175    Returns
176    -------
177    ExperimentConfig
178        The experiment config from the given benchmark folder.
179    """
180    config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml"
181    assert os.path.isfile(
182        config_fp
183    ), f"The folder {benchmark_folder} does not contain a configuration file."
184
185    with open(config_fp, "r") as f:
186        exp_config = yaml.load(f, yaml.Loader)
187    return ExperimentConfig(**exp_config)
188
189
190def retrieve_mdp_configs(
191    benchmark_folder: str, return_string=True
192) -> Union[Dict[Type["BaseMDP"], str], Dict[Type["BaseMDP"], Dict[str, str]],]:
193    """
194    retrieves the MDP gin configs of a benchmark.
195
196    Parameters
197    ----------
198    benchmark_folder : ColosseumBenchmark
199        The folder where the benchmark is located.
200    return_string : bool
201        If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By
202        default, the single string format is used.
203
204    Returns
205    -------
206    Union[
207        Dict[Type["BaseMDP"], str],
208        Dict[Type["BaseMDP"], Dict[str, str]],
209    ]
210        The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
211    """
212    return retrieve_gin_configs(
213        ensure_folder(benchmark_folder) + "mdp_configs" + os.sep, return_string
214    )
215
216
217def retrieve_agent_configs(
218    benchmark_folder: str, return_string=True
219) -> Union[Dict[Type["BaseAgent"], str], Dict[Type["BaseAgent"], Dict[str, str]],]:
220    """
221    retrieves the agent gin configs of a benchmark.
222
223    Parameters
224    ----------
225    benchmark_folder : ColosseumBenchmark
226        The folder where the benchmark is located.
227    return_string : bool
228        If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By
229        default, the single string format is used.
230
231    Returns
232    -------
233    Union[
234        Dict[Type["BaseAgent"], str],
235        Dict[Type["BaseAgent"], Dict[str, str]],
236    ]
237        The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
238    """
239    return retrieve_gin_configs(
240        ensure_folder(benchmark_folder) + "agents_configs" + os.sep, return_string
241    )
242
243
244def retrieve_gin_configs(
245    gin_config_folder: str, return_string: bool
246) -> Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]:
247    """
248    retrieves the gin configs from a folder.
249
250    Parameters
251    ----------
252    gin_config_folder : ColosseumBenchmark
253        The folder where the gin configs are stored.
254    return_string : bool
255        If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By
256        default, the single string format is used.
257
258    Returns
259    -------
260    Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]
261        The dictionary that for each MDP and agent name contains a list of gin configs obtained from the given folder.
262    """
263
264    gin_config_folder = ensure_folder(gin_config_folder)
265
266    configs = glob(gin_config_folder + "*.gin")
267    assert (
268        len(configs) > 0
269    ), f"The folder {gin_config_folder} does not contain config files"
270
271    gin_configs = dict()
272    for f in configs:
273        name = os.path.basename(f).replace(".gin", "")
274        cl = (
275            get_agent_class_from_name(name)
276            if "agent" in os.path.basename(gin_config_folder[:-1])
277            else get_mdp_class_from_name(name)
278        )
279
280        gin_configs[cl] = [] if return_string else dict()
281        with open(f, "r") as ff:
282            gin_config_file = ff.read() + "\n"
283        for config_prms in sorted(
284            set(re.findall(r"prms_[0-9]+/", gin_config_file)),
285            # Ascending order based on the parameter index
286            key=lambda x: int(x.replace("prms_", "")[:-1]),
287        ):
288            imports = set(re.findall("from.+?import.+?\n", gin_config_file))
289            prms_configs = "".join(re.findall(config_prms + ".+?\n", gin_config_file))
290            if len(imports) > 0:
291                prms_configs = "".join(imports) + prms_configs
292
293            if return_string:
294                gin_configs[cl].append(prms_configs)
295            else:
296                gin_configs[cl][config_prms[:-1]] = prms_configs
297
298        if return_string:
299            gin_configs[cl] = "\n".join(gin_configs[cl])
300
301    return gin_configs
def get_mdps_configs_from_mdps(mdps: List[colosseum.mdp.base.BaseMDP]) -> List[str]:
25def get_mdps_configs_from_mdps(mdps: List["BaseMDP"]) -> List[str]:
26    """
27    Returns
28    -------
29    List[str]
30        The gin configs of the MDPs
31    """
32    mdp_configs = dict()
33    for mdp in mdps:
34        if not type(mdp) in mdp_configs:
35            mdp_configs[type(mdp)] = []
36        mdp_configs[type(mdp)].append(mdp.get_gin_config(len(mdp_configs[type(mdp)])))
37    return mdp_configs
Returns
  • List[str]: The gin configs of the MDPs
def instantiate_agent_configs( agents_configs: Dict[Type[colosseum.agent.agents.base.BaseAgent], Optional[str]], benchmark_folder: str):
40def instantiate_agent_configs(
41    agents_configs: Dict[Type["BaseAgent"], Union[str, None]],
42    benchmark_folder: str,
43):
44    """
45    instantiates the gin agent configurations in the given folder.
46
47    Parameters
48    ----------
49    agents_configs : Dict[Type["BaseAgent"], Union[str, None]]
50        The dictionary associates agent classes to their gin config files. If no agent config is given, we try to
51        retrieve it from the cached_hyperparameters folder in package
52    benchmark_folder : str
53        The folder where the corresponding benchmark is located.
54    """
55
56    # If no agent config is given, we retrieve it from the cached_hyperparameters folder in package
57    for ag_cl in list(agents_configs):
58        if agents_configs[ag_cl] is None:
59            cached_config = (
60                BENCHMARKS_DIRECTORY
61                + "cached_hyperparameters"
62                + os.sep
63                + "agent_configs"
64                + os.sep
65                + ag_cl.__name__
66                + ".gin"
67            )
68            if os.path.isfile(cached_config):
69                with open(cached_config, "r") as f:
70                    agents_configs[ag_cl] = f.read()
71            else:
72                raise f"No configuration was given for agent {ag_cl.__name__}"
73
74    if os.path.isdir(benchmark_folder + "agents_configs" + os.sep):
75        try:
76            local_agent_configs = retrieve_agent_configs(benchmark_folder)
77            if not compare_gin_configs(agents_configs, local_agent_configs):
78                raise ValueError(
79                    f"The existing agent configs in {benchmark_folder} are different from the one in input."
80                )
81
82        # If the local folder is corrupted, we eliminate it
83        except AssertionError:
84            shutil.rmtree(benchmark_folder + "agents_configs")
85    else:
86        os.makedirs(ensure_folder(benchmark_folder) + "agents_configs", exist_ok=True)
87        for ag_cl, gin_config in agents_configs.items():
88            with open(
89                ensure_folder(benchmark_folder)
90                + "agents_configs"
91                + os.sep
92                + ag_cl.__name__
93                + ".gin",
94                "w",
95            ) as f:
96                f.write(gin_config)

instantiates the gin agent configurations in the given folder.

Parameters
  • agents_configs (Dict[Type["BaseAgent"], Union[str, None]]): The dictionary associates agent classes to their gin config files. If no agent config is given, we try to retrieve it from the cached_hyperparameters folder in package
  • benchmark_folder (str): The folder where the corresponding benchmark is located.
def instantiate_benchmark_folder( benchmark: colosseum.benchmark.benchmark.ColosseumBenchmark, benchmark_folder: str):
 99def instantiate_benchmark_folder(benchmark: ColosseumBenchmark, benchmark_folder: str):
100    """
101    instantiates the benchmark locally. If a local benchmark is found then it merges iff they are the same in terms of
102    MDP configs and experiment configurations.
103
104    Parameters
105    ----------
106    benchmark : ColosseumBenchmark
107        The benchmark to instantiate.
108    benchmark_folder : str
109        The folder where the corresponding benchmark is located.
110    """
111
112    # Check whether the experiment folder has already been created
113    if os.path.isdir(benchmark_folder) and len(os.listdir(benchmark_folder)) > 0:
114        try:
115            local_benchmark = retrieve_benchmark(benchmark_folder)
116            if local_benchmark != benchmark:
117                raise ValueError(
118                    f"The experiment folder {benchmark_folder} is already occupied."
119                )
120
121        # If the local folder is corrupted, we eliminate it
122        except AssertionError:
123            shutil.rmtree(benchmark_folder)
124    else:
125        benchmark.instantiate(benchmark_folder)

instantiates the benchmark locally. If a local benchmark is found then it merges iff they are the same in terms of MDP configs and experiment configurations.

Parameters
  • benchmark (ColosseumBenchmark): The benchmark to instantiate.
  • benchmark_folder (str): The folder where the corresponding benchmark is located.
def retrieve_benchmark( benchmark_folder: str, experiment_config: colosseum.experiment.config.ExperimentConfig = None, postfix: str = '') -> colosseum.benchmark.benchmark.ColosseumBenchmark:
128def retrieve_benchmark(
129    benchmark_folder: str, experiment_config: ExperimentConfig = None, postfix: str = ""
130) -> ColosseumBenchmark:
131    """
132    retrieves a benchmark from a folder.
133
134    Parameters
135    ----------
136    benchmark_folder : ColosseumBenchmark
137        The folder where the benchmark is located.
138    experiment_config : ExperimentConfig
139        The experiment config to be substituted to the default one. By default, no substitution happens.
140    postfix : str
141        The postfix to add to the name of the benchmark. By default, no postfix is added.
142
143    Returns
144    -------
145    ColosseumBenchmark
146        The retrieved benchmark.
147    """
148    benchmark = ColosseumBenchmark(
149        os.path.basename(ensure_folder(benchmark_folder)[:-1]) + postfix,
150        retrieve_mdp_configs(benchmark_folder),
151        retrieve_experiment_config(benchmark_folder)
152        if experiment_config is None
153        else experiment_config,
154    )
155    return benchmark

retrieves a benchmark from a folder.

Parameters
  • benchmark_folder (ColosseumBenchmark): The folder where the benchmark is located.
  • experiment_config (ExperimentConfig): The experiment config to be substituted to the default one. By default, no substitution happens.
  • postfix (str): The postfix to add to the name of the benchmark. By default, no postfix is added.
Returns
  • ColosseumBenchmark: The retrieved benchmark.
def update_emission_map( benchmark_folder: str, emission_map: colosseum.emission_maps.base.EmissionMap):
158def update_emission_map(benchmark_folder: str, emission_map: EmissionMap):
159    """
160    substitutes the emission map in the experiment config of the given experiment folder with the one given in input.
161    """
162    config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml"
163    assert os.path.isfile(
164        config_fp
165    ), f"The folder {benchmark_folder} does not contain a configuration file."
166
167    with open(config_fp, "r") as f:
168        config_file = yaml.load(f, yaml.Loader)
169    config_file["emission_map"] = emission_map.__name__
170    with open(config_fp, "w") as f:
171        yaml.dump(config_file, f)

substitutes the emission map in the experiment config of the given experiment folder with the one given in input.

def retrieve_experiment_config(benchmark_folder: str) -> colosseum.experiment.config.ExperimentConfig:
174def retrieve_experiment_config(benchmark_folder: str) -> ExperimentConfig:
175    """
176    Returns
177    -------
178    ExperimentConfig
179        The experiment config from the given benchmark folder.
180    """
181    config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml"
182    assert os.path.isfile(
183        config_fp
184    ), f"The folder {benchmark_folder} does not contain a configuration file."
185
186    with open(config_fp, "r") as f:
187        exp_config = yaml.load(f, yaml.Loader)
188    return ExperimentConfig(**exp_config)
Returns
  • ExperimentConfig: The experiment config from the given benchmark folder.
def retrieve_mdp_configs( benchmark_folder: str, return_string=True) -> Union[Dict[Type[colosseum.mdp.base.BaseMDP], str], Dict[Type[colosseum.mdp.base.BaseMDP], Dict[str, str]]]:
191def retrieve_mdp_configs(
192    benchmark_folder: str, return_string=True
193) -> Union[Dict[Type["BaseMDP"], str], Dict[Type["BaseMDP"], Dict[str, str]],]:
194    """
195    retrieves the MDP gin configs of a benchmark.
196
197    Parameters
198    ----------
199    benchmark_folder : ColosseumBenchmark
200        The folder where the benchmark is located.
201    return_string : bool
202        If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By
203        default, the single string format is used.
204
205    Returns
206    -------
207    Union[
208        Dict[Type["BaseMDP"], str],
209        Dict[Type["BaseMDP"], Dict[str, str]],
210    ]
211        The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
212    """
213    return retrieve_gin_configs(
214        ensure_folder(benchmark_folder) + "mdp_configs" + os.sep, return_string
215    )

retrieves the MDP gin configs of a benchmark.

Parameters
  • benchmark_folder (ColosseumBenchmark): The folder where the benchmark is located.
  • return_string (bool): If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By default, the single string format is used.
Returns
  • Union[: Dict[Type["BaseMDP"], str], Dict[Type["BaseMDP"], Dict[str, str]],
  • ]: The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
def retrieve_agent_configs( benchmark_folder: str, return_string=True) -> Union[Dict[Type[colosseum.agent.agents.base.BaseAgent], str], Dict[Type[colosseum.agent.agents.base.BaseAgent], Dict[str, str]]]:
218def retrieve_agent_configs(
219    benchmark_folder: str, return_string=True
220) -> Union[Dict[Type["BaseAgent"], str], Dict[Type["BaseAgent"], Dict[str, str]],]:
221    """
222    retrieves the agent gin configs of a benchmark.
223
224    Parameters
225    ----------
226    benchmark_folder : ColosseumBenchmark
227        The folder where the benchmark is located.
228    return_string : bool
229        If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By
230        default, the single string format is used.
231
232    Returns
233    -------
234    Union[
235        Dict[Type["BaseAgent"], str],
236        Dict[Type["BaseAgent"], Dict[str, str]],
237    ]
238        The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
239    """
240    return retrieve_gin_configs(
241        ensure_folder(benchmark_folder) + "agents_configs" + os.sep, return_string
242    )

retrieves the agent gin configs of a benchmark.

Parameters
  • benchmark_folder (ColosseumBenchmark): The folder where the benchmark is located.
  • return_string (bool): If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By default, the single string format is used.
Returns
  • Union[: Dict[Type["BaseAgent"], str], Dict[Type["BaseAgent"], Dict[str, str]],
  • ]: The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
def retrieve_gin_configs( gin_config_folder: str, return_string: bool) -> Dict[Union[Type[colosseum.mdp.base.BaseMDP], Type[colosseum.agent.agents.base.BaseAgent]], str]:
245def retrieve_gin_configs(
246    gin_config_folder: str, return_string: bool
247) -> Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]:
248    """
249    retrieves the gin configs from a folder.
250
251    Parameters
252    ----------
253    gin_config_folder : ColosseumBenchmark
254        The folder where the gin configs are stored.
255    return_string : bool
256        If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By
257        default, the single string format is used.
258
259    Returns
260    -------
261    Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]
262        The dictionary that for each MDP and agent name contains a list of gin configs obtained from the given folder.
263    """
264
265    gin_config_folder = ensure_folder(gin_config_folder)
266
267    configs = glob(gin_config_folder + "*.gin")
268    assert (
269        len(configs) > 0
270    ), f"The folder {gin_config_folder} does not contain config files"
271
272    gin_configs = dict()
273    for f in configs:
274        name = os.path.basename(f).replace(".gin", "")
275        cl = (
276            get_agent_class_from_name(name)
277            if "agent" in os.path.basename(gin_config_folder[:-1])
278            else get_mdp_class_from_name(name)
279        )
280
281        gin_configs[cl] = [] if return_string else dict()
282        with open(f, "r") as ff:
283            gin_config_file = ff.read() + "\n"
284        for config_prms in sorted(
285            set(re.findall(r"prms_[0-9]+/", gin_config_file)),
286            # Ascending order based on the parameter index
287            key=lambda x: int(x.replace("prms_", "")[:-1]),
288        ):
289            imports = set(re.findall("from.+?import.+?\n", gin_config_file))
290            prms_configs = "".join(re.findall(config_prms + ".+?\n", gin_config_file))
291            if len(imports) > 0:
292                prms_configs = "".join(imports) + prms_configs
293
294            if return_string:
295                gin_configs[cl].append(prms_configs)
296            else:
297                gin_configs[cl][config_prms[:-1]] = prms_configs
298
299        if return_string:
300            gin_configs[cl] = "\n".join(gin_configs[cl])
301
302    return gin_configs

retrieves the gin configs from a folder.

Parameters
  • gin_config_folder (ColosseumBenchmark): The folder where the gin configs are stored.
  • return_string (bool): If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By default, the single string format is used.
Returns
  • Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]: The dictionary that for each MDP and agent name contains a list of gin configs obtained from the given folder.