colosseum.mdp.utils.custom_samplers

 1import random
 2from typing import TYPE_CHECKING, Iterable, List, Tuple
 3
 4import numpy as np
 5
 6if TYPE_CHECKING:
 7    from colosseum.mdp import NODE_TYPE
 8
 9
10class NextStateSampler:
11    """
12    The `NextStateSampler` handles the sampling of states.
13    """
14
15    @property
16    def next_nodes_and_probs(self) -> Iterable[Tuple["NODE_TYPE", float]]:
17        """
18        Returns
19        -------
20        Iterable[Tuple["NODE_TYPE", float]]
21            The iterable over next state and probability pairs.
22        """
23        return zip(self.next_nodes, self.probs)
24
25    def __init__(
26        self, next_nodes: List["NODE_TYPE"], seed: int = None, probs: List[float] = None
27    ):
28        """
29        Parameters
30        ----------
31        next_nodes : List["NODE_TYPE"]
32            The next nodes from which the sampler will sample.
33        seed : int
34            The random seed.
35        probs : List[float]
36            The probabilities corresponding to each state.
37        """
38        assert len(next_nodes) > 0
39        self.next_nodes = next_nodes
40        self._probs = dict()
41
42        # Deterministic sampler
43        if len(next_nodes) == 1:
44            assert probs is None or len(probs) == 1
45            self.next_state = next_nodes[0]
46            self.probs = [1.0]
47            self.is_deterministic = True
48        # Stochastic sampler
49        else:
50            assert seed is not None
51            self.probs = probs
52            self._rng = random.Random(seed)
53            self.n = len(next_nodes)
54            self.is_deterministic = False
55            self.cached_states = self._rng.choices(
56                self.next_nodes, weights=self.probs, k=5000
57            )
58
59    def sample(self) -> "NODE_TYPE":
60        """
61        Returns
62        -------
63        NODE_TYPE
64            A sample of the next state distribution.
65        """
66        if self.is_deterministic:
67            return self.next_state
68        if len(self.cached_states) == 0:
69            self.cached_states = self._rng.choices(
70                self.next_nodes, weights=self.probs, k=5000
71            )
72        return self.cached_states.pop(0)
73
74    def mode(self) -> "NODE_TYPE":
75        """
76        Returns
77        -------
78        NODE_TYPE
79            The most probable next state.
80        """
81        if self.is_deterministic:
82            return self.next_state
83        return self.next_nodes[np.argmax(self.probs)]
84
85    def prob(self, n: "NODE_TYPE") -> float:
86        """
87        Returns
88        -------
89        float
90            The probability of sampling the given state.
91        """
92        if n not in self._probs:
93            if n not in self.next_nodes:
94                self._probs[n] = 0.0
95            else:
96                self._probs[n] = self.probs[self.next_nodes.index(n)]
97        return self._probs[n]
class NextStateSampler:
11class NextStateSampler:
12    """
13    The `NextStateSampler` handles the sampling of states.
14    """
15
16    @property
17    def next_nodes_and_probs(self) -> Iterable[Tuple["NODE_TYPE", float]]:
18        """
19        Returns
20        -------
21        Iterable[Tuple["NODE_TYPE", float]]
22            The iterable over next state and probability pairs.
23        """
24        return zip(self.next_nodes, self.probs)
25
26    def __init__(
27        self, next_nodes: List["NODE_TYPE"], seed: int = None, probs: List[float] = None
28    ):
29        """
30        Parameters
31        ----------
32        next_nodes : List["NODE_TYPE"]
33            The next nodes from which the sampler will sample.
34        seed : int
35            The random seed.
36        probs : List[float]
37            The probabilities corresponding to each state.
38        """
39        assert len(next_nodes) > 0
40        self.next_nodes = next_nodes
41        self._probs = dict()
42
43        # Deterministic sampler
44        if len(next_nodes) == 1:
45            assert probs is None or len(probs) == 1
46            self.next_state = next_nodes[0]
47            self.probs = [1.0]
48            self.is_deterministic = True
49        # Stochastic sampler
50        else:
51            assert seed is not None
52            self.probs = probs
53            self._rng = random.Random(seed)
54            self.n = len(next_nodes)
55            self.is_deterministic = False
56            self.cached_states = self._rng.choices(
57                self.next_nodes, weights=self.probs, k=5000
58            )
59
60    def sample(self) -> "NODE_TYPE":
61        """
62        Returns
63        -------
64        NODE_TYPE
65            A sample of the next state distribution.
66        """
67        if self.is_deterministic:
68            return self.next_state
69        if len(self.cached_states) == 0:
70            self.cached_states = self._rng.choices(
71                self.next_nodes, weights=self.probs, k=5000
72            )
73        return self.cached_states.pop(0)
74
75    def mode(self) -> "NODE_TYPE":
76        """
77        Returns
78        -------
79        NODE_TYPE
80            The most probable next state.
81        """
82        if self.is_deterministic:
83            return self.next_state
84        return self.next_nodes[np.argmax(self.probs)]
85
86    def prob(self, n: "NODE_TYPE") -> float:
87        """
88        Returns
89        -------
90        float
91            The probability of sampling the given state.
92        """
93        if n not in self._probs:
94            if n not in self.next_nodes:
95                self._probs[n] = 0.0
96            else:
97                self._probs[n] = self.probs[self.next_nodes.index(n)]
98        return self._probs[n]

