colosseum.utils.miscellanea

  1import cProfile
  2import collections
  3import importlib
  4import inspect
  5import numbers
  6import os
  7from glob import glob
  8from io import StringIO
  9from typing import TYPE_CHECKING, Iterable, List, Type, Union, Any, Dict
 10
 11import dm_env
 12import numpy as np
 13from scipy.stats import rv_continuous
 14from tqdm import tqdm
 15
 16import colosseum
 17import colosseum.config as config
 18
 19if TYPE_CHECKING:
 20    from colosseum.mdp import BaseMDP
 21    from colosseum.agent.agents.base import BaseAgent
 22
 23
 24def rounding_nested_structure(x: Dict):
 25    """
 26    https://stackoverflow.com/questions/7076254/rounding-decimals-in-nested-data-structures-in-python
 27    """
 28    if isinstance(x, str):
 29        return x
 30    if isinstance(x, dict):
 31        return type(x)(
 32            (key, rounding_nested_structure(value)) for key, value in x.items()
 33        )
 34    if isinstance(x, collections.Container):
 35        return type(x)(rounding_nested_structure(value) for value in x)
 36    if isinstance(x, numbers.Number):
 37        return round(x, config.get_n_floating_sampling_hyperparameters())
 38    return x
 39
 40
 41def compare_gin_configs(
 42    gin_configs1: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str],
 43    gin_configs2: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str],
 44) -> bool:
 45    """
 46    Returns
 47    -------
 48    bool
 49        True, if the two gin configs are identical.
 50    """
 51    if set(gin_configs1) != set(gin_configs2):
 52        return False
 53
 54    gin_configs1 = set(
 55        map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs1.values())
 56    )
 57    gin_configs2 = set(
 58        map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs2.values())
 59    )
 60    return gin_configs1 == gin_configs2
 61
 62
 63def sample_mdp_gin_configs(
 64    mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42
 65) -> List[str]:
 66    """
 67    Parameters
 68    ----------
 69    mdp_class : Type["BaseMDP"]
 70        The MDP class to sample from.
 71    n : int
 72        The number of samples. By default, one sample is taken.
 73    seed : int
 74        The random seed. By default, it is set to 42.
 75
 76    Returns
 77    -------
 78    List[str]
 79        The n sampled gin configs.
 80    """
 81    return [
 82        mdp_class.produce_gin_file_from_mdp_parameters(params, mdp_class.__name__, i)
 83        for i, params in enumerate(mdp_class.sample_parameters(n, seed))
 84    ]
 85
 86
 87def sample_mdp_gin_configs_file(
 88    mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42
 89) -> str:
 90    """
 91    Parameters
 92    ----------
 93    mdp_class : Type["BaseMDP"]
 94        The MDP class to sample from.
 95    n : int
 96        The number of samples. By default, one sample is taken.
 97    seed : int
 98        The random seed. By default, it is set to 42.
 99
100    Returns
101    -------
102    str
103        The n sampled gin configs as a single file string.
104    """
105    return "\n".join(sample_mdp_gin_configs(mdp_class, n, seed))
106
107
108def get_empty_ts(state: Any) -> dm_env.TimeStep:
109    return dm_env.TimeStep(dm_env.StepType.MID, 0, 0, state)
110
111
112def profile(file_path):
113    def decorator(f):
114        print(f"Profiling {f}")
115
116        def inner(*args, **kwargs):
117            pr = cProfile.Profile()
118            pr.enable()
119            f(*args, **kwargs)
120            pr.disable()
121            # after your program ends
122            pr.dump_stats(file_path)
123
124        return inner
125
126    return decorator
127
128
129def get_colosseum_mdp_classes(episodic: bool = None) -> List[Type["BaseMDP"]]:
130    """
131    Returns
132    -------
133    List[Type["BaseMDP"]]
134        All available MDP classes in the package.
135    """
136    if episodic is None:
137        return _get_colosseum_mdp_classes() + _get_colosseum_mdp_classes(False)
138    if episodic:
139        return _get_colosseum_mdp_classes()
140    return _get_colosseum_mdp_classes(False)
141
142
143def _get_colosseum_mdp_classes(episodic=True) -> List[Type["BaseMDP"]]:
144    import colosseum
145
146    kw = "Episodic" if episodic else "Continuous"
147    mdp_path = "finite_horizon" if episodic else "infinite_horizon"
148    return [
149        next(
150            filter(
151                lambda x: kw in x[0] and "MDP" not in x[0],
152                importlib.import_module(
153                    mdp_file[mdp_file.find("colosseum") :].replace(os.sep, ".")[:-3]
154                ).__dict__.items(),
155            )
156        )[1]
157        for mdp_file in glob(
158            f"{os.path.dirname(inspect.getfile(colosseum))}{os.sep}mdp{os.sep}**{os.sep}{mdp_path}.py",
159            recursive=True,
160        )
161    ]
162
163
164def get_mdp_class_from_name(mdp_class_name: str) -> Type["BaseMDP"]:
165    """
166    Returns
167    -------
168    Type["BaseMDP"]
169        The MDP class corresponding to the name in input.
170    """
171    try:
172        return next(
173            filter(lambda c: c.__name__ == mdp_class_name, get_colosseum_mdp_classes())
174        )
175    except StopIteration:
176        raise ModuleNotFoundError(
177            f"The MDP class {mdp_class_name} was not found in colosseum. Please check the correct spelling and the result of "
178            f"get_colosseum_mdp_classes()"
179        )
180
181
182def get_colosseum_agent_classes(episodic: bool = None) -> List[Type["BaseAgent"]]:
183    """
184    Returns
185    -------
186    List[Type["BaseAgent"]]
187        All available agent classes in the package.
188    """
189    if episodic is None:
190        return _get_colosseum_agent_classes(True) + _get_colosseum_agent_classes(False)
191    if episodic:
192        return _get_colosseum_agent_classes(True)
193    return _get_colosseum_agent_classes(False)
194
195
196def _get_colosseum_agent_classes(episodic: bool) -> List[Type["BaseAgent"]]:
197    agent_path = "episodic" if episodic else "infinite_horizon"
198    kw = "Episodic" if episodic else "Continuous"
199    return [
200        next(
201            filter(
202                lambda x: kw in x[0],
203                importlib.import_module(
204                    agent_file[agent_file.find("colosseum") :].replace(os.sep, ".")[:-3]
205                ).__dict__.items(),
206            )
207        )[1]
208        for agent_file in glob(
209            f"{os.path.dirname(inspect.getfile(colosseum))}{os.sep}agent{os.sep}agents{os.sep}{agent_path}"
210            f"{os.sep}**{os.sep}[a-z]*.py",
211            recursive=True,
212        )
213    ]
214
215
216def get_agent_class_from_name(agent_class_name: str) -> Type["BaseAgent"]:
217    """
218    Returns
219    -------
220    Type["BaseAgent"]
221        The agent class corresponding to the name in input.
222    """
223    return next(
224        filter(
225            lambda c: c.__name__ == agent_class_name,
226            get_colosseum_agent_classes() + config.get_external_agent_classes(),
227        )
228    )
229
230    try:
231        return next(
232            filter(
233                lambda c: c.__name__ == agent_class_name,
234                get_colosseum_agent_classes() + config.get_external_agent_classes(),
235            )
236        )
237    except StopIteration:
238        raise ModuleNotFoundError(
239            f"The agent class {agent_class_name} was not found in colosseum. The available classes are {get_colosseum_agent_classes() + config.get_external_agent_classes()}"
240        )
241
242
243def ensure_folder(path: str) -> str:
244    """
245    Returns
246    -------
247    str
248        The path with the os.sep at the end.
249    """
250    return path if path[-1] == os.sep else (path + os.sep)
251
252
253def get_dist(dist_name, args):
254    if dist_name == "deterministic":
255        return deterministic(*args)
256    return importlib.import_module(f"scipy.stats").__getattribute__(dist_name)(*args)
257
258
259class deterministic_gen(rv_continuous):
260    def _cdf(self, x):
261        return np.where(x < 0, 0.0, 1.0)
262
263    def _stats(self):
264        return 0.0, 0.0, 0.0, 0.0
265
266    def _rvs(self, size=None, random_state=None):
267        return np.zeros(shape=size)
268
269
270deterministic = deterministic_gen(name="deterministic")
271
272
273def state_occurencens_to_counts(occurences: List[int], N: int) -> np.ndarray:
274    x = np.zeros(N)
275    for s, c in dict(zip(*np.unique(occurences, return_counts=True))).items():
276        x[s] = c
277    return x
278
279
280def check_distributions(dists: List[Union[rv_continuous, None]], are_stochastic: bool):
281    """
282    checks that the distribution given in input respects the necessary conditions.
283
284    Parameters
285    ----------
286    dists : List[Union[rv_continuous, None]]
287        is the list of distributions.
288    are_stochastic : bool
289        whether the distributions are supposed to be stochastic.
290    """
291    # You either define all or none of the distribution
292    assert dists.count(None) in [0, len(dists)]
293
294    # Double check that the distributions in input matches the stochasticity of the reward parameter
295    if dists[0] is not None:
296        if are_stochastic:
297            assert all(type(dist.dist) != deterministic_gen for dist in dists)
298        else:
299            assert all(type(dist.dist) == deterministic_gen for dist in dists)
300
301
302def get_loop(x: Iterable) -> Iterable:
303    """
304    Returns
305    -------
306    Iterable
307        An iterable that respects the current level of verbosity.
308    """
309    if config.VERBOSE_LEVEL != 0:
310        if type(config.VERBOSE_LEVEL) == int:
311            return tqdm(x, desc="Diameter calculation", mininterval=5)
312        s = StringIO()
313        return tqdm(x, desc="Diameter calculation", file=s, mininterval=5)
314    return x
def rounding_nested_structure(x: Dict):
25def rounding_nested_structure(x: Dict):
26    """
27    https://stackoverflow.com/questions/7076254/rounding-decimals-in-nested-data-structures-in-python
28    """
29    if isinstance(x, str):
30        return x
31    if isinstance(x, dict):
32        return type(x)(
33            (key, rounding_nested_structure(value)) for key, value in x.items()
34        )
35    if isinstance(x, collections.Container):
36        return type(x)(rounding_nested_structure(value) for value in x)
37    if isinstance(x, numbers.Number):
38        return round(x, config.get_n_floating_sampling_hyperparameters())
39    return x
def compare_gin_configs( gin_configs1: Dict[Union[Type[colosseum.mdp.base.BaseMDP], Type[colosseum.agent.agents.base.BaseAgent]], str], gin_configs2: Dict[Union[Type[colosseum.mdp.base.BaseMDP], Type[colosseum.agent.agents.base.BaseAgent]], str]) -> bool:
42def compare_gin_configs(
43    gin_configs1: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str],
44    gin_configs2: Dict[Union[Type["BaseMDP"], Type["BaseAgent"]], str],
45) -> bool:
46    """
47    Returns
48    -------
49    bool
50        True, if the two gin configs are identical.
51    """
52    if set(gin_configs1) != set(gin_configs2):
53        return False
54
55    gin_configs1 = set(
56        map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs1.values())
57    )
58    gin_configs2 = set(
59        map(lambda x: x.replace(" ", "").replace("\n", ""), gin_configs2.values())
60    )
61    return gin_configs1 == gin_configs2
Returns
  • bool: True, if the two gin configs are identical.
