284 lines
9.8 KiB
Python
284 lines
9.8 KiB
Python
#######################################################################
|
|
# Adapted code from https://github.com/jinlinyi/PerspectiveFields
|
|
#######################################################################
|
|
# Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
|
|
import matplotlib.colors as mplc
|
|
import matplotlib.figure as mplfigure
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
|
|
|
|
class VisImage:
|
|
def __init__(self, img, scale=1.0):
|
|
"""
|
|
Args:
|
|
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
|
|
scale (float): scale the input image
|
|
"""
|
|
self.img = img
|
|
self.scale = scale
|
|
self.width, self.height = img.shape[1], img.shape[0]
|
|
self._setup_figure(img)
|
|
|
|
def _setup_figure(self, img):
|
|
"""
|
|
Args:
|
|
Same as in :meth:`__init__()`.
|
|
|
|
Returns:
|
|
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
|
|
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
|
|
"""
|
|
fig = mplfigure.Figure(frameon=False)
|
|
self.dpi = fig.get_dpi()
|
|
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
|
|
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
|
fig.set_size_inches(
|
|
(self.width * self.scale + 1e-2) / self.dpi,
|
|
(self.height * self.scale + 1e-2) / self.dpi,
|
|
)
|
|
self.canvas = FigureCanvasAgg(fig)
|
|
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
|
|
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
|
ax.axis("off")
|
|
self.fig = fig
|
|
self.ax = ax
|
|
self.reset_image(img)
|
|
|
|
def reset_image(self, img):
|
|
"""
|
|
Args:
|
|
img: same as in __init__
|
|
"""
|
|
img = img.astype("uint8")
|
|
self.ax.imshow(
|
|
img, extent=(0, self.width, self.height, 0), interpolation="nearest"
|
|
)
|
|
|
|
def save(self, filepath):
|
|
"""
|
|
Args:
|
|
filepath (str): a string that contains the absolute path, including the file name, where
|
|
the visualized image will be saved.
|
|
"""
|
|
self.fig.savefig(filepath)
|
|
|
|
def get_image(self):
|
|
"""
|
|
Returns:
|
|
ndarray:
|
|
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
|
|
The shape is scaled w.r.t the input image using the given `scale` argument.
|
|
"""
|
|
canvas = self.canvas
|
|
s, (width, height) = canvas.print_to_buffer()
|
|
# buf = io.BytesIO() # works for cairo backend
|
|
# canvas.print_rgba(buf)
|
|
# width, height = self.width, self.height
|
|
# s = buf.getvalue()
|
|
|
|
buffer = np.frombuffer(s, dtype="uint8")
|
|
|
|
img_rgba = buffer.reshape(height, width, 4)
|
|
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
|
return rgb.astype("uint8")
|
|
|
|
|
|
class Visualizer:
|
|
"""
|
|
Visualizer that draws data about detection/segmentation on images.
|
|
|
|
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
|
|
that draw primitive objects to images, as well as high-level wrappers like
|
|
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
|
|
that draw composite data in some pre-defined style.
|
|
|
|
Note that the exact visualization style for the high-level wrappers are subject to change.
|
|
Style such as color, opacity, label contents, visibility of labels, or even the visibility
|
|
of objects themselves (e.g. when the object is too small) may change according
|
|
to different heuristics, as long as the results still look visually reasonable.
|
|
|
|
To obtain a consistent style, you can implement custom drawing functions with the
|
|
abovementioned primitive methods instead. If you need more customized visualization
|
|
styles, you can process the data yourself following their format documented in
|
|
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
|
|
intend to satisfy everyone's preference on drawing styles.
|
|
|
|
This visualizer focuses on high rendering quality rather than performance. It is not
|
|
designed to be used for real-time applications.
|
|
"""
|
|
|
|
def __init__(self, img_rgb, scale=1.0, font_size_scale=1.0):
|
|
"""
|
|
Args:
|
|
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
|
|
the height and width of the image respectively. C is the number of
|
|
color channels. The image is required to be in RGB format since that
|
|
is a requirement of the Matplotlib library. The image is also expected
|
|
to be in the range [0, 255].
|
|
font_size_scale: extra scaling of font size on top of default font size
|
|
"""
|
|
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
|
|
self.output = VisImage(self.img, scale=scale)
|
|
self.cpu_device = torch.device("cpu")
|
|
|
|
# too small texts are useless, therefore clamp to 9
|
|
self._default_font_size = (
|
|
max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale)
|
|
* font_size_scale
|
|
)
|
|
|
|
"""
|
|
Primitive drawing functions:
|
|
"""
|
|
|
|
def draw_text(
|
|
self,
|
|
text,
|
|
position,
|
|
*,
|
|
font_size=None,
|
|
color="g",
|
|
horizontal_alignment="center",
|
|
rotation=0,
|
|
):
|
|
"""
|
|
Args:
|
|
text (str): class label
|
|
position (tuple): a tuple of the x and y coordinates to place text on image.
|
|
font_size (int, optional): font of the text. If not provided, a font size
|
|
proportional to the image width is calculated and used.
|
|
color: color of the text. Refer to `matplotlib.colors` for full list
|
|
of formats that are accepted.
|
|
horizontal_alignment (str): see `matplotlib.text.Text`
|
|
rotation: rotation angle in degrees CCW
|
|
|
|
Returns:
|
|
output (VisImage): image object with text drawn.
|
|
"""
|
|
if not font_size:
|
|
font_size = self._default_font_size
|
|
|
|
# since the text background is dark, we don't want the text to be dark
|
|
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
|
color[np.argmax(color)] = max(0.8, np.max(color))
|
|
|
|
x, y = position
|
|
self.output.ax.text(
|
|
x,
|
|
y,
|
|
text,
|
|
size=font_size * self.output.scale,
|
|
family="sans-serif",
|
|
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
|
|
verticalalignment="top",
|
|
horizontalalignment=horizontal_alignment,
|
|
color=color,
|
|
zorder=10,
|
|
rotation=rotation,
|
|
)
|
|
return self.output
|
|
|
|
def get_output(self):
|
|
"""
|
|
Returns:
|
|
output (VisImage): the image output containing the visualizations added
|
|
to the image.
|
|
"""
|
|
return self.output
|
|
|
|
|
|
class VisualizerPerspective(Visualizer):
|
|
def draw_arrow(
|
|
self,
|
|
x_pos,
|
|
y_pos,
|
|
x_direct,
|
|
y_direct,
|
|
color=None,
|
|
linestyle="-",
|
|
linewidth=None,
|
|
):
|
|
"""
|
|
Args:
|
|
x_data (list[int]): a list containing x values of all the points being drawn.
|
|
Length of list should match the length of y_data.
|
|
y_data (list[int]): a list containing y values of all the points being drawn.
|
|
Length of list should match the length of x_data.
|
|
color: color of the line. Refer to `matplotlib.colors` for a full list of
|
|
formats that are accepted.
|
|
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
|
|
for a full list of formats that are accepted.
|
|
linewidth (float or None): width of the line. When it's None,
|
|
a default value will be computed and used.
|
|
|
|
Returns:
|
|
output (VisImage): image object with line drawn.
|
|
"""
|
|
if linewidth is None:
|
|
linewidth = self._default_font_size / 3
|
|
linewidth = max(linewidth, 1)
|
|
self.output.ax.quiver(
|
|
x_pos,
|
|
y_pos,
|
|
x_direct,
|
|
y_direct,
|
|
color=color,
|
|
scale_units="xy",
|
|
scale=1,
|
|
antialiased=True,
|
|
headaxislength=3.5,
|
|
linewidths=0.1, # , width=0.01
|
|
)
|
|
return self.output
|
|
|
|
def draw_lati(
|
|
self, latimap, alpha_contourf=0.4, alpha_contour=0.9, contour_only=False
|
|
):
|
|
"""Blend latitude map"""
|
|
height, width = latimap.shape
|
|
y, x = np.mgrid[0:height, 0:width]
|
|
cmap = plt.get_cmap("seismic")
|
|
bands = 20
|
|
levels = np.linspace(-np.pi / 2, np.pi / 2, bands - 1)
|
|
if not contour_only:
|
|
pp = self.output.ax.contourf(
|
|
x,
|
|
y,
|
|
latimap,
|
|
levels=levels,
|
|
cmap=cmap,
|
|
alpha=alpha_contourf,
|
|
antialiased=True,
|
|
)
|
|
pp2 = self.output.ax.contour(
|
|
x,
|
|
y,
|
|
latimap,
|
|
pp.levels,
|
|
cmap=cmap,
|
|
alpha=alpha_contour,
|
|
antialiased=True,
|
|
linewidths=5,
|
|
)
|
|
try:
|
|
for c in pp2.collections:
|
|
c.set_linestyle("solid")
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
# only plot central contour
|
|
pp = self.output.ax.contour(
|
|
x,
|
|
y,
|
|
latimap,
|
|
levels=[0],
|
|
cmap=cmap,
|
|
alpha=alpha_contour,
|
|
antialiased=True,
|
|
linewidths=15,
|
|
)
|
|
return self.output
|