
  1import cProfile
  2import collections
  3import importlib
  4import inspect
  5import numbers
  6import os
  7from glob import glob
  8from io import StringIO
  9from typing import TYPE_CHECKING, Iterable, List, Type, Union, Any, Dict
 11import dm_env
 12import numpy as np
 13from scipy.stats import rv_continuous
 14from tqdm import tqdm
 16import colosseum
 17import colosseum.config as config
 20    from colosseum.mdp import BaseMDP
 21    from colosseum.agent.agents.base import BaseAgent
 24def rounding_nested_structure(x: Dict):
 25    """
 27    """
 28    if isinstance(x, str):
 29        return x
 30    if isinstance(x, dict):
 31        return type(x)(
 32            (key, rounding_nested_structure(value)) for key, value in x.items()
 33        )
 34    if isinstance(x, collections.Container):
 35        return type(x)(rounding_nested_structure(value) for value in x)
 36    if isinstance(x, numbers.Number):
 37        return round(x, config.get_n_floating_sampling_hyperparameters())
 38    return x
 41def compare_gin_configs(
 42    gin_configs1: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str],
 43    gin_configs2: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str],
 44) -> bool:
 45    """
 46    Returns
 47    -------
 48    bool
 49        True, if the two gin configs are identical.
 50    """
 51    if set(gin_configs1) != set(gin_configs2):
 52        return False
 54    gin_configs1 = set(
 55        map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs1.values())
 56    )
 57    gin_configs2 = set(
 58        map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs2.values())
 59    )
 60    return gin_configs1 == gin_configs2
 63def sample_mdp_gin_configs(
 64    mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42
 65) -> List[str]:
 66    """
 67    Parameters
 68    ----------
 69    mdp_class : Type["BaseMDP"]
 70        The MDP class to sample from.
 71    n : int
 72        The number of samples. By default, one sample is taken.
 73    seed : int
 74        The random seed. By default, it is set to 42.
 76    Returns
 77    -------
 78    List[str]
 79        The n sampled gin configs.
 80    """
 81    return [
 82        mdp_class.produce_gin_file_from_mdp_parameters(params, mdp_class.__name__, i)
 83        for i, params in enumerate(mdp_class.sample_parameters(n, seed))
 84    ]
 87def sample_mdp_gin_configs_file(
 88    mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42
 89) -> str:
 90    """
 91    Parameters
 92    ----------
 93    mdp_class : Type["BaseMDP"]
 94        The MDP class to sample from.
 95    n : int
 96        The number of samples. By default, one sample is taken.
 97    seed : int
 98        The random seed. By default, it is set to 42.
100    Returns
101    -------
102    str
103        The n sampled gin configs as a single file string.
104    """
105    return "\n".join(sample_mdp_gin_configs(mdp_class, n, seed))
108def get_empty_ts(state: Any) -> dm_env.TimeStep:
109    return dm_env.TimeStep(dm_env.StepType.MID, 0, 0, state)
112def profile(file_path):
113    def decorator(f):
114        print(f"Profiling {f}")
116        def inner(*args, **kwargs):
117            pr = cProfile.Profile()
118            pr.enable()
119            f(*args, **kwargs)
120            pr.disable()
121            # after your program ends
122            pr.dump_stats(file_path)
124        return inner
126    return decorator
129def get_colosseum_mdp_classes(episodic: bool = None) -> List[Type["BaseMDP"]]:
130    """
131    Returns
132    -------
133    List[Type["BaseMDP"]]
134        All available MDP classes in the package.
135    """
136    if episodic is None:
137        return _get_colosseum_mdp_classes() + _get_colosseum_mdp_classes(False)
138    if episodic:
139        return _get_colosseum_mdp_classes()
140    return _get_colosseum_mdp_classes(False)
143def _get_colosseum_mdp_classes(episodic=True) -> List[Type["BaseMDP"]]:
144    import colosseum
146    kw = "Episodic" if episodic else "Continuous"
147    mdp_path = "finite_horizon" if episodic else "infinite_horizon"
148    return [
149        next(
150            filter(
151                lambda x: kw in x[0] and "MDP" not in x[0],
152                importlib.import_module(
153                    mdp_file[mdp_file.find("colosseum") :].replace(os.sep, ".")[:-3]
154                ).__dict__.items(),
155            )
156        )[1]
157        for mdp_file in glob(
158            f"{os.path.dirname(inspect.getfile(colosseum))}{os.sep}mdp{os.sep}**{os.sep}{mdp_path}.py",
159            recursive=True,
160        )
161    ]
164def get_mdp_class_from_name(mdp_class_name: str) -> Type["BaseMDP"]:
165    """
166    Returns
167    -------
168    Type["BaseMDP"]
169        The MDP class corresponding to the name in input.
170    """
171    try:
172        return next(
173            filter(lambda c: c.__name__ == mdp_class_name, get_colosseum_mdp_classes())
174        )
175    except StopIteration:
176        raise ModuleNotFoundError(
177            f"The MDP class {mdp_class_name} was not found in colosseum. Please check the correct spelling and the result of "
178            f"get_colosseum_mdp_classes()"
179        )
182def get_colosseum_agent_classes(episodic: bool = None) -> List[Type["BaseAgent"]]:
183    """
184    Returns
185    -------
186    List[Type["BaseAgent"]]
187        All available agent classes in the package.
188    """
189    if episodic is None:
190        return _get_colosseum_agent_classes(True) + _get_colosseum_agent_classes(False)
191    if episodic:
192        return _get_colosseum_agent_classes(True)
193    return _get_colosseum_agent_classes(False)
196def _get_colosseum_agent_classes(episodic: bool) -> List[Type["BaseAgent"]]:
197    agent_path = "episodic" if episodic else "infinite_horizon"
198    kw = "Episodic" if episodic else "Continuous"
199    return [
200        next(
201            filter(
202                lambda x: kw in x[0],
203                importlib.import_module(
204                    agent_file[agent_file.find("colosseum") :].replace(os.sep, ".")[:-3]
205                ).__dict__.items(),
206            )
207        )[1]
208        for agent_file in glob(
209            f"{os.path.dirname(inspect.getfile(colosseum))}{os.sep}agent{os.sep}agents{os.sep}{agent_path}"
210            f"{os.sep}**{os.sep}[a-z]*.py",
211            recursive=True,
212        )
213    ]
216def get_agent_class_from_name(agent_class_name: str) -> Type["BaseAgent"]:
217    """
218    Returns
219    -------
220    Type["BaseAgent"]
221        The agent class corresponding to the name in input.
222    """
223    return next(
224        filter(
225            lambda c: c.__name__ == agent_class_name,
226            get_colosseum_agent_classes() + config.get_external_agent_classes(),
227        )
228    )
230    try:
231        return next(
232            filter(
233                lambda c: c.__name__ == agent_class_name,
234                get_colosseum_agent_classes() + config.get_external_agent_classes(),
235            )
236        )
237    except StopIteration:
238        raise ModuleNotFoundError(
239            f"The agent class {agent_class_name} was not found in colosseum. The available classes are {get_colosseum_agent_classes() + config.get_external_agent_classes()}"
240        )
243def ensure_folder(path: str) -> str:
244    """
245    Returns
246    -------
247    str
248        The path with the os.sep at the end.
249    """
250    return path if path[-1] == os.sep else (path + os.sep)
253def get_dist(dist_name, args):
254    if dist_name == "deterministic":
255        return deterministic(*args)
256    return importlib.import_module(f"scipy.stats").__getattribute__(dist_name)(*args)
259class deterministic_gen(rv_continuous):
260    def _cdf(self, x):
261        return np.where(x < 0, 0.0, 1.0)
263    def _stats(self):
264        return 0.0, 0.0, 0.0, 0.0
266    def _rvs(self, size=None, random_state=None):
267        return np.zeros(shape=size)
270deterministic = deterministic_gen(name="deterministic")
273def state_occurencens_to_counts(occurences: List[int], N: int) -> np.ndarray:
274    x = np.zeros(N)
275    for s, c in dict(zip(*np.unique(occurences, return_counts=True))).items():
276        x[s] = c
277    return x
280def check_distributions(dists: List[Union[rv_continuous, None]], are_stochastic: bool):
281    """
282    checks that the distribution given in input respects the necessary conditions.
284    Parameters
285    ----------
286    dists : List[Union[rv_continuous, None]]
287        is the list of distributions.
288    are_stochastic : bool
289        whether the distributions are supposed to be stochastic.
290    """
291    # You either define all or none of the distribution
292    assert dists.count(None) in [0, len(dists)]
294    # Double check that the distributions in input matches the stochasticity of the reward parameter
295    if dists[0] is not None:
296        if are_stochastic:
297            assert all(type(dist.dist) != deterministic_gen for dist in dists)
298        else:
299            assert all(type(dist.dist) == deterministic_gen for dist in dists)
302def get_loop(x: Iterable) -> Iterable:
303    """
304    Returns
305    -------
306    Iterable
307        An iterable that respects the current level of verbosity.
308    """
309    if config.VERBOSE_LEVEL != 0:
310        if type(config.VERBOSE_LEVEL) == int:
311            return tqdm(x, desc="Diameter calculation", mininterval=5)
312        s = StringIO()
313        return tqdm(x, desc="Diameter calculation", file=s, mininterval=5)
314    return x
