from threading import Event
from typing import Optional, Dict, Tuple, List, Any
from collections import deque
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm, colors, patches
from shapely.geometry import LineString
from descartes import PolygonPatch
import cv2
from . import World
from ..entities.agents.Dynamics import StateDynamics, update_with_perfect_controller, \
curvature2tireangle
from ..entities.agents import Car
from ..entities.sensors import Camera, EventCamera, Lidar
from ..entities.sensors.camera_utils import CameraParams
from ..entities.sensors.lidar_utils import Pointcloud
from ..utils import logging, transform, misc
[docs]class Display:
""" This is a visualizer of VISTA simulator. It renders an image that contains visualization
of all sensors from all agents and a top-down view that depicts road and all cars in the scene
within a predefined range based on the state of the simulator (:class:`World`).
Args:
world (vista.core.World): World to be visualized.
fps (int): Frame per second.
display_config (Dict): Configuration for the display (visualization).
Raises:
AssertionError: Grid spec is inconsistent with maximal number of sensors across agents.
Example usage::
>>> display_config = {
'road_buffer_size': 200,
'birdseye_map_size': (30, 20), # size of bev map in vertical and horizontal directions
'gs_bev_w': 2, # grid spec width for the birdseye view block
'gs_agent_w': 4, # grid spec width for an agent's block
'gs_h': 6, # grid spec height
'gui_scale': 1.0, # a global scale that determines the size of the figure
'vis_full_frame': False, # if Display should not crop/resize camera for visualization purposes
}
>>> display = Display(world, )
"""
DEFAULT_DISPLAY_CONFIG = {
'road_buffer_size': 200,
'birdseye_map_size': (30, 20),
'gs_bev_w': 2,
'gs_agent_w': 4,
'gs_h': 6,
'gui_scale': 1.0,
'vis_full_frame': False
}
def __init__(self,
world: World,
fps: Optional[float] = 30,
display_config: Optional[Dict] = dict()):
# Get arguments
self._world: World = world
self._fps: float = fps
self._config: Dict = misc.merge_dict(display_config,
self.DEFAULT_DISPLAY_CONFIG)
# Initialize data for plotting
self._road: deque[np.ndarray] = deque(
maxlen=self._config['road_buffer_size'])
self._road_frame_idcs: deque[int] = deque(
maxlen=self._config['road_buffer_size'])
self._road_dynamics: StateDynamics = StateDynamics()
# Get agents with sensors (to be visualized)
self._agents_with_sensors: List[Car] = []
for agent in self._world.agents:
if len(agent.sensors) > 0:
is_camera = [isinstance(_v, Camera) for _v in agent.sensors]
if len(is_camera) > 0:
self._agents_with_sensors.append(agent)
if len(is_camera) < len(agent.sensors):
logging.warning(
'Cannot visualize sensor other than Camera')
n_agents_with_sensors = len(self._agents_with_sensors)
if n_agents_with_sensors > 0:
max_n_sensors = max(
[len(_v.sensors) for _v in self._agents_with_sensors])
else:
max_n_sensors = 1
# Specify colors for agents and road
colors = list(cm.get_cmap('Set1').colors)
rgba2rgb = lambda rgba: np.clip(
(1 - rgba[:3]) * rgba[3] + rgba[:3], 0., 1.)
colors = [np.array(list(c) + [0.6]) for c in colors]
colors = list(map(rgba2rgb, colors))
self._agent_colors: List[Tuple] = colors
self._road_color: Tuple = list(cm.get_cmap('Dark2').colors)[-1]
# Initialize figure
self._artists: Dict[Any] = dict()
self._axes: Dict[plt.Axes] = dict()
gui_scale = self._config['gui_scale']
figsize = (6.4 * gui_scale * n_agents_with_sensors + 3.2 * gui_scale,
3.2 * gui_scale * max_n_sensors)
self._fig: plt.Figure = plt.figure(figsize=figsize)
self._fig.patch.set_facecolor('black') # use black background
self._gs = self._fig.add_gridspec(
self._config['gs_h'],
self._config['gs_agent_w'] * n_agents_with_sensors +
self._config['gs_bev_w'])
assert self._config['gs_h'] % max_n_sensors == 0, \
(f'Height of grid ({self._config["gs_h"]}) can not be exactly ' + \
f'divided by max number of sensors ({max_n_sensors})')
gs_agent_h = self._config['gs_h'] // max_n_sensors
# Initialize birds eye view
self._axes['bev'] = self._fig.add_subplot(
self._gs[:, -self._config['gs_bev_w']:])
self._axes['bev'].set_facecolor('black')
self._axes['bev'].set_xticks([])
self._axes['bev'].set_yticks([])
self._axes['bev'].set_title('Top-down View',
color='white',
size=20,
weight='bold')
self._axes['bev'].set_xlim(-self._config['birdseye_map_size'][1] / 2.,
self._config['birdseye_map_size'][1] / 2.)
self._axes['bev'].set_ylim(-self._config['birdseye_map_size'][0] / 2.,
self._config['birdseye_map_size'][0] / 2.)
# Initialize plot for sensory measurement
logging.debug(
'Does not handle preprocessed (cropped/resized) observation')
for i, agent in enumerate(self._agents_with_sensors):
for j, sensor in enumerate(agent.sensors):
if (isinstance(sensor, Camera)
or isinstance(sensor, EventCamera)):
param = sensor.camera_param
img_shape = (param.get_height(), param.get_width(), 3)
elif isinstance(sensor, Lidar):
x_dim, y_dim = sensor.view_synthesis._dims[:, 0]
# Cut width in half and stack on-top
img_shape = (y_dim * 2, x_dim // 2, 3)
else:
logging.error(f'Unrecognized sensor type {type(sensor)}')
continue
gs_ij = self._gs[gs_agent_h * j:gs_agent_h * (j + 1),
self._config['gs_agent_w'] *
i:self._config['gs_agent_w'] * (i + 1)]
ax_name = 'a{}s{}'.format(i, j)
self._axes[ax_name] = self._fig.add_subplot(gs_ij,
facecolor='black')
self._axes[ax_name].set_xticks([])
self._axes[ax_name].set_yticks([])
self._axes[ax_name].set_title('Init',
color='white',
size=20,
weight='bold')
placeholder = fit_img_to_ax(
self._fig, self._axes[ax_name],
np.zeros(img_shape, dtype=np.uint8))
self._artists['im:{}'.format(
ax_name)] = self._axes[ax_name].imshow(placeholder)
self._fig.tight_layout()
[docs] def reset(self) -> None:
""" Reset the visualizer. This should be called every time after
:class:`World` reset. It basically reset the cache of road data
used in the top-down view visualization.
"""
# Reset road deque
self._road.clear()
self._road.append(self.ref_agent.human_dynamics.numpy()[:3])
self._road_dynamics = self.ref_agent.human_dynamics.copy()
self._road_frame_idcs.clear()
self._road_frame_idcs.append(self.ref_agent.frame_index)
[docs] def render(self):
""" Render an image that visualizes the simulator. This includes visualization
of all sensors of every agent and a top-down view that depicts the road and all
cars in the scene within a certain range. Note that it render visualization based
on the current status of the world and should be called every time when there is
any update to the simulator.
Returns:
np.ndarray: An image of visualization for the simulator.
"""
# Update road (in global coordinate)
exceed_end = False
while self._road_frame_idcs[-1] < (
self.ref_agent.frame_index +
self._config['road_buffer_size'] / 2.) and not exceed_end:
exceed_end, ts = self._get_timestamp(self._road_frame_idcs[-1])
self._road_frame_idcs.append(self._road_frame_idcs[-1] + 1)
exceed_end, next_ts = self._get_timestamp(
self._road_frame_idcs[-1])
state = [
curvature2tireangle(self.ref_agent.trace.f_curvature(ts),
self.ref_agent.wheel_base),
self.ref_agent.trace.f_speed(ts)
]
update_with_perfect_controller(state, next_ts - ts,
self._road_dynamics)
self._road.append(self._road_dynamics.numpy()[:3])
# Update road in birds eye view (in reference agent's coordinate)
ref_pose = self.ref_agent.human_dynamics.numpy()[:3]
logging.debug('Computation of road in reference frame not vectorized')
road_in_ref = np.array([
transform.compute_relative_latlongyaw(_v, ref_pose)
for _v in self._road
])
road_half_width = self.ref_agent.trace.road_width / 2.
patch = LineString(road_in_ref).buffer(road_half_width)
patch = PolygonPatch(patch,
fc=self._road_color,
ec=self._road_color,
zorder=1)
self._update_patch(self._axes['bev'], 'patch:road', patch)
# Update agent in birds eye view
for i, agent in enumerate(self._world.agents):
poly = misc.agent2poly(agent, self.ref_agent.human_dynamics)
color = self._agent_colors[i]
patch = PolygonPatch(poly, fc=color, ec=color, zorder=2)
self._update_patch(self._axes['bev'], 'patch:agent_{}'.format(i),
patch)
# Update sensory measurements
for i, agent in enumerate(self._agents_with_sensors):
cameras = {
_v.name: _v
for _v in agent.sensors if isinstance(_v, Camera)
}
event_cameras = {
_v.name: _v
for _v in agent.sensors if isinstance(_v, EventCamera)
}
lidars = {
_v.name: _v
for _v in agent.sensors if isinstance(_v, Lidar)
}
for j, (obs_name, obs) in enumerate(agent.observations.items()):
ax_name = 'a{}s{}'.format(i, j)
if obs_name in cameras.keys():
obs = plot_roi(obs.copy(),
cameras[obs_name].camera_param.get_roi())
sensor = cameras[obs_name]
noodle = curvature2noodle(self.ref_agent.curvature,
sensor.camera_param,
mode='camera')
obs = cv2.polylines(obs, [noodle], False, (255, 0, 0), 2)
if not self._config["vis_full_frame"]:
# Black out the sides for visualization
h, w = obs.shape[:2]
h_, w_ = (0.65 * h, 0.65 * w)
hs, ws = (int((h - h_) // 2), int((w - w_) // 2))
obs = cv2.resize(obs[hs:-hs, ws:-ws], (w, h))
obs_render = fit_img_to_ax(self._fig, self._axes[ax_name],
obs[:, :, ::-1])
elif obs_name in event_cameras.keys():
event_cam_param = event_cameras[obs_name].camera_param
frame_obs = events2frame(obs, event_cam_param.get_height(),
event_cam_param.get_width())
sensor = event_cameras[obs_name]
frame_obs = plot_roi(frame_obs.copy(),
sensor.camera_param.get_roi())
noodle = curvature2noodle(self.ref_agent.curvature,
sensor.camera_param,
mode='camera')
frame_obs = cv2.polylines(frame_obs, [noodle], False,
(0, 0, 255), 2)
obs_render = fit_img_to_ax(self._fig, self._axes[ax_name],
frame_obs[:, :, ::-1])
elif obs_name in lidars.keys():
if isinstance(obs, Pointcloud):
obs_render = None
ax = self._axes[ax_name]
ax.clear()
obs = obs[::10] # sub-sample the pointcloud for vis
ax, scat = plot_pointcloud(
obs,
ax=ax,
color_by="z",
max_dist=20.,
car_dims=(self.ref_agent.length,
self.ref_agent.width),
cmap="nipy_spectral")
# Plot the noodle
noodle = curvature2noodle(self.ref_agent.curvature,
mode='lidar')
ax.plot(noodle[:, 0], noodle[:, 1], '-r', linewidth=3)
else: # dense image
obs = np.roll(obs, -obs.shape[1] // 4, axis=1) # shift
obs = np.concatenate(np.split(obs, 2, axis=1),
0) # stack
obs = np.clip(4 * obs, 0, 255).astype(np.uint8) # norm
obs = cv2.applyColorMap(obs, cv2.COLORMAP_JET) # color
obs_render = fit_img_to_ax(self._fig,
self._axes[ax_name], obs)
else:
logging.error(f'Unrecognized observation {obs_name}')
continue
title = '{}: {}'.format(agent.id, obs_name)
self._axes[ax_name].set_title(title,
color='white',
size=20,
weight='bold')
if obs_render is not None:
self._artists['im:{}'.format(ax_name)].set_data(obs_render)
# Convert to image
img = fig2img(self._fig)
return img
def _update_patch(self, ax: plt.Axes, name: str,
patch: PolygonPatch) -> None:
if name in self._artists:
self._artists[name].remove()
ax.add_patch(patch)
self._artists[name] = patch
def _get_timestamp(self, frame_index: int) -> Tuple[float, bool]:
return self.ref_agent.trace.get_master_timestamp(
self.ref_agent.segment_index, frame_index, check_end=True)
@property
def ref_agent(self) -> Car:
""" Agent as a reference to compute poses of objects (e.g., cars, road)
in visualization. """
return self._world.agents[0]
def curvature2noodle(curvature: float,
camera_param: Optional[CameraParams] = None,
mode: Optional[str] = 'camera') -> np.ndarray:
""" Construct a curly line (noodle) based on the curvature for visualizing
steering control command.
Args:
curvature (float): Curvature (steering angle control command).
camera_param (vista.entities.sensors.camera_utils.CameraParams): Camera parameters; used if
mode is set to camera.
mode (str): Sensor type for the visualization.
Returns:
np.ndarray: A curly line that visualizes the given curvature.
Raises:
NotImplementedError: Unrecognized mode to draw the noodle.
"""
lookaheads = np.linspace(0, 15, 10) # meters
if mode == 'camera':
assert camera_param is not None
K = camera_param.get_K()
normal = camera_param.get_ground_plane()[0:3]
normal = np.reshape(normal, [1, 3])
d = camera_param.get_ground_plane()[3]
A, B, C = normal[0]
radius = 1. / (curvature + 1e-8)
z_vals = lookaheads
y_vals = (d - C * z_vals) / B
x_sq_r = radius**2 - z_vals**2 - (y_vals - d)**2
x_vals = np.sqrt(x_sq_r[x_sq_r > 0]) - abs(radius)
y_vals = y_vals[x_sq_r > 0]
z_vals = z_vals[x_sq_r > 0]
if radius < 0:
x_vals *= -1
world_coords = np.stack((x_vals, y_vals, z_vals))
theta = camera_param.get_yaw()
R = np.array([[np.cos(theta), 0.0, -np.sin(theta)], [0.0, 1.0, 0.0],
[np.sin(theta), 0.0, np.cos(theta)]])
tf_world_coords = np.matmul(R, world_coords)
img_coords = np.matmul(K, tf_world_coords)
norm = np.divide(img_coords, img_coords[2] + 1e-10)
valid_inds = np.multiply(norm[0] >= 0,
norm[0] < camera_param.get_width())
valid_inds = np.multiply(valid_inds, norm[1] >= 0)
valid_inds = np.multiply(valid_inds,
norm[1] < camera_param.get_height())
noodle = norm[:2, valid_inds].astype(np.int32).T
elif mode == 'lidar':
turning_r = 1 / (curvature + 1e-8)
shifts = (np.sqrt(turning_r**2 - lookaheads**2) - abs(turning_r))
shifts = -1 * np.sign(turning_r) * shifts
noodle = np.stack([lookaheads, shifts], axis=1)
else:
raise NotImplementedError(
'Unrecognized mode {} in drawing noodle'.format(mode))
return noodle
def plot_roi(img: np.ndarray,
roi: List[int],
color: Optional[List[int]] = (0, 0, 255),
thickness: Optional[int] = 2) -> np.ndarray:
""" Plot a bounding box that shows ROI on an image.
Args:
img (np.ndarray): An image to be plotted.
roi (List[int]): Region of interest.
color (List[int]): Color of the bounding box.
thickness (int): Thickness of the bounding box.
Returns:
np.ndarray: An image with the ROI bounding box.
"""
(i1, j1, i2, j2) = roi
img = cv2.rectangle(img, (j1, i1), (j2, i2), color, thickness)
return img
def events2frame(events: List[np.ndarray],
cam_h: int,
cam_w: int,
positive_color: Optional[List] = [255, 255, 255],
negative_color: Optional[List] = [212, 188, 114],
mode: Optional[int] = 2) -> np.ndarray:
""" Convert event data to frame representation.
Args:
events (List[np.ndarray]): A list with entries as a collection of positive and
negative events.
cam_h (int): Height of the frame representation.
cam_w (int): Width of the frame representation.
positive_color (List): Color of positive events.
negative_color (List): Color of negative events.
mode (int): Mode for colorization.
Returns:
np.ndarray: Frame representation of event data.
"""
if mode == 0:
frame = np.zeros((cam_h, cam_w, 3), dtype=np.uint8)
for color, p_events in zip([positive_color, negative_color], events):
uv = np.concatenate(p_events)[:, :2]
frame[uv[:, 0], uv[:, 1], :] = color
elif mode == 1:
frame_acc = np.zeros((cam_h, cam_w), dtype=np.int8)
for polarity, p_events in zip([1, -1], events):
for sub_p_events in p_events:
uv = sub_p_events[:, :2]
frame_acc[uv[:, 0], uv[:, 1]] += polarity
frame = np.zeros((cam_h, cam_w, 3), dtype=np.uint8)
frame[frame_acc > 0, :] = positive_color
frame[frame_acc < 0, :] = negative_color
elif mode == 2:
frame_abs_acc = np.zeros((cam_h, cam_w), dtype=np.int8)
frame = np.zeros((cam_h, cam_w, 3), dtype=np.uint8)
for polarity, p_events in zip([1, -1], events):
for sub_p_events in p_events:
uv = sub_p_events[:, :2]
add_c = np.array(
positive_color if polarity > 0 else negative_color)[None,
...]
cnt = frame_abs_acc[uv[:, 0], uv[:, 1]][:, None]
frame[uv[:, 0], uv[:, 1]] = (frame[uv[:, 0], uv[:, 1]] * cnt +
add_c) / (cnt + 1)
frame_abs_acc[uv[:, 0], uv[:, 1]] = cnt[:, 0] + 1
else:
raise NotImplementedError('Unknown mode {}'.format(mode))
return frame
def plot_pointcloud(pcd,
color_by="z",
max_dist=None,
cmap="nipy_spectral",
car_dims=None,
ax=None,
scat=None,
s=1):
""" Convert pointcloud to an image for visualization. """
if ax is None:
_, ax = plt.subplots()
if max_dist is not None:
pcd = pcd[pcd.dist < (max_dist * np.sqrt(2))]
if color_by == "z":
c = pcd.z
vmin, vmax = (-2.5, 4)
elif color_by == "intensity":
c = np.log(1 + pcd.intensity)
vmin, vmax = (1.7, 4.3)
else:
raise ValueError(f"unsupported color {color_by}")
# Plot points
if scat is None:
scat = ax.scatter(pcd.x,
pcd.y,
c=c,
s=s,
vmin=vmin,
vmax=vmax,
cmap=cmap)
else:
scat.set_offsets(np.stack([pcd.x, pcd.y], axis=1))
scat.set_clim(vmin, vmax)
scat.set_color(getattr(plt.cm, cmap)(scat.norm(c)))
# Plot car
if car_dims is not None:
l_car, w_car = car_dims
ax.add_patch(
patches.Rectangle(
(-l_car / 2, -w_car / 2),
l_car,
w_car,
fill=True # remove background
))
ax.set_xlim(-max_dist, max_dist)
ax.set_ylim(-max_dist, max_dist)
return ax, scat
def fig2img(fig: plt.Figure) -> np.ndarray:
""" Convert a matplotlib figure to a numpy array. """
fig.canvas.draw()
buf = fig.canvas.buffer_rgba()
img = np.asarray(buf)[:, :, :3]
return img
def fit_img_to_ax(fig: plt.Figure, ax: plt.Axes,
img: np.ndarray) -> np.ndarray:
""" Fit an image to an axis in a matplotlib figure. """
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
w, h = bbox.width, bbox.height
img_h, img_w = img.shape[:2]
new_img_w = img_h * w / h
new_img_h = img_w * h / w
d_img_w = new_img_w - img_w
d_img_h = new_img_h - img_h
if d_img_h > 0:
pad_img = np.zeros((int(d_img_h // 2), img_w, 3), dtype=np.uint8)
new_img = np.concatenate([pad_img, img, pad_img], axis=0)
elif d_img_w > 0:
pad_img = np.zeros((img_h, int(d_img_w // 2), 3), dtype=np.uint8)
new_img = np.concatenate([pad_img, img, pad_img], axis=1)
else:
raise ValueError('Something weird happened.')
return new_img