colosseum.mdp.utils.custom_samplers
1import random 2from typing import TYPE_CHECKING, Iterable, List, Tuple 3 4import numpy as np 5 6if TYPE_CHECKING: 7 from colosseum.mdp import NODE_TYPE 8 9 10class NextStateSampler: 11 """ 12 The `NextStateSampler` handles the sampling of states. 13 """ 14 15 @property 16 def next_nodes_and_probs(self) -> Iterable[Tuple["NODE_TYPE", float]]: 17 """ 18 Returns 19 ------- 20 Iterable[Tuple["NODE_TYPE", float]] 21 The iterable over next state and probability pairs. 22 """ 23 return zip(self.next_nodes, self.probs) 24 25 def __init__( 26 self, next_nodes: List["NODE_TYPE"], seed: int = None, probs: List[float] = None 27 ): 28 """ 29 Parameters 30 ---------- 31 next_nodes : List["NODE_TYPE"] 32 The next nodes from which the sampler will sample. 33 seed : int 34 The random seed. 35 probs : List[float] 36 The probabilities corresponding to each state. 37 """ 38 assert len(next_nodes) > 0 39 self.next_nodes = next_nodes 40 self._probs = dict() 41 42 # Deterministic sampler 43 if len(next_nodes) == 1: 44 assert probs is None or len(probs) == 1 45 self.next_state = next_nodes[0] 46 self.probs = [1.0] 47 self.is_deterministic = True 48 # Stochastic sampler 49 else: 50 assert seed is not None 51 self.probs = probs 52 self._rng = random.Random(seed) 53 self.n = len(next_nodes) 54 self.is_deterministic = False 55 self.cached_states = self._rng.choices( 56 self.next_nodes, weights=self.probs, k=5000 57 ) 58 59 def sample(self) -> "NODE_TYPE": 60 """ 61 Returns 62 ------- 63 NODE_TYPE 64 A sample of the next state distribution. 65 """ 66 if self.is_deterministic: 67 return self.next_state 68 if len(self.cached_states) == 0: 69 self.cached_states = self._rng.choices( 70 self.next_nodes, weights=self.probs, k=5000 71 ) 72 return self.cached_states.pop(0) 73 74 def mode(self) -> "NODE_TYPE": 75 """ 76 Returns 77 ------- 78 NODE_TYPE 79 The most probable next state. 80 """ 81 if self.is_deterministic: 82 return self.next_state 83 return self.next_nodes[np.argmax(self.probs)] 84 85 def prob(self, n: "NODE_TYPE") -> float: 86 """ 87 Returns 88 ------- 89 float 90 The probability of sampling the given state. 91 """ 92 if n not in self._probs: 93 if n not in self.next_nodes: 94 self._probs[n] = 0.0 95 else: 96 self._probs[n] = self.probs[self.next_nodes.index(n)] 97 return self._probs[n]
class
NextStateSampler:
11class NextStateSampler: 12 """ 13 The `NextStateSampler` handles the sampling of states. 14 """ 15 16 @property 17 def next_nodes_and_probs(self) -> Iterable[Tuple["NODE_TYPE", float]]: 18 """ 19 Returns 20 ------- 21 Iterable[Tuple["NODE_TYPE", float]] 22 The iterable over next state and probability pairs. 23 """ 24 return zip(self.next_nodes, self.probs) 25 26 def __init__( 27 self, next_nodes: List["NODE_TYPE"], seed: int = None, probs: List[float] = None 28 ): 29 """ 30 Parameters 31 ---------- 32 next_nodes : List["NODE_TYPE"] 33 The next nodes from which the sampler will sample. 34 seed : int 35 The random seed. 36 probs : List[float] 37 The probabilities corresponding to each state. 38 """ 39 assert len(next_nodes) > 0 40 self.next_nodes = next_nodes 41 self._probs = dict() 42 43 # Deterministic sampler 44 if len(next_nodes) == 1: 45 assert probs is None or len(probs) == 1 46 self.next_state = next_nodes[0] 47 self.probs = [1.0] 48 self.is_deterministic = True 49 # Stochastic sampler 50 else: 51 assert seed is not None 52 self.probs = probs 53 self._rng = random.Random(seed) 54 self.n = len(next_nodes) 55 self.is_deterministic = False 56 self.cached_states = self._rng.choices( 57 self.next_nodes, weights=self.probs, k=5000 58 ) 59 60 def sample(self) -> "NODE_TYPE": 61 """ 62 Returns 63 ------- 64 NODE_TYPE 65 A sample of the next state distribution. 66 """ 67 if self.is_deterministic: 68 return self.next_state 69 if len(self.cached_states) == 0: 70 self.cached_states = self._rng.choices( 71 self.next_nodes, weights=self.probs, k=5000 72 ) 73 return self.cached_states.pop(0) 74 75 def mode(self) -> "NODE_TYPE": 76 """ 77 Returns 78 ------- 79 NODE_TYPE 80 The most probable next state. 81 """ 82 if self.is_deterministic: 83 return self.next_state 84 return self.next_nodes[np.argmax(self.probs)] 85 86 def prob(self, n: "NODE_TYPE") -> float: 87 """ 88 Returns 89 ------- 90 float 91 The probability of sampling the given state. 92 """ 93 if n not in self._probs: 94 if n not in self.next_nodes: 95 self._probs[n] = 0.0 96 else: 97 self._probs[n] = self.probs[self.next_nodes.index(n)] 98 return self._probs[n]
The NextStateSampler
handles the sampling of states.
NextStateSampler( next_nodes: List[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode]], seed: int = None, probs: List[float] = None)
26 def __init__( 27 self, next_nodes: List["NODE_TYPE"], seed: int = None, probs: List[float] = None 28 ): 29 """ 30 Parameters 31 ---------- 32 next_nodes : List["NODE_TYPE"] 33 The next nodes from which the sampler will sample. 34 seed : int 35 The random seed. 36 probs : List[float] 37 The probabilities corresponding to each state. 38 """ 39 assert len(next_nodes) > 0 40 self.next_nodes = next_nodes 41 self._probs = dict() 42 43 # Deterministic sampler 44 if len(next_nodes) == 1: 45 assert probs is None or len(probs) == 1 46 self.next_state = next_nodes[0] 47 self.probs = [1.0] 48 self.is_deterministic = True 49 # Stochastic sampler 50 else: 51 assert seed is not None 52 self.probs = probs 53 self._rng = random.Random(seed) 54 self.n = len(next_nodes) 55 self.is_deterministic = False 56 self.cached_states = self._rng.choices( 57 self.next_nodes, weights=self.probs, k=5000 58 )
Parameters
- next_nodes (List["NODE_TYPE"]): The next nodes from which the sampler will sample.
- seed (int): The random seed.
- probs (List[float]): The probabilities corresponding to each state.
next_nodes_and_probs: Iterable[Tuple[Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode], float]]
Returns
- Iterable[Tuple["NODE_TYPE", float]]: The iterable over next state and probability pairs.
def
sample( self) -> Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode]:
60 def sample(self) -> "NODE_TYPE": 61 """ 62 Returns 63 ------- 64 NODE_TYPE 65 A sample of the next state distribution. 66 """ 67 if self.is_deterministic: 68 return self.next_state 69 if len(self.cached_states) == 0: 70 self.cached_states = self._rng.choices( 71 self.next_nodes, weights=self.probs, k=5000 72 ) 73 return self.cached_states.pop(0)
Returns
- NODE_TYPE: A sample of the next state distribution.
def
mode( self) -> Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode]:
75 def mode(self) -> "NODE_TYPE": 76 """ 77 Returns 78 ------- 79 NODE_TYPE 80 The most probable next state. 81 """ 82 if self.is_deterministic: 83 return self.next_state 84 return self.next_nodes[np.argmax(self.probs)]
Returns
- NODE_TYPE: The most probable next state.
def
prob( self, n: Union[colosseum.mdp.custom_mdp.CustomNode, colosseum.mdp.river_swim.base.RiverSwimNode, colosseum.mdp.deep_sea.base.DeepSeaNode, colosseum.mdp.frozen_lake.base.FrozenLakeNode, colosseum.mdp.simple_grid.base.SimpleGridNode, colosseum.mdp.minigrid_empty.base.MiniGridEmptyNode, colosseum.mdp.minigrid_rooms.base.MiniGridRoomsNode, colosseum.mdp.taxi.base.TaxiNode]) -> float:
86 def prob(self, n: "NODE_TYPE") -> float: 87 """ 88 Returns 89 ------- 90 float 91 The probability of sampling the given state. 92 """ 93 if n not in self._probs: 94 if n not in self.next_nodes: 95 self._probs[n] = 0.0 96 else: 97 self._probs[n] = self.probs[self.next_nodes.index(n)] 98 return self._probs[n]
Returns
- float: The probability of sampling the given state.