colosseum.utils.miscellanea
1import cProfile 2import collections 3import importlib 4import inspect 5import numbers 6import os 7from glob import glob 8from io import StringIO 9from typing import TYPE_CHECKING, Iterable, List, Type, Union, Any, Dict 10 11import dm_env 12import numpy as np 13from scipy.stats import rv_continuous 14from tqdm import tqdm 15 16import colosseum 17import colosseum.config as config 18 19if TYPE_CHECKING: 20 from colosseum.mdp import BaseMDP 21 from colosseum.agent.agents.base import BaseAgent 22 23 24def rounding_nested_structure(x: Dict): 25 """ 26 https://stackoverflow.com/questions/7076254/rounding-decimals-in-nested-data-structures-in-python 27 """ 28 if isinstance(x, str): 29 return x 30 if isinstance(x, dict): 31 return type(x)( 32 (key, rounding_nested_structure(value)) for key, value in x.items() 33 ) 34 if isinstance(x, collections.Container): 35 return type(x)(rounding_nested_structure(value) for value in x) 36 if isinstance(x, numbers.Number): 37 return round(x, config.get_n_floating_sampling_hyperparameters()) 38 return x 39 40 41def compare_gin_configs( 42 gin_configs1: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str], 43 gin_configs2: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str], 44) -> bool: 45 """ 46 Returns 47 ------- 48 bool 49 True, if the two gin configs are identical. 50 """ 51 if set(gin_configs1) != set(gin_configs2): 52 return False 53 54 gin_configs1 = set( 55 map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs1.values()) 56 ) 57 gin_configs2 = set( 58 map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs2.values()) 59 ) 60 return gin_configs1 == gin_configs2 61 62 63def sample_mdp_gin_configs( 64 mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42 65) -> List[str]: 66 """ 67 Parameters 68 ---------- 69 mdp_class : Type["BaseMDP"] 70 The MDP class to sample from. 71 n : int 72 The number of samples. By default, one sample is taken. 73 seed : int 74 The random seed. By default, it is set to 42. 75 76 Returns 77 ------- 78 List[str] 79 The n sampled gin configs. 80 """ 81 return [ 82 mdp_class.produce_gin_file_from_mdp_parameters(params, mdp_class.__name__, i) 83 for i, params in enumerate(mdp_class.sample_parameters(n, seed)) 84 ] 85 86 87def sample_mdp_gin_configs_file( 88 mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42 89) -> str: 90 """ 91 Parameters 92 ---------- 93 mdp_class : Type["BaseMDP"] 94 The MDP class to sample from. 95 n : int 96 The number of samples. By default, one sample is taken. 97 seed : int 98 The random seed. By default, it is set to 42. 99 100 Returns 101 ------- 102 str 103 The n sampled gin configs as a single file string. 104 """ 105 return "\n".join(sample_mdp_gin_configs(mdp_class, n, seed)) 106 107 108def get_empty_ts(state: Any) -> dm_env.TimeStep: 109 return dm_env.TimeStep(dm_env.StepType.MID, 0, 0, state) 110 111 112def profile(file_path): 113 def decorator(f): 114 print(f"Profiling {f}") 115 116 def inner(*args, **kwargs): 117 pr = cProfile.Profile() 118 pr.enable() 119 f(*args, **kwargs) 120 pr.disable() 121 # after your program ends 122 pr.dump_stats(file_path) 123 124 return inner 125 126 return decorator 127 128 129def get_colosseum_mdp_classes(episodic: bool = None) -> List[Type["BaseMDP"]]: 130 """ 131 Returns 132 ------- 133 List[Type["BaseMDP"]] 134 All available MDP classes in the package. 135 """ 136 if episodic is None: 137 return _get_colosseum_mdp_classes() + _get_colosseum_mdp_classes(False) 138 if episodic: 139 return _get_colosseum_mdp_classes() 140 return _get_colosseum_mdp_classes(False) 141 142 143def _get_colosseum_mdp_classes(episodic=True) -> List[Type["BaseMDP"]]: 144 import colosseum 145 146 kw = "Episodic" if episodic else "Continuous" 147 mdp_path = "finite_horizon" if episodic else "infinite_horizon" 148 return [ 149 next( 150 filter( 151 lambda x: kw in x[0] and "MDP" not in x[0], 152 importlib.import_module( 153 mdp_file[mdp_file.find("colosseum") :].replace(os.sep, ".")[:-3] 154 ).__dict__.items(), 155 ) 156 )[1] 157 for mdp_file in glob( 158 f"{os.path.dirname(inspect.getfile(colosseum))}{os.sep}mdp{os.sep}**{os.sep}{mdp_path}.py", 159 recursive=True, 160 ) 161 ] 162 163 164def get_mdp_class_from_name(mdp_class_name: str) -> Type["BaseMDP"]: 165 """ 166 Returns 167 ------- 168 Type["BaseMDP"] 169 The MDP class corresponding to the name in input. 170 """ 171 try: 172 return next( 173 filter(lambda c: c.__name__ == mdp_class_name, get_colosseum_mdp_classes()) 174 ) 175 except StopIteration: 176 raise ModuleNotFoundError( 177 f"The MDP class {mdp_class_name} was not found in colosseum. Please check the correct spelling and the result of " 178 f"get_colosseum_mdp_classes()" 179 ) 180 181 182def get_colosseum_agent_classes(episodic: bool = None) -> List[Type["BaseAgent"]]: 183 """ 184 Returns 185 ------- 186 List[Type["BaseAgent"]] 187 All available agent classes in the package. 188 """ 189 if episodic is None: 190 return _get_colosseum_agent_classes(True) + _get_colosseum_agent_classes(False) 191 if episodic: 192 return _get_colosseum_agent_classes(True) 193 return _get_colosseum_agent_classes(False) 194 195 196def _get_colosseum_agent_classes(episodic: bool) -> List[Type["BaseAgent"]]: 197 agent_path = "episodic" if episodic else "infinite_horizon" 198 kw = "Episodic" if episodic else "Continuous" 199 return [ 200 next( 201 filter( 202 lambda x: kw in x[0], 203 importlib.import_module( 204 agent_file[agent_file.find("colosseum") :].replace(os.sep, ".")[:-3] 205 ).__dict__.items(), 206 ) 207 )[1] 208 for agent_file in glob( 209 f"{os.path.dirname(inspect.getfile(colosseum))}{os.sep}agent{os.sep}agents{os.sep}{agent_path}" 210 f"{os.sep}**{os.sep}[a-z]*.py", 211 recursive=True, 212 ) 213 ] 214 215 216def get_agent_class_from_name(agent_class_name: str) -> Type["BaseAgent"]: 217 """ 218 Returns 219 ------- 220 Type["BaseAgent"] 221 The agent class corresponding to the name in input. 222 """ 223 return next( 224 filter( 225 lambda c: c.__name__ == agent_class_name, 226 get_colosseum_agent_classes() + config.get_external_agent_classes(), 227 ) 228 ) 229 230 try: 231 return next( 232 filter( 233 lambda c: c.__name__ == agent_class_name, 234 get_colosseum_agent_classes() + config.get_external_agent_classes(), 235 ) 236 ) 237 except StopIteration: 238 raise ModuleNotFoundError( 239 f"The agent class {agent_class_name} was not found in colosseum. The available classes are {get_colosseum_agent_classes() + config.get_external_agent_classes()}" 240 ) 241 242 243def ensure_folder(path: str) -> str: 244 """ 245 Returns 246 ------- 247 str 248 The path with the os.sep at the end. 249 """ 250 return path if path[-1] == os.sep else (path + os.sep) 251 252 253def get_dist(dist_name, args): 254 if dist_name == "deterministic": 255 return deterministic(*args) 256 return importlib.import_module(f"scipy.stats").__getattribute__(dist_name)(*args) 257 258 259class deterministic_gen(rv_continuous): 260 def _cdf(self, x): 261 return np.where(x < 0, 0.0, 1.0) 262 263 def _stats(self): 264 return 0.0, 0.0, 0.0, 0.0 265 266 def _rvs(self, size=None, random_state=None): 267 return np.zeros(shape=size) 268 269 270deterministic = deterministic_gen(name="deterministic") 271 272 273def state_occurencens_to_counts(occurences: List[int], N: int) -> np.ndarray: 274 x = np.zeros(N) 275 for s, c in dict(zip(*np.unique(occurences, return_counts=True))).items(): 276 x[s] = c 277 return x 278 279 280def check_distributions(dists: List[Union[rv_continuous, None]], are_stochastic: bool): 281 """ 282 checks that the distribution given in input respects the necessary conditions. 283 284 Parameters 285 ---------- 286 dists : List[Union[rv_continuous, None]] 287 is the list of distributions. 288 are_stochastic : bool 289 whether the distributions are supposed to be stochastic. 290 """ 291 # You either define all or none of the distribution 292 assert dists.count(None) in [0, len(dists)] 293 294 # Double check that the distributions in input matches the stochasticity of the reward parameter 295 if dists[0] is not None: 296 if are_stochastic: 297 assert all(type(dist.dist) != deterministic_gen for dist in dists) 298 else: 299 assert all(type(dist.dist) == deterministic_gen for dist in dists) 300 301 302def get_loop(x: Iterable) -> Iterable: 303 """ 304 Returns 305 ------- 306 Iterable 307 An iterable that respects the current level of verbosity. 308 """ 309 if config.VERBOSE_LEVEL != 0: 310 if type(config.VERBOSE_LEVEL) == int: 311 return tqdm(x, desc="Diameter calculation", mininterval=5) 312 s = StringIO() 313 return tqdm(x, desc="Diameter calculation", file=s, mininterval=5) 314 return x
25def rounding_nested_structure(x: Dict): 26 """ 27 https://stackoverflow.com/questions/7076254/rounding-decimals-in-nested-data-structures-in-python 28 """ 29 if isinstance(x, str): 30 return x 31 if isinstance(x, dict): 32 return type(x)( 33 (key, rounding_nested_structure(value)) for key, value in x.items() 34 ) 35 if isinstance(x, collections.Container): 36 return type(x)(rounding_nested_structure(value) for value in x) 37 if isinstance(x, numbers.Number): 38 return round(x, config.get_n_floating_sampling_hyperparameters()) 39 return x
42def compare_gin_configs( 43 gin_configs1: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str], 44 gin_configs2: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str], 45) -> bool: 46 """ 47 Returns 48 ------- 49 bool 50 True, if the two gin configs are identical. 51 """ 52 if set(gin_configs1) != set(gin_configs2): 53 return False 54 55 gin_configs1 = set( 56 map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs1.values()) 57 ) 58 gin_configs2 = set( 59 map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs2.values()) 60 ) 61 return gin_configs1 == gin_configs2
Returns
- bool: True, if the two gin configs are identical.
64def sample_mdp_gin_configs( 65 mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42 66) -> List[str]: 67 """ 68 Parameters 69 ---------- 70 mdp_class : Type["BaseMDP"] 71 The MDP class to sample from. 72 n : int 73 The number of samples. By default, one sample is taken. 74 seed : int 75 The random seed. By default, it is set to 42. 76 77 Returns 78 ------- 79 List[str] 80 The n sampled gin configs. 81 """ 82 return [ 83 mdp_class.produce_gin_file_from_mdp_parameters(params, mdp_class.__name__, i) 84 for i, params in enumerate(mdp_class.sample_parameters(n, seed)) 85 ]
Parameters
- mdp_class (Type["BaseMDP"]): The MDP class to sample from.
- n (int): The number of samples. By default, one sample is taken.
- seed (int): The random seed. By default, it is set to 42.
Returns
- List[str]: The n sampled gin configs.
88def sample_mdp_gin_configs_file( 89 mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42 90) -> str: 91 """ 92 Parameters 93 ---------- 94 mdp_class : Type["BaseMDP"] 95 The MDP class to sample from. 96 n : int 97 The number of samples. By default, one sample is taken. 98 seed : int 99 The random seed. By default, it is set to 42. 100 101 Returns 102 ------- 103 str 104 The n sampled gin configs as a single file string. 105 """ 106 return "\n".join(sample_mdp_gin_configs(mdp_class, n, seed))
Parameters
- mdp_class (Type["BaseMDP"]): The MDP class to sample from.
- n (int): The number of samples. By default, one sample is taken.
- seed (int): The random seed. By default, it is set to 42.
Returns
- str: The n sampled gin configs as a single file string.
113def profile(file_path): 114 def decorator(f): 115 print(f"Profiling {f}") 116 117 def inner(*args, **kwargs): 118 pr = cProfile.Profile() 119 pr.enable() 120 f(*args, **kwargs) 121 pr.disable() 122 # after your program ends 123 pr.dump_stats(file_path) 124 125 return inner 126 127 return decorator
130def get_colosseum_mdp_classes(episodic: bool = None) -> List[Type["BaseMDP"]]: 131 """ 132 Returns 133 ------- 134 List[Type["BaseMDP"]] 135 All available MDP classes in the package. 136 """ 137 if episodic is None: 138 return _get_colosseum_mdp_classes() + _get_colosseum_mdp_classes(False) 139 if episodic: 140 return _get_colosseum_mdp_classes() 141 return _get_colosseum_mdp_classes(False)
Returns
- List[Type["BaseMDP"]]: All available MDP classes in the package.
165def get_mdp_class_from_name(mdp_class_name: str) -> Type["BaseMDP"]: 166 """ 167 Returns 168 ------- 169 Type["BaseMDP"] 170 The MDP class corresponding to the name in input. 171 """ 172 try: 173 return next( 174 filter(lambda c: c.__name__ == mdp_class_name, get_colosseum_mdp_classes()) 175 ) 176 except StopIteration: 177 raise ModuleNotFoundError( 178 f"The MDP class {mdp_class_name} was not found in colosseum. Please check the correct spelling and the result of " 179 f"get_colosseum_mdp_classes()" 180 )
Returns
- Type["BaseMDP"]: The MDP class corresponding to the name in input.
183def get_colosseum_agent_classes(episodic: bool = None) -> List[Type["BaseAgent"]]: 184 """ 185 Returns 186 ------- 187 List[Type["BaseAgent"]] 188 All available agent classes in the package. 189 """ 190 if episodic is None: 191 return _get_colosseum_agent_classes(True) + _get_colosseum_agent_classes(False) 192 if episodic: 193 return _get_colosseum_agent_classes(True) 194 return _get_colosseum_agent_classes(False)
Returns
- List[Type["BaseAgent"]]: All available agent classes in the package.
217def get_agent_class_from_name(agent_class_name: str) -> Type["BaseAgent"]: 218 """ 219 Returns 220 ------- 221 Type["BaseAgent"] 222 The agent class corresponding to the name in input. 223 """ 224 return next( 225 filter( 226 lambda c: c.__name__ == agent_class_name, 227 get_colosseum_agent_classes() + config.get_external_agent_classes(), 228 ) 229 ) 230 231 try: 232 return next( 233 filter( 234 lambda c: c.__name__ == agent_class_name, 235 get_colosseum_agent_classes() + config.get_external_agent_classes(), 236 ) 237 ) 238 except StopIteration: 239 raise ModuleNotFoundError( 240 f"The agent class {agent_class_name} was not found in colosseum. The available classes are {get_colosseum_agent_classes() + config.get_external_agent_classes()}" 241 )
Returns
- Type["BaseAgent"]: The agent class corresponding to the name in input.
244def ensure_folder(path: str) -> str: 245 """ 246 Returns 247 ------- 248 str 249 The path with the os.sep at the end. 250 """ 251 return path if path[-1] == os.sep else (path + os.sep)
Returns
- str: The path with the os.sep at the end.
260class deterministic_gen(rv_continuous): 261 def _cdf(self, x): 262 return np.where(x < 0, 0.0, 1.0) 263 264 def _stats(self): 265 return 0.0, 0.0, 0.0, 0.0 266 267 def _rvs(self, size=None, random_state=None): 268 return np.zeros(shape=size)
A generic continuous random variable class meant for subclassing.
rv_continuous
is a base class to construct specific distribution classes
and instances for continuous random variables. It cannot be used
directly as a distribution.
Parameters
- momtype (int, optional): The type of generic moment calculation to use: 0 for pdf, 1 (default) for ppf.
- a (float, optional): Lower bound of the support of the distribution, default is minus infinity.
- b (float, optional): Upper bound of the support of the distribution, default is plus infinity.
- xtol (float, optional): The tolerance for fixed point calculation for generic ppf.
- badvalue (float, optional): The value in a result arrays that indicates a value that for which some argument restriction is violated, default is np.nan.
- name (str, optional): The name of the instance. This string is used to construct the default example for distributions.
- longname (str, optional):
This string is used as part of the first line of the docstring returned
when a subclass has no docstring of its own. Note:
longname
exists for backwards compatibility, do not use for new subclasses. - shapes (str, optional):
The shape of the distribution. For example
"m, n"
for a distribution that takes two integers as the two shape arguments for all its methods. If not provided, shape parameters will be inferred from the signature of the private methods,_pdf
and_cdf
of the instance. - extradoc (str, optional, deprecated):
This string is used as the last part of the docstring returned when a
subclass has no docstring of its own. Note:
extradoc
exists for backwards compatibility, do not use for new subclasses. seed ({None, int,
numpy.random.Generator
,):numpy.random.RandomState
}, optionalIf
seed
is None (ornp.random
), thenumpy.random.RandomState
singleton is used. Ifseed
is an int, a newRandomState
instance is used, seeded withseed
. Ifseed
is already aGenerator
orRandomState
instance then that instance is used.
Methods
rvs pdf logpdf cdf logcdf sf logsf ppf isf moment stats entropy expect median mean std var interval __call__ fit fit_loc_scale nnlf support
Notes
Public methods of an instance of a distribution class (e.g., pdf
,
cdf
) check their arguments and pass valid arguments to private,
computational methods (_pdf
, _cdf
). For pdf(x)
, x
is valid
if it is within the support of the distribution.
Whether a shape parameter is valid is decided by an _argcheck
method
(which defaults to checking that its arguments are strictly positive.)
Subclassing
New random variables can be defined by subclassing the rv_continuous
class
and re-defining at least the _pdf
or the _cdf
method (normalized
to location 0 and scale 1).
If positive argument checking is not correct for your RV
then you will also need to re-define the _argcheck
method.
For most of the scipy.stats distributions, the support interval doesn't
depend on the shape parameters. x
being in the support interval is
equivalent to self.a <= x <= self.b
. If either of the endpoints of
the support do depend on the shape parameters, then
i) the distribution must implement the _get_support
method; and
ii) those dependent endpoints must be omitted from the distribution's
call to the rv_continuous
initializer.
Correct, but potentially slow defaults exist for the remaining methods but for speed and/or accuracy you can over-ride::
_logpdf, _cdf, _logcdf, _ppf, _rvs, _isf, _sf, _logsf
The default method _rvs
relies on the inverse of the cdf, _ppf
,
applied to a uniform random variate. In order to generate random variates
efficiently, either the default _ppf
needs to be overwritten (e.g.
if the inverse cdf can expressed in an explicit form) or a sampling
method needs to be implemented in a custom _rvs
method.
If possible, you should override _isf
, _sf
or _logsf
.
The main reason would be to improve numerical accuracy: for example,
the survival function _sf
is computed as 1 - _cdf
which can
result in loss of precision if _cdf(x)
is close to one.
Methods that can be overwritten by subclasses ::
_rvs _pdf _cdf _sf _ppf _isf _stats _munp _entropy _argcheck _get_support
There are additional (internal and private) generic methods that can be useful for cross-checking and for debugging, but might work in all cases when directly called.
A note on shapes
: subclasses need not specify them explicitly. In this
case, shapes
will be automatically deduced from the signatures of the
overridden methods (pdf
, cdf
etc).
If, for some reason, you prefer to avoid relying on introspection, you can
specify shapes
explicitly as an argument to the instance constructor.
Frozen Distributions
Normally, you must provide shape parameters (and, optionally, location and scale parameters to each call of a method of a distribution.
Alternatively, the object may be called (as a function) to fix the shape, location, and scale parameters returning a "frozen" continuous RV object:
rv = generic(
Statistics
Statistics are computed using numerical integration by default.
For speed you can redefine this using _stats
:
- take shape parameters and return mu, mu2, g1, g2
- If you can't compute one of these, return it as None
- Can also be defined with a keyword argument
moments
, which is a string composed of "m", "v", "s", and/or "k". Only the components appearing in string should be computed and returned in the order "m", "v", "s", or "k" with missing values returned as None.
Alternatively, you can override _munp
, which takes n
and shape
parameters and returns the n-th non-central moment of the distribution.
Examples
To create a new Gaussian distribution, we would do the following:
>>> from scipy.stats import rv_continuous
>>> class gaussian_gen(rv_continuous):
... "Gaussian distribution"
... def _pdf(self, x):
... return np.exp(-x**2 / 2.) / np.sqrt(2.0 * np.pi)
>>> gaussian = gaussian_gen(name='gaussian')
scipy.stats
distributions are instances, so here we subclass
rv_continuous
and create an instance. With this, we now have
a fully functional distribution with all relevant methods automagically
generated by the framework.
Note that above we defined a standard normal distribution, with zero mean
and unit variance. Shifting and scaling of the distribution can be done
by using loc
and scale
parameters: gaussian.pdf(x, loc, scale)
essentially computes y = (x - loc) / scale
and
gaussian._pdf(y) / scale
.
Inherited Members
- scipy.stats._distn_infrastructure.rv_continuous
- rv_continuous
- logpdf
- cdf
- logcdf
- sf
- logsf
- ppf
- isf
- fit
- fit_loc_scale
- expect
- scipy.stats._distn_infrastructure.rv_generic
- random_state
- freeze
- rvs
- stats
- entropy
- moment
- median
- mean
- var
- std
- interval
- support
- nnlf
281def check_distributions(dists: List[Union[rv_continuous, None]], are_stochastic: bool): 282 """ 283 checks that the distribution given in input respects the necessary conditions. 284 285 Parameters 286 ---------- 287 dists : List[Union[rv_continuous, None]] 288 is the list of distributions. 289 are_stochastic : bool 290 whether the distributions are supposed to be stochastic. 291 """ 292 # You either define all or none of the distribution 293 assert dists.count(None) in [0, len(dists)] 294 295 # Double check that the distributions in input matches the stochasticity of the reward parameter 296 if dists[0] is not None: 297 if are_stochastic: 298 assert all(type(dist.dist) != deterministic_gen for dist in dists) 299 else: 300 assert all(type(dist.dist) == deterministic_gen for dist in dists)
checks that the distribution given in input respects the necessary conditions.
Parameters
- dists (List[Union[rv_continuous, None]]): is the list of distributions.
- are_stochastic (bool): whether the distributions are supposed to be stochastic.
303def get_loop(x: Iterable) -> Iterable: 304 """ 305 Returns 306 ------- 307 Iterable 308 An iterable that respects the current level of verbosity. 309 """ 310 if config.VERBOSE_LEVEL != 0: 311 if type(config.VERBOSE_LEVEL) == int: 312 return tqdm(x, desc="Diameter calculation", mininterval=5) 313 s = StringIO() 314 return tqdm(x, desc="Diameter calculation", file=s, mininterval=5) 315 return x
Returns
- Iterable: An iterable that respects the current level of verbosity.