colosseum.benchmark.utils
1import os 2import re 3import shutil 4from glob import glob 5from typing import List, Dict, TYPE_CHECKING, Type, Union 6 7import yaml 8 9from colosseum.benchmark.benchmark import ColosseumBenchmark, BENCHMARKS_DIRECTORY 10from colosseum.emission_maps import EmissionMap 11from colosseum.experiment import ExperimentConfig 12from colosseum.utils import ensure_folder 13from colosseum.utils.miscellanea import ( 14 get_agent_class_from_name, 15 get_mdp_class_from_name, 16 compare_gin_configs, 17) 18 19if TYPE_CHECKING: 20 from colosseum.mdp.base import BaseMDP 21 from colosseum.agent.agents.base import BaseAgent 22 23 24def get_mdps_configs_from_mdps(mdps: List["BaseMDP"]) -> List[str]: 25 """ 26 Returns 27 ------- 28 List[str] 29 The gin configs of the MDPs 30 """ 31 mdp_configs = dict() 32 for mdp in mdps: 33 if not type(mdp) in mdp_configs: 34 mdp_configs[type(mdp)] = [] 35 mdp_configs[type(mdp)].append(mdp.get_gin_config(len(mdp_configs[type(mdp)]))) 36 return mdp_configs 37 38 39def instantiate_agent_configs( 40 agents_configs: Dict[Type["BaseAgent"], Union[str, None]], 41 benchmark_folder: str, 42): 43 """ 44 instantiates the gin agent configurations in the given folder. 45 46 Parameters 47 ---------- 48 agents_configs : Dict[Type["BaseAgent"], Union[str, None]] 49 The dictionary associates agent classes to their gin config files. If no agent config is given, we try to 50 retrieve it from the cached_hyperparameters folder in package 51 benchmark_folder : str 52 The folder where the corresponding benchmark is located. 53 """ 54 55 # If no agent config is given, we retrieve it from the cached_hyperparameters folder in package 56 for ag_cl in list(agents_configs): 57 if agents_configs[ag_cl] is None: 58 cached_config = ( 59 BENCHMARKS_DIRECTORY 60 + "cached_hyperparameters" 61 + os.sep 62 + "agent_configs" 63 + os.sep 64 + ag_cl.__name__ 65 + ".gin" 66 ) 67 if os.path.isfile(cached_config): 68 with open(cached_config, "r") as f: 69 agents_configs[ag_cl] = f.read() 70 else: 71 raise f"No configuration was given for agent {ag_cl.__name__}" 72 73 if os.path.isdir(benchmark_folder + "agents_configs" + os.sep): 74 try: 75 local_agent_configs = retrieve_agent_configs(benchmark_folder) 76 if not compare_gin_configs(agents_configs, local_agent_configs): 77 raise ValueError( 78 f"The existing agent configs in {benchmark_folder} are different from the one in input." 79 ) 80 81 # If the local folder is corrupted, we eliminate it 82 except AssertionError: 83 shutil.rmtree(benchmark_folder + "agents_configs") 84 else: 85 os.makedirs(ensure_folder(benchmark_folder) + "agents_configs", exist_ok=True) 86 for ag_cl, gin_config in agents_configs.items(): 87 with open( 88 ensure_folder(benchmark_folder) 89 + "agents_configs" 90 + os.sep 91 + ag_cl.__name__ 92 + ".gin", 93 "w", 94 ) as f: 95 f.write(gin_config) 96 97 98def instantiate_benchmark_folder(benchmark: ColosseumBenchmark, benchmark_folder: str): 99 """ 100 instantiates the benchmark locally. If a local benchmark is found then it merges iff they are the same in terms of 101 MDP configs and experiment configurations. 102 103 Parameters 104 ---------- 105 benchmark : ColosseumBenchmark 106 The benchmark to instantiate. 107 benchmark_folder : str 108 The folder where the corresponding benchmark is located. 109 """ 110 111 # Check whether the experiment folder has already been created 112 if os.path.isdir(benchmark_folder) and len(os.listdir(benchmark_folder)) > 0: 113 try: 114 local_benchmark = retrieve_benchmark(benchmark_folder) 115 if local_benchmark != benchmark: 116 raise ValueError( 117 f"The experiment folder {benchmark_folder} is already occupied." 118 ) 119 120 # If the local folder is corrupted, we eliminate it 121 except AssertionError: 122 shutil.rmtree(benchmark_folder) 123 else: 124 benchmark.instantiate(benchmark_folder) 125 126 127def retrieve_benchmark( 128 benchmark_folder: str, experiment_config: ExperimentConfig = None, postfix: str = "" 129) -> ColosseumBenchmark: 130 """ 131 retrieves a benchmark from a folder. 132 133 Parameters 134 ---------- 135 benchmark_folder : ColosseumBenchmark 136 The folder where the benchmark is located. 137 experiment_config : ExperimentConfig 138 The experiment config to be substituted to the default one. By default, no substitution happens. 139 postfix : str 140 The postfix to add to the name of the benchmark. By default, no postfix is added. 141 142 Returns 143 ------- 144 ColosseumBenchmark 145 The retrieved benchmark. 146 """ 147 benchmark = ColosseumBenchmark( 148 os.path.basename(ensure_folder(benchmark_folder)[:-1]) + postfix, 149 retrieve_mdp_configs(benchmark_folder), 150 retrieve_experiment_config(benchmark_folder) 151 if experiment_config is None 152 else experiment_config, 153 ) 154 return benchmark 155 156 157def update_emission_map(benchmark_folder: str, emission_map: EmissionMap): 158 """ 159 substitutes the emission map in the experiment config of the given experiment folder with the one given in input. 160 """ 161 config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml" 162 assert os.path.isfile( 163 config_fp 164 ), f"The folder {benchmark_folder} does not contain a configuration file." 165 166 with open(config_fp, "r") as f: 167 config_file = yaml.load(f, yaml.Loader) 168 config_file["emission_map"] = emission_map.__name__ 169 with open(config_fp, "w") as f: 170 yaml.dump(config_file, f) 171 172 173def retrieve_experiment_config(benchmark_folder: str) -> ExperimentConfig: 174 """ 175 Returns 176 ------- 177 ExperimentConfig 178 The experiment config from the given benchmark folder. 179 """ 180 config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml" 181 assert os.path.isfile( 182 config_fp 183 ), f"The folder {benchmark_folder} does not contain a configuration file." 184 185 with open(config_fp, "r") as f: 186 exp_config = yaml.load(f, yaml.Loader) 187 return ExperimentConfig(**exp_config) 188 189 190def retrieve_mdp_configs( 191 benchmark_folder: str, return_string=True 192) -> Union[Dict[Type["BaseMDP"], str], Dict[Type["BaseMDP"], Dict[str, str]],]: 193 """ 194 retrieves the MDP gin configs of a benchmark. 195 196 Parameters 197 ---------- 198 benchmark_folder : ColosseumBenchmark 199 The folder where the benchmark is located. 200 return_string : bool 201 If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By 202 default, the single string format is used. 203 204 Returns 205 ------- 206 Union[ 207 Dict[Type["BaseMDP"], str], 208 Dict[Type["BaseMDP"], Dict[str, str]], 209 ] 210 The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder. 211 """ 212 return retrieve_gin_configs( 213 ensure_folder(benchmark_folder) + "mdp_configs" + os.sep, return_string 214 ) 215 216 217def retrieve_agent_configs( 218 benchmark_folder: str, return_string=True 219) -> Union[Dict[Type["BaseAgent"], str], Dict[Type["BaseAgent"], Dict[str, str]],]: 220 """ 221 retrieves the agent gin configs of a benchmark. 222 223 Parameters 224 ---------- 225 benchmark_folder : ColosseumBenchmark 226 The folder where the benchmark is located. 227 return_string : bool 228 If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By 229 default, the single string format is used. 230 231 Returns 232 ------- 233 Union[ 234 Dict[Type["BaseAgent"], str], 235 Dict[Type["BaseAgent"], Dict[str, str]], 236 ] 237 The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder. 238 """ 239 return retrieve_gin_configs( 240 ensure_folder(benchmark_folder) + "agents_configs" + os.sep, return_string 241 ) 242 243 244def retrieve_gin_configs( 245 gin_config_folder: str, return_string: bool 246) -> Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]: 247 """ 248 retrieves the gin configs from a folder. 249 250 Parameters 251 ---------- 252 gin_config_folder : ColosseumBenchmark 253 The folder where the gin configs are stored. 254 return_string : bool 255 If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By 256 default, the single string format is used. 257 258 Returns 259 ------- 260 Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str] 261 The dictionary that for each MDP and agent name contains a list of gin configs obtained from the given folder. 262 """ 263 264 gin_config_folder = ensure_folder(gin_config_folder) 265 266 configs = glob(gin_config_folder + "*.gin") 267 assert ( 268 len(configs) > 0 269 ), f"The folder {gin_config_folder} does not contain config files" 270 271 gin_configs = dict() 272 for f in configs: 273 name = os.path.basename(f).replace(".gin", "") 274 cl = ( 275 get_agent_class_from_name(name) 276 if "agent" in os.path.basename(gin_config_folder[:-1]) 277 else get_mdp_class_from_name(name) 278 ) 279 280 gin_configs[cl] = [] if return_string else dict() 281 with open(f, "r") as ff: 282 gin_config_file = ff.read() + "\n" 283 for config_prms in sorted( 284 set(re.findall(r"prms_[0-9]+/", gin_config_file)), 285 # Ascending order based on the parameter index 286 key=lambda x: int(x.replace("prms_", "")[:-1]), 287 ): 288 imports = set(re.findall("from.+?import.+?\n", gin_config_file)) 289 prms_configs = "".join(re.findall(config_prms + ".+?\n", gin_config_file)) 290 if len(imports) > 0: 291 prms_configs = "".join(imports) + prms_configs 292 293 if return_string: 294 gin_configs[cl].append(prms_configs) 295 else: 296 gin_configs[cl][config_prms[:-1]] = prms_configs 297 298 if return_string: 299 gin_configs[cl] = "\n".join(gin_configs[cl]) 300 301 return gin_configs
25def get_mdps_configs_from_mdps(mdps: List["BaseMDP"]) -> List[str]: 26 """ 27 Returns 28 ------- 29 List[str] 30 The gin configs of the MDPs 31 """ 32 mdp_configs = dict() 33 for mdp in mdps: 34 if not type(mdp) in mdp_configs: 35 mdp_configs[type(mdp)] = [] 36 mdp_configs[type(mdp)].append(mdp.get_gin_config(len(mdp_configs[type(mdp)]))) 37 return mdp_configs
Returns
- List[str]: The gin configs of the MDPs
def
instantiate_agent_configs( agents_configs: Dict[Type[colosseum.agent.agents.base.BaseAgent], Optional[str]], benchmark_folder: str):
40def instantiate_agent_configs( 41 agents_configs: Dict[Type["BaseAgent"], Union[str, None]], 42 benchmark_folder: str, 43): 44 """ 45 instantiates the gin agent configurations in the given folder. 46 47 Parameters 48 ---------- 49 agents_configs : Dict[Type["BaseAgent"], Union[str, None]] 50 The dictionary associates agent classes to their gin config files. If no agent config is given, we try to 51 retrieve it from the cached_hyperparameters folder in package 52 benchmark_folder : str 53 The folder where the corresponding benchmark is located. 54 """ 55 56 # If no agent config is given, we retrieve it from the cached_hyperparameters folder in package 57 for ag_cl in list(agents_configs): 58 if agents_configs[ag_cl] is None: 59 cached_config = ( 60 BENCHMARKS_DIRECTORY 61 + "cached_hyperparameters" 62 + os.sep 63 + "agent_configs" 64 + os.sep 65 + ag_cl.__name__ 66 + ".gin" 67 ) 68 if os.path.isfile(cached_config): 69 with open(cached_config, "r") as f: 70 agents_configs[ag_cl] = f.read() 71 else: 72 raise f"No configuration was given for agent {ag_cl.__name__}" 73 74 if os.path.isdir(benchmark_folder + "agents_configs" + os.sep): 75 try: 76 local_agent_configs = retrieve_agent_configs(benchmark_folder) 77 if not compare_gin_configs(agents_configs, local_agent_configs): 78 raise ValueError( 79 f"The existing agent configs in {benchmark_folder} are different from the one in input." 80 ) 81 82 # If the local folder is corrupted, we eliminate it 83 except AssertionError: 84 shutil.rmtree(benchmark_folder + "agents_configs") 85 else: 86 os.makedirs(ensure_folder(benchmark_folder) + "agents_configs", exist_ok=True) 87 for ag_cl, gin_config in agents_configs.items(): 88 with open( 89 ensure_folder(benchmark_folder) 90 + "agents_configs" 91 + os.sep 92 + ag_cl.__name__ 93 + ".gin", 94 "w", 95 ) as f: 96 f.write(gin_config)
instantiates the gin agent configurations in the given folder.
Parameters
- agents_configs (Dict[Type["BaseAgent"], Union[str, None]]): The dictionary associates agent classes to their gin config files. If no agent config is given, we try to retrieve it from the cached_hyperparameters folder in package
- benchmark_folder (str): The folder where the corresponding benchmark is located.
def
instantiate_benchmark_folder( benchmark: colosseum.benchmark.benchmark.ColosseumBenchmark, benchmark_folder: str):
99def instantiate_benchmark_folder(benchmark: ColosseumBenchmark, benchmark_folder: str): 100 """ 101 instantiates the benchmark locally. If a local benchmark is found then it merges iff they are the same in terms of 102 MDP configs and experiment configurations. 103 104 Parameters 105 ---------- 106 benchmark : ColosseumBenchmark 107 The benchmark to instantiate. 108 benchmark_folder : str 109 The folder where the corresponding benchmark is located. 110 """ 111 112 # Check whether the experiment folder has already been created 113 if os.path.isdir(benchmark_folder) and len(os.listdir(benchmark_folder)) > 0: 114 try: 115 local_benchmark = retrieve_benchmark(benchmark_folder) 116 if local_benchmark != benchmark: 117 raise ValueError( 118 f"The experiment folder {benchmark_folder} is already occupied." 119 ) 120 121 # If the local folder is corrupted, we eliminate it 122 except AssertionError: 123 shutil.rmtree(benchmark_folder) 124 else: 125 benchmark.instantiate(benchmark_folder)
instantiates the benchmark locally. If a local benchmark is found then it merges iff they are the same in terms of MDP configs and experiment configurations.
Parameters
- benchmark (ColosseumBenchmark): The benchmark to instantiate.
- benchmark_folder (str): The folder where the corresponding benchmark is located.
def
retrieve_benchmark( benchmark_folder: str, experiment_config: colosseum.experiment.config.ExperimentConfig = None, postfix: str = '') -> colosseum.benchmark.benchmark.ColosseumBenchmark:
128def retrieve_benchmark( 129 benchmark_folder: str, experiment_config: ExperimentConfig = None, postfix: str = "" 130) -> ColosseumBenchmark: 131 """ 132 retrieves a benchmark from a folder. 133 134 Parameters 135 ---------- 136 benchmark_folder : ColosseumBenchmark 137 The folder where the benchmark is located. 138 experiment_config : ExperimentConfig 139 The experiment config to be substituted to the default one. By default, no substitution happens. 140 postfix : str 141 The postfix to add to the name of the benchmark. By default, no postfix is added. 142 143 Returns 144 ------- 145 ColosseumBenchmark 146 The retrieved benchmark. 147 """ 148 benchmark = ColosseumBenchmark( 149 os.path.basename(ensure_folder(benchmark_folder)[:-1]) + postfix, 150 retrieve_mdp_configs(benchmark_folder), 151 retrieve_experiment_config(benchmark_folder) 152 if experiment_config is None 153 else experiment_config, 154 ) 155 return benchmark
retrieves a benchmark from a folder.
Parameters
- benchmark_folder (ColosseumBenchmark): The folder where the benchmark is located.
- experiment_config (ExperimentConfig): The experiment config to be substituted to the default one. By default, no substitution happens.
- postfix (str): The postfix to add to the name of the benchmark. By default, no postfix is added.
Returns
- ColosseumBenchmark: The retrieved benchmark.
def
update_emission_map( benchmark_folder: str, emission_map: colosseum.emission_maps.base.EmissionMap):
158def update_emission_map(benchmark_folder: str, emission_map: EmissionMap): 159 """ 160 substitutes the emission map in the experiment config of the given experiment folder with the one given in input. 161 """ 162 config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml" 163 assert os.path.isfile( 164 config_fp 165 ), f"The folder {benchmark_folder} does not contain a configuration file." 166 167 with open(config_fp, "r") as f: 168 config_file = yaml.load(f, yaml.Loader) 169 config_file["emission_map"] = emission_map.__name__ 170 with open(config_fp, "w") as f: 171 yaml.dump(config_file, f)
substitutes the emission map in the experiment config of the given experiment folder with the one given in input.
def
retrieve_experiment_config(benchmark_folder: str) -> colosseum.experiment.config.ExperimentConfig:
174def retrieve_experiment_config(benchmark_folder: str) -> ExperimentConfig: 175 """ 176 Returns 177 ------- 178 ExperimentConfig 179 The experiment config from the given benchmark folder. 180 """ 181 config_fp = ensure_folder(benchmark_folder) + "experiment_config.yml" 182 assert os.path.isfile( 183 config_fp 184 ), f"The folder {benchmark_folder} does not contain a configuration file." 185 186 with open(config_fp, "r") as f: 187 exp_config = yaml.load(f, yaml.Loader) 188 return ExperimentConfig(**exp_config)
Returns
- ExperimentConfig: The experiment config from the given benchmark folder.
def
retrieve_mdp_configs( benchmark_folder: str, return_string=True) -> Union[Dict[Type[colosseum.mdp.base.BaseMDP], str], Dict[Type[colosseum.mdp.base.BaseMDP], Dict[str, str]]]:
191def retrieve_mdp_configs( 192 benchmark_folder: str, return_string=True 193) -> Union[Dict[Type["BaseMDP"], str], Dict[Type["BaseMDP"], Dict[str, str]],]: 194 """ 195 retrieves the MDP gin configs of a benchmark. 196 197 Parameters 198 ---------- 199 benchmark_folder : ColosseumBenchmark 200 The folder where the benchmark is located. 201 return_string : bool 202 If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By 203 default, the single string format is used. 204 205 Returns 206 ------- 207 Union[ 208 Dict[Type["BaseMDP"], str], 209 Dict[Type["BaseMDP"], Dict[str, str]], 210 ] 211 The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder. 212 """ 213 return retrieve_gin_configs( 214 ensure_folder(benchmark_folder) + "mdp_configs" + os.sep, return_string 215 )
retrieves the MDP gin configs of a benchmark.
Parameters
- benchmark_folder (ColosseumBenchmark): The folder where the benchmark is located.
- return_string (bool): If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By default, the single string format is used.
Returns
- Union[: Dict[Type["BaseMDP"], str], Dict[Type["BaseMDP"], Dict[str, str]],
- ]: The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
def
retrieve_agent_configs( benchmark_folder: str, return_string=True) -> Union[Dict[Type[colosseum.agent.agents.base.BaseAgent], str], Dict[Type[colosseum.agent.agents.base.BaseAgent], Dict[str, str]]]:
218def retrieve_agent_configs( 219 benchmark_folder: str, return_string=True 220) -> Union[Dict[Type["BaseAgent"], str], Dict[Type["BaseAgent"], Dict[str, str]],]: 221 """ 222 retrieves the agent gin configs of a benchmark. 223 224 Parameters 225 ---------- 226 benchmark_folder : ColosseumBenchmark 227 The folder where the benchmark is located. 228 return_string : bool 229 If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By 230 default, the single string format is used. 231 232 Returns 233 ------- 234 Union[ 235 Dict[Type["BaseAgent"], str], 236 Dict[Type["BaseAgent"], Dict[str, str]], 237 ] 238 The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder. 239 """ 240 return retrieve_gin_configs( 241 ensure_folder(benchmark_folder) + "agents_configs" + os.sep, return_string 242 )
retrieves the agent gin configs of a benchmark.
Parameters
- benchmark_folder (ColosseumBenchmark): The folder where the benchmark is located.
- return_string (bool): If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By default, the single string format is used.
Returns
- Union[: Dict[Type["BaseAgent"], str], Dict[Type["BaseAgent"], Dict[str, str]],
- ]: The dictionary that for each MDP name contains a list of gin configs obtained from the given benchmark folder.
def
retrieve_gin_configs( gin_config_folder: str, return_string: bool) -> Dict[Union[Type[colosseum.mdp.base.BaseMDP], Type[colosseum.agent.agents.base.BaseAgent]], str]:
245def retrieve_gin_configs( 246 gin_config_folder: str, return_string: bool 247) -> Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]: 248 """ 249 retrieves the gin configs from a folder. 250 251 Parameters 252 ---------- 253 gin_config_folder : ColosseumBenchmark 254 The folder where the gin configs are stored. 255 return_string : bool 256 If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By 257 default, the single string format is used. 258 259 Returns 260 ------- 261 Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str] 262 The dictionary that for each MDP and agent name contains a list of gin configs obtained from the given folder. 263 """ 264 265 gin_config_folder = ensure_folder(gin_config_folder) 266 267 configs = glob(gin_config_folder + "*.gin") 268 assert ( 269 len(configs) > 0 270 ), f"The folder {gin_config_folder} does not contain config files" 271 272 gin_configs = dict() 273 for f in configs: 274 name = os.path.basename(f).replace(".gin", "") 275 cl = ( 276 get_agent_class_from_name(name) 277 if "agent" in os.path.basename(gin_config_folder[:-1]) 278 else get_mdp_class_from_name(name) 279 ) 280 281 gin_configs[cl] = [] if return_string else dict() 282 with open(f, "r") as ff: 283 gin_config_file = ff.read() + "\n" 284 for config_prms in sorted( 285 set(re.findall(r"prms_[0-9]+/", gin_config_file)), 286 # Ascending order based on the parameter index 287 key=lambda x: int(x.replace("prms_", "")[:-1]), 288 ): 289 imports = set(re.findall("from.+?import.+?\n", gin_config_file)) 290 prms_configs = "".join(re.findall(config_prms + ".+?\n", gin_config_file)) 291 if len(imports) > 0: 292 prms_configs = "".join(imports) + prms_configs 293 294 if return_string: 295 gin_configs[cl].append(prms_configs) 296 else: 297 gin_configs[cl][config_prms[:-1]] = prms_configs 298 299 if return_string: 300 gin_configs[cl] = "\n".join(gin_configs[cl]) 301 302 return gin_configs
retrieves the gin configs from a folder.
Parameters
- gin_config_folder (ColosseumBenchmark): The folder where the gin configs are stored.
- return_string (bool): If False, the gin configs are returned as a list of strings. If True, the list is joined in a singles string. By default, the single string format is used.
Returns
- Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str]: The dictionary that for each MDP and agent name contains a list of gin configs obtained from the given folder.