KICCO_AI_IMAGE/flux_style_shaper_api/perspective_fields/visualizer.py

284 lines
9.8 KiB
Python

#######################################################################
# Adapted code from https://github.com/jinlinyi/PerspectiveFields
#######################################################################
# Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
import matplotlib.colors as mplc
import matplotlib.figure as mplfigure
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg
class VisImage:
def __init__(self, img, scale=1.0):
"""
Args:
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
scale (float): scale the input image
"""
self.img = img
self.scale = scale
self.width, self.height = img.shape[1], img.shape[0]
self._setup_figure(img)
def _setup_figure(self, img):
"""
Args:
Same as in :meth:`__init__()`.
Returns:
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
"""
fig = mplfigure.Figure(frameon=False)
self.dpi = fig.get_dpi()
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches(
(self.width * self.scale + 1e-2) / self.dpi,
(self.height * self.scale + 1e-2) / self.dpi,
)
self.canvas = FigureCanvasAgg(fig)
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
self.fig = fig
self.ax = ax
self.reset_image(img)
def reset_image(self, img):
"""
Args:
img: same as in __init__
"""
img = img.astype("uint8")
self.ax.imshow(
img, extent=(0, self.width, self.height, 0), interpolation="nearest"
)
def save(self, filepath):
"""
Args:
filepath (str): a string that contains the absolute path, including the file name, where
the visualized image will be saved.
"""
self.fig.savefig(filepath)
def get_image(self):
"""
Returns:
ndarray:
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
The shape is scaled w.r.t the input image using the given `scale` argument.
"""
canvas = self.canvas
s, (width, height) = canvas.print_to_buffer()
# buf = io.BytesIO() # works for cairo backend
# canvas.print_rgba(buf)
# width, height = self.width, self.height
# s = buf.getvalue()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
return rgb.astype("uint8")
class Visualizer:
"""
Visualizer that draws data about detection/segmentation on images.
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
that draw primitive objects to images, as well as high-level wrappers like
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
that draw composite data in some pre-defined style.
Note that the exact visualization style for the high-level wrappers are subject to change.
Style such as color, opacity, label contents, visibility of labels, or even the visibility
of objects themselves (e.g. when the object is too small) may change according
to different heuristics, as long as the results still look visually reasonable.
To obtain a consistent style, you can implement custom drawing functions with the
abovementioned primitive methods instead. If you need more customized visualization
styles, you can process the data yourself following their format documented in
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
intend to satisfy everyone's preference on drawing styles.
This visualizer focuses on high rendering quality rather than performance. It is not
designed to be used for real-time applications.
"""
def __init__(self, img_rgb, scale=1.0, font_size_scale=1.0):
"""
Args:
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
the height and width of the image respectively. C is the number of
color channels. The image is required to be in RGB format since that
is a requirement of the Matplotlib library. The image is also expected
to be in the range [0, 255].
font_size_scale: extra scaling of font size on top of default font size
"""
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
self.output = VisImage(self.img, scale=scale)
self.cpu_device = torch.device("cpu")
# too small texts are useless, therefore clamp to 9
self._default_font_size = (
max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale)
* font_size_scale
)
"""
Primitive drawing functions:
"""
def draw_text(
self,
text,
position,
*,
font_size=None,
color="g",
horizontal_alignment="center",
rotation=0,
):
"""
Args:
text (str): class label
position (tuple): a tuple of the x and y coordinates to place text on image.
font_size (int, optional): font of the text. If not provided, a font size
proportional to the image width is calculated and used.
color: color of the text. Refer to `matplotlib.colors` for full list
of formats that are accepted.
horizontal_alignment (str): see `matplotlib.text.Text`
rotation: rotation angle in degrees CCW
Returns:
output (VisImage): image object with text drawn.
"""
if not font_size:
font_size = self._default_font_size
# since the text background is dark, we don't want the text to be dark
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
x, y = position
self.output.ax.text(
x,
y,
text,
size=font_size * self.output.scale,
family="sans-serif",
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
verticalalignment="top",
horizontalalignment=horizontal_alignment,
color=color,
zorder=10,
rotation=rotation,
)
return self.output
def get_output(self):
"""
Returns:
output (VisImage): the image output containing the visualizations added
to the image.
"""
return self.output
class VisualizerPerspective(Visualizer):
def draw_arrow(
self,
x_pos,
y_pos,
x_direct,
y_direct,
color=None,
linestyle="-",
linewidth=None,
):
"""
Args:
x_data (list[int]): a list containing x values of all the points being drawn.
Length of list should match the length of y_data.
y_data (list[int]): a list containing y values of all the points being drawn.
Length of list should match the length of x_data.
color: color of the line. Refer to `matplotlib.colors` for a full list of
formats that are accepted.
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
for a full list of formats that are accepted.
linewidth (float or None): width of the line. When it's None,
a default value will be computed and used.
Returns:
output (VisImage): image object with line drawn.
"""
if linewidth is None:
linewidth = self._default_font_size / 3
linewidth = max(linewidth, 1)
self.output.ax.quiver(
x_pos,
y_pos,
x_direct,
y_direct,
color=color,
scale_units="xy",
scale=1,
antialiased=True,
headaxislength=3.5,
linewidths=0.1, # , width=0.01
)
return self.output
def draw_lati(
self, latimap, alpha_contourf=0.4, alpha_contour=0.9, contour_only=False
):
"""Blend latitude map"""
height, width = latimap.shape
y, x = np.mgrid[0:height, 0:width]
cmap = plt.get_cmap("seismic")
bands = 20
levels = np.linspace(-np.pi / 2, np.pi / 2, bands - 1)
if not contour_only:
pp = self.output.ax.contourf(
x,
y,
latimap,
levels=levels,
cmap=cmap,
alpha=alpha_contourf,
antialiased=True,
)
pp2 = self.output.ax.contour(
x,
y,
latimap,
pp.levels,
cmap=cmap,
alpha=alpha_contour,
antialiased=True,
linewidths=5,
)
try:
for c in pp2.collections:
c.set_linestyle("solid")
except AttributeError:
pass
else:
# only plot central contour
pp = self.output.ax.contour(
x,
y,
latimap,
levels=[0],
cmap=cmap,
alpha=alpha_contour,
antialiased=True,
linewidths=15,
)
return self.output