Source code for vista.tasks.multi_agent_base

from typing import Optional, List, Dict, Any
import numpy as np
from shapely.geometry import box as Box
from shapely import affinity

from .. import World
from ..entities.agents.Car import Car
from ..entities.agents.Dynamics import StateDynamics
from ..entities.sensors.MeshLib import MeshLib
from ..utils import logging, misc, transform


def default_terminal_condition(task, agent_id, **kwargs):
    """ An example definition of terminal condition. """

    agent = [_a for _a in task.world.agents if _a.id == agent_id][0]

    def _check_out_of_lane():
        road_half_width = agent.trace.road_width / 2.
        return np.abs(agent.relative_state.x) > road_half_width

    def _check_exceed_max_rot():
        maximal_rotation = np.pi / 10.
        return np.abs(agent.relative_state.yaw) > maximal_rotation

    def _check_crash():
        other_agents = [_a for _a in task.world.agents if _a.id != agent_id]
        agent2poly = lambda _x: misc.agent2poly(
            _x, ref_dynamics=agent.human_dynamics)
        poly = agent2poly(agent)
        other_polys = list(map(agent2poly, other_agents))
        overlap = compute_overlap(poly, other_polys) / poly.area
        crashed = np.any(overlap > task.config['overlap_threshold'])
        return crashed

    out_of_lane = _check_out_of_lane()
    exceed_max_rot = _check_exceed_max_rot()
    crashed = _check_crash()
    done = out_of_lane or exceed_max_rot or crashed or agent.done
    other_info = {
        'done': done,
        'out_of_lane': out_of_lane,
        'exceed_max_rot': exceed_max_rot,
        'crashed': crashed,
    }

    return done, other_info


def default_reward_fn(task, agent_id, **kwargs):
    """ An example definition of reward function. """
    reward = -1 if kwargs['done'] else 0  # simply encourage survival

    return reward, {}