The NextStateSampler handles the sampling of states.

26    def __init__(
27        self, next_nodes: List["NODE_TYPE"], seed: int = None, probs: List[float] = None
28    ):
29        """
30        Parameters
31        ----------
32        next_nodes : List["NODE_TYPE"]
33            The next nodes from which the sampler will sample.
34        seed : int
35            The random seed.
36        probs : List[float]
37            The probabilities corresponding to each state.
38        """
39        assert len(next_nodes) > 0
40        self.next_nodes = next_nodes
41        self._probs = dict()
42
43        # Deterministic sampler
44        if len(next_nodes) == 1:
45            assert probs is None or len(probs) == 1
46            self.next_state = next_nodes[0]
47            self.probs = [1.0]
48            self.is_deterministic = True
49        # Stochastic sampler
50        else:
51            assert seed is not None
52            self.probs = probs
53            self._rng = random.Random(seed)
54            self.n = len(next_nodes)
55            self.is_deterministic = False
56            self.cached_states = self._rng.choices(
57                self.next_nodes, weights=self.probs, k=5000
58            )
Parameters
  • next_nodes (List["NODE_TYPE"]): The next nodes from which the sampler will sample.
  • seed (int): The random seed.
  • probs (List[float]): The probabilities corresponding to each state.
60    def sample(self) -> "NODE_TYPE":
61        """
62        Returns
63        -------
64        NODE_TYPE
65            A sample of the next state distribution.
66        """
67        if self.is_deterministic:
68            return self.next_state
69        if len(self.cached_states) == 0:
70            self.cached_states = self._rng.choices(
71                self.next_nodes, weights=self.probs, k=5000
72            )
73        return self.cached_states.pop(0)
Returns
  • NODE_TYPE: A sample of the next state distribution.
75    def mode(self) -> "NODE_TYPE":
76        """
77        Returns
78        -------
79        NODE_TYPE
80            The most probable next state.
81        """
82        if self.is_deterministic:
83            return self.next_state
84        return self.next_nodes[np.argmax(self.probs)]
Returns
  • NODE_TYPE: The most probable next state.
86    def prob(self, n: "NODE_TYPE") -> float:
87        """
88        Returns
89        -------
90        float
91            The probability of sampling the given state.
92        """
93        if n not in self._probs:
94            if n not in self.next_nodes:
95                self._probs[n] = 0.0
96            else:
97                self._probs[n] = self.probs[self.next_nodes.index(n)]
98        return self._probs[n]
Returns
  • float: The probability of sampling the given state.