def sample_mdp_gin_configs( mdp_class: Type[colosseum.mdp.base.BaseMDP], n: int = 1, seed: int = 42) -> List[str]:
64def sample_mdp_gin_configs(
65    mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42
66) -> List[str]:
67    """
68    Parameters
69    ----------
70    mdp_class : Type["BaseMDP"]
71        The MDP class to sample from.
72    n : int
73        The number of samples. By default, one sample is taken.
74    seed : int
75        The random seed. By default, it is set to 42.
76
77    Returns
78    -------
79    List[str]
80        The n sampled gin configs.
81    """
82    return [
83        mdp_class.produce_gin_file_from_mdp_parameters(params, mdp_class.__name__, i)
84        for i, params in enumerate(mdp_class.sample_parameters(n, seed))
85    ]
Parameters
  • mdp_class (Type["BaseMDP"]): The MDP class to sample from.
  • n (int): The number of samples. By default, one sample is taken.
  • seed (int): The random seed. By default, it is set to 42.
Returns
  • List[str]: The n sampled gin configs.
def sample_mdp_gin_configs_file( mdp_class: Type[colosseum.mdp.base.BaseMDP], n: int = 1, seed: int = 42) -> str:
 88def sample_mdp_gin_configs_file(
 89    mdp_class: Type["BaseMDP"], n: int = 1, seed: int = 42
 90) -> str:
 91    """
 92    Parameters
 93    ----------
 94    mdp_class : Type["BaseMDP"]
 95        The MDP class to sample from.
 96    n : int
 97        The number of samples. By default, one sample is taken.
 98    seed : int
 99        The random seed. By default, it is set to 42.
100
101    Returns
102    -------
103    str
104        The n sampled gin configs as a single file string.
105    """
106    return "\n".join(sample_mdp_gin_configs(mdp_class, n, seed))
Parameters
  • mdp_class (Type["BaseMDP"]): The MDP class to sample from.
  • n (int): The number of samples. By default, one sample is taken.
  • seed (int): The random seed. By default, it is set to 42.