[docs]class MultiAgentBase: """ This class builds a simple environment with multiple cars in the scene, which involves randomly initializing ado cars in the front of the ego car, checking collision between cars, handling meshes for all virtual agents, and determining terminal condition. Args: trace_paths (List[str]): A list of trace paths. trace_config (Dict): Configuration of the trace. car_configs (List[Dict]): Configuration of ``every`` cars. sensors_configs (List[Dict]): Configuration of sensors on ``every`` cars. task_config (Dict): Configuration of the task. An example (default) is, >>> DEFAULT_CONFIG = { 'n_agents': 1, 'mesh_dir': None, 'overlap_threshold': 0.05, 'max_resample_tries': 10, 'init_dist_range': [5., 10.], 'init_lat_noise_range': [-1., 1.], 'init_yaw_noise_range': [-0.0, 0.0], 'reward_fn': default_reward_fn, 'terminal_condition': default_terminal_condition } Note that both ``reward_fn`` and ``terminal_condition`` have function signature as ``f(task, agent_id, **kwargs) -> (value, dict)``. For more details, please check the source code. logging_level (str): Logging level (``DEBUG``, ``INFO``, ``WARNING``, ``ERROR``, ``CRITICAL``); default set to ``WARNING``. """ DEFAULT_CONFIG = { 'n_agents': 1, 'mesh_dir': None, 'overlap_threshold': 0.05, 'max_resample_tries': 10, 'init_dist_range': [5., 10.], 'init_lat_noise_range': [-1., 1.], 'init_yaw_noise_range': [-0.0, 0.0], 'reward_fn': default_reward_fn, 'terminal_condition': default_terminal_condition } def __init__(self, trace_paths: List[str], trace_config: Dict, car_configs: List[Dict], sensors_configs: List[List[Dict]], task_config: Optional[Dict] = dict(), logging_level: Optional[str] = 'WARNING'): logging.setLevel(getattr(logging, logging_level)) self._config = misc.merge_dict(task_config, self.DEFAULT_CONFIG) n_agents = self.config['n_agents'] assert len( car_configs ) == n_agents, 'Number of car config is not consistent with number of agents' assert len( sensors_configs ) == n_agents, 'Number of sensors config is not consistent with number of agents' assert car_configs[0][ 'lookahead_road'], '\'lookahead_road\' in the first car config should be set to True' self._world: World = World(trace_paths, trace_config) for i in range(n_agents): agent = self._world.spawn_agent(car_configs[i]) for sensor_config in sensors_configs[i]: sensor_type = sensor_config.pop('type') if sensor_type == 'camera': agent.spawn_camera(sensor_config) else: raise NotImplementedError( f'Unrecognized sensor type {sensor_type}') if n_agents > 1: assert self.config[ 'mesh_dir'] is not None, 'Specify mesh_dir if n_agents > 1' self._meshlib = MeshLib(self.config['mesh_dir']) self.set_seed(0)
[docs] def reset(self) -> Dict: """ Reset the environment. This involves regular world reset, randomly initializing ado agent in the front of the ego agent, and resetting the mesh library for all virtual agents. Returns: Dict: A dictionary with keys as agent IDs and values as observation for each agent, which is also a dictionary with keys as sensor IDs and values as sensory measurements. """ # Reset world; all agents are initialized at the same pointer to the trace new_trace_index, new_segment_index, new_frame_index = \ self.world.sample_new_location() for agent in self.world.agents: agent.reset(new_trace_index, new_segment_index, new_frame_index, step_sensors=False) # Randomly initialize ado agents in the front ref_dynamics = self.ego_agent.human_dynamics polys = [misc.agent2poly(self.ego_agent, ref_dynamics)] for agent in self.world.agents: if agent == self.ego_agent: continue collision_free = False resample_tries = 0 while not collision_free and resample_tries < self.config[ 'max_resample_tries']: self._randomly_place_agent(agent) poly = misc.agent2poly(agent, ref_dynamics) overlap = compute_overlap(poly, polys) / poly.area collision_free = np.all( overlap <= self.config['overlap_threshold']) resample_tries += 1 polys.append(poly) # Reset mesh library if len(self.world.agents) > 1: self._reset_meshlib() # Get observation self._sensor_capture() observations = {_a.id: _a.observations for _a in self.world.agents} return observations
[docs] def step(self, actions, dt=1 / 30.): """ Step the environment. This includes updating agents' states, synthesizing agents' observations, checking terminal conditions, and computing rewards. Args: actions (Dict[str, np.ndarray]): A dictionary with keys as agent IDs and values as actions to be executed to interact with the environment and other agents. dt (float): Elapsed time in second; default set to 1/30. Returns: Return a tuple (``dict_a``, ``dict_b``, ``dict_c``, ``dict_d``), where ``dict_a`` is the observation, ``dict_b`` is the reward, ``dict_c`` is whether the episode terminates, ``dict_d`` is additional informations for every agents; keys of every dictionary are agent IDs. """ # Update agents' dynamics (state) for agent in self.world.agents: action = actions[agent.id] agent.step_dynamics(action, dt=dt) # Get agents' sensory measurement self._sensor_capture() observations = {_a.id: _a.observations for _a in self.world.agents} # Check terminal conditions dones = dict() infos_from_terminal_condition = dict() terminal_condition = self.config['terminal_condition'] for agent in self.world.agents: dones[agent.id], infos_from_terminal_condition[ agent.id] = terminal_condition(self, agent.id) # Compute reward rewards = dict() reward_fn = self.config['reward_fn'] for agent in self.world.agents: rewards[agent.id], _ = reward_fn( self, agent.id, **infos_from_terminal_condition[agent.id]) # Get info infos = dict() return observations, rewards, dones, infos
[docs] def set_seed(self, seed) -> None: """ Set random seed. Args: seed (int): Random seed. """ self._seed = seed self._rng = np.random.default_rng(self.seed) self.world.set_seed(seed)
def _randomly_place_agent(self, agent: Car): # Randomly sampled a pose in the front of ego agent that is still on # the road. This can be achieved by, # (1) randomly sampling a distance from the ego agent # (2) fetch the closest pointer from the road kept by the ego agent # (3) slightly perturb the associated pose. tgt_dist = self._rng.uniform(*self.config['init_dist_range']) road = np.array(self.ego_agent.road) dist_from_ego = np.linalg.norm(road[:, :2], axis=1) tgt_idx = np.argmin(np.abs(tgt_dist - dist_from_ego)) tgt_pose = road[tgt_idx].copy() lat_noise = self._rng.uniform(*self.config['init_lat_noise_range']) tgt_pose[0] += lat_noise * np.cos(tgt_pose[2]) tgt_pose[1] += lat_noise * np.sin(tgt_pose[2]) yaw_noise = self._rng.uniform(*self.config['init_yaw_noise_range']) tgt_pose[2] += yaw_noise # Place agent given the randomly sampled pose agent.ego_dynamics.update(*tgt_pose) agent.step_dynamics(tgt_pose[-2:], dt=1e-8) def _reset_meshlib(self): self._meshlib.reset(self.config['n_agents']) # Assign car width and length based on mesh size for i, agent in enumerate(self.world.agents): agent._width = self._meshlib.agents_meshes_dim[i][0] agent._length = self._meshlib.agents_meshes_dim[i][1] def _sensor_capture(self) -> None: # Update mesh of virtual agents if self.config['n_agents'] > 1: for agent in self.world.agents: if len(agent.sensors) == 0: continue for i, other_agent in enumerate(self.world.agents): if other_agent.id == agent.id: continue # compute relative pose to the ego agent latlongyaw = transform.compute_relative_latlongyaw( other_agent.ego_dynamics.numpy()[:3], agent.human_dynamics.numpy()[:3]) # add to sensor scene_object = self._meshlib.agents_meshes[i] name = f'agent_{i}' for sensor in agent.sensors: sensor.update_scene_object(name, scene_object, latlongyaw) # Step sensors for agent in self.world.agents: agent.step_sensors() @property def config(self) -> Dict: """ Configuration of this task. """ return self._config @property def ego_agent(self) -> Car: """ Ego agent. """ return self.world.agents[0] @property def world(self) -> World: """ :class:`World` of this task. """ return self._world @property def seed(self) -> int: """ Random seed for the task and the associated :class:`World`. """ return self._seed
def compute_overlap(poly: Box, polys: List[Box]) -> List[float]: """ Compute overlapping area between 1 polygons and N polygons. Args: poly (shapely.geometry.Box): A polygon. poly (List): A list of polygon. Returns: List[float]: Intersecting area between polygons. """ n_polys = len(polys) overlap = np.zeros((n_polys)) for i in range(n_polys): intersection = polys[i].intersection(poly) overlap[i] = intersection.area return overlap