Source code for vista.core.World

from typing import List, Dict, Tuple, Optional, Callable
import numpy as np

from .Trace import Trace
from ..entities.agents import Car
from ..utils import logging


[docs]class World: """ This class specifies the :class:`World` where all entities (i.e., agents and sensors) in VISTA lives in. The is built upon one or multiple traces pre-collected from the real world. Users can spawn agents (:class:`Car`) in the :class:`World`, where all agents will be visible while rendering sensory measurement. At reset, the :class:`World` randomly sample an pointer in the dataset (conceptually retrieving a frame within a trace) as the initial condition for the simulation. Args: trace_paths (List(str)): A list of paths to traces trace_config (Dict): Configuration of traces Example usage:: >>> world = vista.World(trace_paths, trace_config) >>> agent = world.spawn_agent(agent_config) >>> world.reset() """ def __init__( self, trace_paths: List[str], trace_config: Optional[Dict] = dict()) -> None: # A list of traces that define the world self._traces: List[Trace] = [ Trace(trace_path, trace_config) for trace_path in trace_paths ] # A list of agents within this world. Agents start in a single # trace but can be teleported between traces, since they are all in # the same world. self._agents: List[Car] = [] # Random seed; will also assign the same seed to traces self.set_seed(0)
[docs] def spawn_agent(self, config: Dict) -> Car: """ Spawn an agent in this world. Args: config (Dict): Configuration of the agent. Returns: Car: The agent being spawned. """ agent = Car(world=self, car_config=config) logging.info('Spawn an agent ({})'.format(agent.id)) self._agents.append(agent) return agent
[docs] def reset( self, initial_dynamics_fn: Optional[Dict[str, Callable]] = dict()) -> None: """ Reset the world. This includes (1) sample a new anchor point from the real-world dataset to be simulated from and (2) reset states for all agents. Args: initial_dynamics_fn (Dict[str, Callable]): A dict mapping agent names to a function that initialize agents poses. """ logging.info('World reset') # Sample a new trace and a new location at the sampled trace new_trace_index, new_segment_index, new_frame_index = \ self.sample_new_location() # Reset agents for agent in self.agents: agent.reset(new_trace_index, new_segment_index, new_frame_index, initial_dynamics_fn.get(agent.id, None))
[docs] def sample_new_location(self) -> Tuple[int, int, int]: """ Sample a pointer to the dataset for simulation. Returns: Return a tuple (``int_a``, ``int_b``, ``int_c``), where ``int_a`` is trace index, ``int_b`` is segment index, and ``int_c`` is frame index. """ new_trace_index = self.sample_new_trace_index() trace = self.traces[new_trace_index] new_segment_index = trace.find_segment_reset() # note that this is the index of Trace.good_frames instead of element new_frame_index = trace.find_frame_reset(new_segment_index) return new_trace_index, new_segment_index, new_frame_index
[docs] def sample_new_trace_index(self) -> int: """ Sample a new trace index based on number of frames in a trace. Returns: int: An index to specify which trace to be simulated from. """ trace_reset_probs = np.zeros(len(self.traces)) for i, trace in enumerate(self.traces): trace_reset_probs[i] = trace.num_of_frames trace_reset_probs /= np.sum(trace_reset_probs) new_trace_index = self._rng.choice(trace_reset_probs.shape[0], p=trace_reset_probs) return new_trace_index
[docs] def set_seed(self, seed) -> None: """ Set random seed. Args: seed (int): Random seed. """ self._seed = seed self._rng = np.random.default_rng(self.seed) [t.set_seed(seed) for t in self.traces]
@property def seed(self) -> int: """ Random seed for sampling pointer to the dataset during :meth:`reset`. """ return self._seed @property def traces(self) -> List[Trace]: """ All trace attached to the :class:`World`. """ return self._traces @property def agents(self) -> List[Car]: """ All agents that live in the :class:`World`. """ return self._agents def __repr__(self) -> str: return '<{}> '.format(self.__class__.__name__) + \ '#traces: {} '.format(len(self.traces)) + \ '#agents: {} '.format(len(self.agents))