Returns
  • str: The n sampled gin configs as a single file string.
def get_empty_ts(state: Any) -> dm_env._environment.TimeStep:
109def get_empty_ts(state: Any) -> dm_env.TimeStep:
110    return dm_env.TimeStep(dm_env.StepType.MID, 0, 0, state)
def profile(file_path):
113def profile(file_path):
114    def decorator(f):
115        print(f"Profiling {f}")
116
117        def inner(*args, **kwargs):
118            pr = cProfile.Profile()
119            pr.enable()
120            f(*args, **kwargs)
121            pr.disable()
122            # after your program ends
123            pr.dump_stats(file_path)
124
125        return inner
126
127    return decorator
def get_colosseum_mdp_classes(episodic: bool = None) -> List[Type[colosseum.mdp.base.BaseMDP]]:
130def get_colosseum_mdp_classes(episodic: bool = None) -> List[Type["BaseMDP"]]:
131    """
132    Returns
133    -------
134    List[Type["BaseMDP"]]
135        All available MDP classes in the package.
136    """
137    if episodic is None:
138        return _get_colosseum_mdp_classes() + _get_colosseum_mdp_classes(False)
139    if episodic:
140        return _get_colosseum_mdp_classes()
141    return _get_colosseum_mdp_classes(False)
Returns
  • List[Type["BaseMDP"]]: All available MDP classes in the package.
