1from typing import Callable, List
2
3import numpy as np
4from scipy.stats import wishart, multivariate_normal
5
6from colosseum.noises.base import Noise
7
8
9class GaussianCorrelated(Noise):
10 """
11 The class that creates Gaussian correlated noise.
12 """
13
14 def _sample_noise(self, n: int) -> np.ndarray:
15 if self.rv is None:
16 W = wishart(scale=[self._scale] * np.prod(self.shape)).rvs(1, self._rng)
17 self.rv = multivariate_normal(cov=W)
18 return self.rv.rvs(n, self._rng).reshape(n, *self.shape)
19
20 def __init__(self, seed: int, shape_f: Callable[[], List[int]], scale: float = 0.1):
21 """
22 Parameters
23 ----------
24 seed : int
25 The random seed.
26 shape_f : Callable[[], List[int]]
27 The function that returns the shape of the emission map.
28 scale : float
29 The scale parameter for the Wishart distribution for the covariance matrix. By default, it is 0.1.
30 """
31
32 super(GaussianCorrelated, self).__init__(seed, shape_f)
33
34 self._scale = scale
35 self.rv = None