Source code for fluidgym.wrappers.obs_extraction

"""A wrapper that extracts specific observations from the observation dictionary."""

import torch
from gymnasium import spaces

from fluidgym.types import FluidEnvLike
from fluidgym.wrappers.fluid_wrapper import FluidWrapper


[docs] class ObsExtraction(FluidWrapper): """A wrapper that extracts specific observations from the observation dictionary. It extracts only the observations specified in the `keys` list. Parameters ---------- env: FluidEnvLike The environment to wrap. keys: list[str] | None The list of keys to extract from the observation dictionary. """ def __init__(self, env: FluidEnvLike, keys: list[str]) -> None: super().__init__(env) if len(keys) == 0: raise ValueError("Keys list must be non-empty or None.") if not isinstance(self._env.observation_space, spaces.Dict): raise ValueError( "ObsExtraction wrapper only supports Dict observation spaces." ) for k in keys: if k not in self._env.observation_space.spaces: raise ValueError(f"Key '{k}' not found in observation space.") self.__keys = keys self.__observation_space = spaces.Dict( {k: self._env.observation_space.spaces[k] for k in keys} ) @property def observation_space(self) -> spaces.Dict: """The observation space of the environment.""" return self.__observation_space def __filter_obs(self, obs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: if self.__keys is not None: obs = {k: obs[k] for k in self.__keys} return obs
[docs] def reset( self, seed: int | None = None, randomize: bool | None = None, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """Resets the environment to an initial internal state, returning an initial observation and info. Parameters ---------- seed: int | None The seed to use for random number generation. If None, the current seed is used. randomize: bool | None Whether to randomize the initial state. If None, the default behavior is used. Returns ------- tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]] A tuple containing the initial observation and an info dictionary. """ obs, info = self._env.reset(seed=seed, randomize=randomize) obs = self.__filter_obs(obs) return obs, info
[docs] def step( self, action: torch.Tensor ) -> tuple[ dict[str, torch.Tensor], torch.Tensor, bool, bool, dict[str, torch.Tensor] ]: """Run one timestep of the environment's dynamics using the agent actions. When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to reset this environment's state for the next episode. Parameters ---------- action: torch.Tensor The action to take. Returns ------- tuple[ dict[str, torch.Tensor], torch.Tensor, bool, bool, dict[str, torch.Tensor]] A tuple containing the observation, reward, terminated flag, truncated flag, and info dictionary. """ obs, reward, terminated, truncated, info = self._env.step(action) obs = self.__filter_obs(obs) return obs, reward, terminated, truncated, info