def get_mdp_class_from_name(mdp_class_name: str) -> Type[colosseum.mdp.base.BaseMDP]:
165def get_mdp_class_from_name(mdp_class_name: str) -> Type["BaseMDP"]:
166    """
167    Returns
168    -------
169    Type["BaseMDP"]
170        The MDP class corresponding to the name in input.
171    """
172    try:
173        return next(
174            filter(lambda c: c.__name__ == mdp_class_name, get_colosseum_mdp_classes())
175        )
176    except StopIteration:
177        raise ModuleNotFoundError(
178            f"The MDP class {mdp_class_name} was not found in colosseum. Please check the correct spelling and the result of "
179            f"get_colosseum_mdp_classes()"
180        )
Returns
  • Type["BaseMDP"]: The MDP class corresponding to the name in input.
def get_colosseum_agent_classes( episodic: bool = None) -> List[Type[colosseum.agent.agents.base.BaseAgent]]:
183def get_colosseum_agent_classes(episodic: bool = None) -> List[Type["BaseAgent"]]:
184    """
185    Returns
186    -------
187    List[Type["BaseAgent"]]
188        All available agent classes in the package.
189    """
190    if episodic is None:
191        return _get_colosseum_agent_classes(True) + _get_colosseum_agent_classes(False)
192    if episodic:
193        return _get_colosseum_agent_classes(True)
194    return _get_colosseum_agent_classes(False)
Returns
  • List[Type["BaseAgent"]]: All available agent classes in the package.
def get_agent_class_from_name(agent_class_name: str) -> Type[colosseum.agent.agents.base.BaseAgent]:
217def get_agent_class_from_name(agent_class_name: str) -> Type["BaseAgent"]:
218    """
219    Returns
220    -------
221    Type["BaseAgent"]
222        The agent class corresponding to the name in input.
223    """
224    return next(
225        filter(
226            lambda c: c.__name__ == agent_class_name,
227            get_colosseum_agent_classes() + config.get_external_agent_classes(),
228        )
229    )
230
231    try:
232        return next(
233            filter(
234                lambda c: c.__name__ == agent_class_name,
235                get_colosseum_agent_classes() + config.get_external_agent_classes(),
236            )
237        )
238    except StopIteration:
239        raise ModuleNotFoundError(
240            f"The agent class {agent_class_name} was not found in colosseum. The available classes are {get_colosseum_agent_classes() + config.get_external_agent_classes()}"
241        )
Returns
  • Type["BaseAgent"]: The agent class corresponding to the name in input.
def ensure_folder(path: str) -> str:
244def ensure_folder(path: str) -> str:
245    """
246    Returns
247    -------
248    str
249        The path with the os.sep at the end.
250    """
251    return path if path[-1] == os.sep else (path + os.sep)
Returns
  • str: The path with the os.sep at the end.
def get_dist(dist_name, args):
254def get_dist(dist_name, args):
255    if dist_name == "deterministic":
256        return deterministic(*args)
257    return importlib.import_module(f"scipy.stats").__getattribute__(dist_name)(*args)
class deterministic_gen(scipy.stats._distn_infrastructure.rv_continuous):
260class deterministic_gen(rv_continuous):
261    def _cdf(self, x):
262        return np.where(x < 0, 0.0, 1.0)
263
264    def _stats(self):
265        return 0.0, 0.0, 0.0, 0.0
266
267    def _rvs(self, size=None, random_state=None):
268        return np.zeros(shape=size)

A generic continuous random variable class meant for subclassing.

rv_continuous is a base class to construct specific distribution classes and instances for continuous random variables. It cannot be used directly as a distribution.

Parameters
  • momtype (int, optional): The type of generic moment calculation to use: 0 for pdf, 1 (default) for ppf.
  • a (float, optional): Lower bound of the support of the distribution, default is minus infinity.
  • b (float, optional): Upper bound of the support of the distribution, default is plus infinity.
  • xtol (float, optional): The tolerance for fixed point calculation for generic ppf.
  • badvalue (float, optional): The value in a result arrays that indicates a value that for which some argument restriction is violated, default is np.nan.
  • name (str, optional): The name of the instance. This string is used to construct the default example for distributions.
  • longname (str, optional): This string is used as part of the first line of the docstring returned when a subclass has no docstring of its own. Note: longname exists for backwards compatibility, do not use for new subclasses.
  • shapes (str, optional): The shape of the distribution. For example "m, n" for a distribution that takes two integers as the two shape arguments for all its methods. If not provided, shape parameters will be inferred from the signature of the private methods, _pdf and _cdf of the instance.
  • extradoc (str, optional, deprecated): This string is used as the last part of the docstring returned when a subclass has no docstring of its own. Note: extradoc exists for backwards compatibility, do not use for new subclasses.
  • seed ({None, int, numpy.random.Generator,): numpy.random.RandomState}, optional

    If seed is None (or np.random), the numpy.random.RandomState singleton is used. If seed is an int, a new RandomState instance is used, seeded with seed. If seed is already a Generator or RandomState instance then that instance is used.

Methods

rvs pdf logpdf cdf logcdf sf logsf ppf isf moment stats entropy expect median mean std var interval __call__ fit fit_loc_scale nnlf support

Notes

Public methods of an instance of a distribution class (e.g., pdf, cdf) check their arguments and pass valid arguments to private, computational methods (_pdf, _cdf). For pdf(x), x is valid if it is within the support of the distribution. Whether a shape parameter is valid is decided by an _argcheck method (which defaults to checking that its arguments are strictly positive.)

Subclassing

New random variables can be defined by subclassing the rv_continuous class and re-defining at least the _pdf or the _cdf method (normalized to location 0 and scale 1).

If positive argument checking is not correct for your RV then you will also need to re-define the _argcheck method.

For most of the scipy.stats distributions, the support interval doesn't depend on the shape parameters. x being in the support interval is equivalent to self.a <= x <= self.b. If either of the endpoints of the support do depend on the shape parameters, then i) the distribution must implement the _get_support method; and ii) those dependent endpoints must be omitted from the distribution's call to the rv_continuous initializer.

Correct, but potentially slow defaults exist for the remaining methods but for speed and/or accuracy you can over-ride::

_logpdf, _cdf, _logcdf, _ppf, _rvs, _isf, _sf, _logsf

The default method _rvs relies on the inverse of the cdf, _ppf, applied to a uniform random variate. In order to generate random variates efficiently, either the default _ppf needs to be overwritten (e.g. if the inverse cdf can expressed in an explicit form) or a sampling method needs to be implemented in a custom _rvs method.

If possible, you should override _isf, _sf or _logsf. The main reason would be to improve numerical accuracy: for example, the survival function _sf is computed as 1 - _cdf which can result in loss of precision if _cdf(x) is close to one.

Methods that can be overwritten by subclasses ::

_rvs _pdf _cdf _sf _ppf _isf _stats _munp _entropy _argcheck _get_support

There are additional (internal and private) generic methods that can be useful for cross-checking and for debugging, but might work in all cases when directly called.

A note on shapes: subclasses need not specify them explicitly. In this case, shapes will be automatically deduced from the signatures of the overridden methods (pdf, cdf etc). If, for some reason, you prefer to avoid relying on introspection, you can specify shapes explicitly as an argument to the instance constructor.

Frozen Distributions

Normally, you must provide shape parameters (and, optionally, location and scale parameters to each call of a method of a distribution.

Alternatively, the object may be called (as a function) to fix the shape, location, and scale parameters returning a "frozen" continuous RV object:

rv = generic(rv_frozen object with the same methods but holding the given shape, location, and scale fixed

Statistics

Statistics are computed using numerical integration by default. For speed you can redefine this using _stats:

  • take shape parameters and return mu, mu2, g1, g2
  • If you can't compute one of these, return it as None
  • Can also be defined with a keyword argument moments, which is a string composed of "m", "v", "s", and/or "k". Only the components appearing in string should be computed and returned in the order "m", "v", "s", or "k" with missing values returned as None.

Alternatively, you can override _munp, which takes n and shape parameters and returns the n-th non-central moment of the distribution.

Examples

To create a new Gaussian distribution, we would do the following:

>>> from scipy.stats import rv_continuous
>>> class gaussian_gen(rv_continuous):
...     "Gaussian distribution"
...     def _pdf(self, x):
...         return np.exp(-x**2 / 2.) / np.sqrt(2.0 * np.pi)
>>> gaussian = gaussian_gen(name='gaussian')

scipy.stats distributions are instances, so here we subclass rv_continuous and create an instance. With this, we now have a fully functional distribution with all relevant methods automagically generated by the framework.

Note that above we defined a standard normal distribution, with zero mean and unit variance. Shifting and scaling of the distribution can be done by using loc and scale parameters: gaussian.pdf(x, loc, scale) essentially computes y = (x - loc) / scale and gaussian._pdf(y) / scale.

Inherited Members
scipy.stats._distn_infrastructure.rv_continuous
rv_continuous
pdf
logpdf
cdf
logcdf
sf
logsf
ppf
isf
fit
fit_loc_scale
expect
scipy.stats._distn_infrastructure.rv_generic
random_state
freeze
rvs
stats
entropy
moment
median
mean
var
std
interval
support
nnlf
def state_occurencens_to_counts(occurences: List[int], N: int) -> numpy.ndarray:
274def state_occurencens_to_counts(occurences: List[int], N: int) -> np.ndarray:
275    x = np.zeros(N)
276    for s, c in dict(zip(*np.unique(occurences, return_counts=True))).items():
277        x[s] = c
278    return x
def check_distributions( dists: List[Optional[scipy.stats._distn_infrastructure.rv_continuous]], are_stochastic: bool):
281def check_distributions(dists: List[Union[rv_continuous, None]], are_stochastic: bool):
282    """
283    checks that the distribution given in input respects the necessary conditions.
284
285    Parameters
286    ----------
287    dists : List[Union[rv_continuous, None]]
288        is the list of distributions.
289    are_stochastic : bool
290        whether the distributions are supposed to be stochastic.
291    """
292    # You either define all or none of the distribution
293    assert dists.count(None) in [0, len(dists)]
294
295    # Double check that the distributions in input matches the stochasticity of the reward parameter
296    if dists[0] is not None:
297        if are_stochastic:
298            assert all(type(dist.dist) != deterministic_gen for dist in dists)
299        else:
300            assert all(type(dist.dist) == deterministic_gen for dist in dists)

checks that the distribution given in input respects the necessary conditions.

Parameters
  • dists (List[Union[rv_continuous, None]]): is the list of distributions.
  • are_stochastic (bool): whether the distributions are supposed to be stochastic.
def get_loop(x: Iterable) -> Iterable:
303def get_loop(x: Iterable) -> Iterable:
304    """
305    Returns
306    -------
307    Iterable
308        An iterable that respects the current level of verbosity.
309    """
310    if config.VERBOSE_LEVEL != 0:
311        if type(config.VERBOSE_LEVEL) == int:
312            return tqdm(x, desc="Diameter calculation", mininterval=5)
313        s = StringIO()
314        return tqdm(x, desc="Diameter calculation", file=s, mininterval=5)
315    return x
Returns
  • Iterable: An iterable that respects the current level of verbosity.