diff --git a/flux_style_shaper_api/flux_precisecam_demo.py b/flux_style_shaper_api/flux_precisecam_demo.py new file mode 100644 index 00000000..44055804 --- /dev/null +++ b/flux_style_shaper_api/flux_precisecam_demo.py @@ -0,0 +1,253 @@ +import gradio as gr +import numpy as np +import torch +import os +from PIL import Image +from perspective_fields import pano_utils as pu +from diffusers import ( + AutoencoderKL, + ControlNetModel, + FluxControlNetPipeline, + FlowMatchEulerDiscreteScheduler, + FluxControlNetModel, +) +from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection +from transformers import T5EncoderModel, T5TokenizerFast +from safetensors.torch import load_file +# 加载本地模型文件路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_dir = os.path.dirname(current_dir) + +# 修正路径格式,使用os.path.join确保路径分隔符正确 +flux_model_dir = os.path.join(project_dir, "models", "diffusion_models") + +print(f"Loading models from:") +print(f"Flux dir: {flux_model_dir}") + +# 创建缓存目录(如果不存在) +os.makedirs(os.path.join(project_dir, "cache"), exist_ok=True) + +# 加载所有需要的组件 +print("Loading components...") + +# 1. 加载调度器 +scheduler = FlowMatchEulerDiscreteScheduler.from_config( + { + "num_train_timesteps": 1000 + } +) + +# 5. 加载ControlNet - 使用与FLUX.1-dev兼容的模型 +print("Loading ControlNet compatible with FLUX.1-dev...") +controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" +controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16,cache_dir=os.path.join(project_dir, "cache")) + + +# 2. 加载VAE +try: + vae = AutoencoderKL.from_pretrained( + "diffusers/FLUX.1-vae", + torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet + cache_dir=os.path.join(project_dir, "cache") + ) +except Exception as e: + print(f"Error loading VAE: {e}") + print("Loading default VAE...") + vae = AutoencoderKL.from_pretrained( + "diffusers/FLUX.1-vae", + torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet + cache_dir=os.path.join(project_dir, "cache") + ) + +# 3. 加载文本编码器和标记器 +text_encoder = CLIPTextModel.from_pretrained( + "openai/clip-vit-large-patch14", + torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet + cache_dir=os.path.join(project_dir, "cache") +) +tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", + cache_dir=os.path.join(project_dir, "cache") +) + +# 4. 加载第二个文本编码器和标记器 +print("从本地加载预先下载的T5模型...") +# 指定本地模型路径 - 假设已经手动下载了魔搭的模型 +t5_model_path = os.path.join(project_dir, "models", "t5_xxl") # 根据实际路径调整 + +# 检查模型路径是否存在 +if not os.path.exists(t5_model_path): + print(f"警告: 本地模型路径 {t5_model_path} 不存在。请先手动下载魔搭的T5-XXL模型到此路径。") + print("继续尝试从HuggingFace下载,但这可能会很慢...") + t5_model_path = "google/t5-v1_1-xxl" + +text_encoder_2 = T5EncoderModel.from_pretrained( + t5_model_path, # 使用本地路径 + torch_dtype=torch.bfloat16, + cache_dir=os.path.join(project_dir, "cache") +) +tokenizer_2 = T5TokenizerFast.from_pretrained( + t5_model_path, # 使用本地路径 + cache_dir=os.path.join(project_dir, "cache") +) + + + +# 6. 加载特征提取器和图像编码器 +feature_extractor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-large-patch14", + cache_dir=os.path.join(project_dir, "cache") +) +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", + torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet + cache_dir=os.path.join(project_dir, "cache") +) + +# 7. 加载PipelineImageProcessor (transformer) +base_model = "black-forest-labs/FLUX.1-dev" + +# 不加载transformer组件,因为这会导致类型错误 +# pipeline会自动初始化正确类型的transformer + +# 8. 加载Pipeline - 直接从HuggingFace加载 +try: + print("Loading FLUX model from HuggingFace...") + # 设置缓存相关环境变量 + os.environ["HF_HOME"] = os.path.join(project_dir, "cache") + os.environ["TRANSFORMERS_CACHE"] = os.path.join(project_dir, "cache", "transformers") + os.environ["DIFFUSERS_CACHE"] = os.path.join(project_dir, "cache", "diffusers") + # 设置缓存路径参数 + cache_path = os.path.join(project_dir, "cache") + + pipe = FluxControlNetPipeline.from_pretrained( + base_model, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + controlnet=controlnet, + torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet + low_cpu_mem_usage=False, + cache_dir=cache_path, + resume_download=True, # 如果下载中断,尝试继续下载 + force_download=False # 如果已缓存,不重新下载 + ) +except Exception as e: + print(f"Error loading from HuggingFace: {e}") + print("Trying alternative configuration...") + +# 启用CPU卸载以减少GPU内存使用 +pipe.enable_model_cpu_offload() + +def inference(prompt, pf_image, n_steps=2, seed=13): + """Generates an image based on the given prompt and perspective field image.""" + pf = Image.fromarray( + np.concatenate( + [np.expand_dims(pf_image[:, :, i], axis=-1) for i in [2, 1, 0]], axis=2 + ) + ) + pf_condition = pf.resize((1024, 1024)) + generator = torch.manual_seed(seed) + print("Control image size:", pf_condition.size) # 应该是 (1024, 1024) + return pipe( + prompt=prompt, + num_inference_steps=n_steps, + generator=generator, + control_image=pf_condition, # 只保留ControlNet的输入参数 + # 移除不支持的参数 + # image=pf_condition, # 需要为img2img提供一个基础图像 + # strength=0.8, # 控制处理强度 + ).images[0] + + +def obtain_pf(xi, roll, pitch, vfov): + """Computes perspective fields given camera parameters.""" + w, h = (1024, 1024) + equi_img = np.zeros((h, w, 3), dtype=np.uint8) + x = -np.sin(np.radians(vfov / 2)) + z = np.sqrt(1 - x**2) + f_px_effective = -0.5 * (w / 2) * (xi + z) / x + + crop, _, _, _, up, lat, _ = pu.crop_distortion( + equi_img, f=f_px_effective, xi=xi, H=h, W=w, az=10, el=-pitch, roll=roll + ) + + gravity = (up + 1) * 127.5 + latitude = np.expand_dims((np.degrees(lat) + 90) * (255 / 180), axis=-1) + pf_image = np.concatenate([gravity, latitude], axis=2).astype(np.uint8) + + # 尝试生成blend图像,如果失败则创建一个替代的blend图像 + try: + blend = pu.draw_perspective_fields(crop, up, np.radians(np.degrees(lat))) + except Exception as e: + print(f"Error generating blend image: {e}") + # 创建一个替代的blend图像 + blend = np.zeros_like(pf_image) + # 添加一些基本的可视化 + blend[:,:,0] = (np.degrees(lat) + 90) * (255 / 180) # 使用latitude作为红色通道 + blend[:,:,1] = (up + 1) * 127.5 # 使用up作为绿色通道 + blend[:,:,2] = 128 # 蓝色通道设置为中性值 + + return pf_image, blend + + +# Gradio UI +demo = gr.Blocks(theme=gr.themes.Soft()) + +with demo: + gr.Markdown("""---""") + gr.Markdown("""# PreciseCam: Precise Camera Control for Text-to-Image Generation""") + gr.Markdown("""1. Set the camera parameters (Roll, Pitch, Vertical FOV, ξ)""") + gr.Markdown("""2. Click "Compute PF-US" to generate the perspective field image""") + gr.Markdown("""3. Enter a prompt for the image generation""") + gr.Markdown("""4. Click "Generate Image" to create the final image""") + gr.Markdown("""---""") + + with gr.Row(): + with gr.Column(): + roll = gr.Slider(-90, 90, 0, label="Roll") + pitch = gr.Slider(-90, 90, 1, label="Pitch") + vfov = gr.Slider(15, 140, 50, label="Vertical FOV") + xi = gr.Slider(0.0, 1, 0.2, label="ξ") + prompt = gr.Textbox( + lines=4, + label="Prompt", + show_copy_button=True, + value="A colorful autumn park with leaves of orange, red, and yellow scattered across a winding path.", + ) + pf_btn = gr.Button("Compute PF-US", variant="primary") + with gr.Row(): + pf_img = gr.Image(height=1024 // 4, width=1024 // 4, label="PF-US") + condition_img = gr.Image( + height=1024 // 4, width=1024 // 4, label="Internal PF-US (RGB)" + ) + + with gr.Column(scale=2): + result_img = gr.Image(label="Generated Image", height=1024 // 2) + inf_btn = gr.Button("Generate Image", variant="primary") + gr.Markdown("""---""") + + pf_btn.click( + obtain_pf, inputs=[xi, roll, pitch, vfov], outputs=[condition_img, pf_img] + ) + inf_btn.click(inference, inputs=[prompt, condition_img], outputs=[result_img]) + +# 设置环境变量,禁用代理 +os.environ["HTTP_PROXY"] = "" +os.environ["HTTPS_PROXY"] = "" +os.environ["NO_PROXY"] = "localhost,127.0.0.1" + +# 启动Gradio应用,指定内联选项和服务器名称 +demo.launch( + server_name="0.0.0.0", + server_port=9946, + share=False, + favicon_path=None, + show_api=False, + prevent_thread_lock=True +) diff --git a/flux_style_shaper_api/perspective_fields/README.md b/flux_style_shaper_api/perspective_fields/README.md new file mode 100644 index 00000000..b82f2316 --- /dev/null +++ b/flux_style_shaper_api/perspective_fields/README.md @@ -0,0 +1,11 @@ +Code based on the original implementation from [jinlinyi/PerspectiveFields](https://github.com/jinlinyi/PerspectiveFields). + +**Citation:** +```bibtex +@inproceedings{jin2023perspective, + title={Perspective Fields for Single Image Camera Calibration}, + author={Linyi Jin and Jianming Zhang and Yannick Hold-Geoffroy and Oliver Wang and Kevin Matzen and Matthew Sticha and David F. Fouhey}, + booktitle = {CVPR}, + year={2023} +} +``` \ No newline at end of file diff --git a/flux_style_shaper_api/perspective_fields/pano_utils.py b/flux_style_shaper_api/perspective_fields/pano_utils.py new file mode 100644 index 00000000..b528dd10 --- /dev/null +++ b/flux_style_shaper_api/perspective_fields/pano_utils.py @@ -0,0 +1,261 @@ +####################################################################### +# Adapted code from https://github.com/jinlinyi/PerspectiveFields +####################################################################### +import torch +import numpy as np +from equilib import grid_sample +from sklearn.preprocessing import normalize +from perspective_fields.visualizer import VisualizerPerspective +from numpy.lib.scimath import sqrt as csqrt + +def deg2rad(deg): + """convert degrees to radians""" + return deg * np.pi / 180 + +def diskradius(xi, f): # compute the disk radius when the image is catadioptric + return np.sqrt(-(f * f) / (1 - xi * xi)) + +def crop_distortion(image360_path, f, xi, H, W, az, el, roll): + """ + Reference: https://github.com/dompm/spherical-distortion-dataset/blob/main/spherical_distortion/spherical_distortion.py + Crop distorted image with specified camera parameters + + Args: + image360_path (str): path to image which to crop from + f (float): focal_length of cropped image + xi: + H (int): height of cropped image + W (int): width of cropped image + az: camera rotation about camera frame y-axis of cropped image (degrees) + el: camera rotation about camera frame x-axis of cropped image (degrees) + roll: camera rotation about camera frame z-axis of cropped image (degrees) + Returns: + im (np.ndarray): cropped, distorted image + """ + + u0 = W / 2.0 + v0 = H / 2.0 + + grid_x, grid_y = np.meshgrid(list(range(W)), list(range(H))) + + image360 = image360_path.copy() + + ImPano_W = np.shape(image360)[1] + ImPano_H = np.shape(image360)[0] + x_ref = 1 + y_ref = 1 + + fmin = minfocal( + u0, v0, xi, x_ref, y_ref + ) # compute minimal focal length for the image to ve catadioptric with given xi + + # 1. Projection on the camera plane + + X_Cam = np.divide(grid_x - u0, f) + Y_Cam = -np.divide(grid_y - v0, f) + + # 2. Projection on the sphere + + AuxVal = np.multiply(X_Cam, X_Cam) + np.multiply(Y_Cam, Y_Cam) + + alpha_cam = np.real(xi + csqrt(1 + np.multiply((1 - xi * xi), AuxVal))) + + alpha_div = AuxVal + 1 + + alpha_cam_div = np.divide(alpha_cam, alpha_div) + + X_Sph = np.multiply(X_Cam, alpha_cam_div) + Y_Sph = np.multiply(Y_Cam, alpha_cam_div) + Z_Sph = alpha_cam_div - xi + + # 3. Rotation of the sphere + coords = np.vstack((X_Sph.ravel(), Y_Sph.ravel(), Z_Sph.ravel())) + rot_el = np.array( + [ + 1.0, + 0.0, + 0.0, + 0.0, + np.cos(deg2rad(el)), + -np.sin(deg2rad(el)), + 0.0, + np.sin(deg2rad(el)), + np.cos(deg2rad(el)), + ] + ).reshape((3, 3)) + rot_az = np.array( + [ + np.cos(deg2rad(az)), + 0.0, + -np.sin(deg2rad(az)), + 0.0, + 1.0, + 0.0, + np.sin(deg2rad(az)), + 0.0, + np.cos(deg2rad(az)), + ] + ).reshape((3, 3)) + rot_roll = np.array( + [ + np.cos(deg2rad(roll)), + np.sin(deg2rad(roll)), + 0.0, + -np.sin(deg2rad(roll)), + np.cos(deg2rad(roll)), + 0.0, + 0.0, + 0.0, + 1.0, + ] + ).reshape((3, 3)) + sph = rot_el.dot(rot_roll.T.dot(coords)) + sph = rot_az.dot(sph) + + sph = sph.reshape((3, H, W)).transpose((1, 2, 0)) + X_Sph, Y_Sph, Z_Sph = sph[:, :, 0], sph[:, :, 1], sph[:, :, 2] + + # 4. cart 2 sph + ntheta = np.arctan2(X_Sph, Z_Sph) + nphi = np.arctan2(Y_Sph, np.sqrt(Z_Sph**2 + X_Sph**2)) + + pi = np.pi + + # 5. Sphere to pano + min_theta = -pi + max_theta = pi + min_phi = -pi / 2.0 + max_phi = pi / 2.0 + + min_x = 0 + max_x = ImPano_W - 1.0 + min_y = 0 + max_y = ImPano_H - 1.0 + + ## for x + a = (max_theta - min_theta) / (max_x - min_x) + b = max_theta - a * max_x # from y=ax+b %% -a; + nx = (1.0 / a) * (ntheta - b) + + ## for y + a = (min_phi - max_phi) / (max_y - min_y) + b = max_phi - a * min_y # from y=ax+b %% -a; + ny = (1.0 / a) * (nphi - b) + lat = nphi.copy() + xy_map = np.stack((nx, ny)).transpose(1, 2, 0) + + # 6. Final step interpolation and mapping + # im = np.array(my_interpol.interp2linear(image360, nx, ny), dtype=np.uint8) + im = grid_sample.numpy_grid_sample.default( + image360.transpose(2, 0, 1), np.stack((ny, nx)) + ).transpose(1, 2, 0) + if ( + f < fmin + ): # if it is a catadioptric image, apply mask and a disk in the middle + r = diskradius(xi, f) + DIM = im.shape + ci = (np.round(DIM[0] / 2), np.round(DIM[1] / 2)) + xx, yy = np.meshgrid( + list(range(DIM[0])) - ci[0], list(range(DIM[1])) - ci[1] + ) + mask = np.double((np.multiply(xx, xx) + np.multiply(yy, yy)) < r * r) + mask_3channel = np.stack([mask, mask, mask], axis=-1).transpose((1, 0, 2)) + im = np.array(np.multiply(im, mask_3channel), dtype=np.uint8) + + col = nphi[:, W // 2] + zero_crossings_rows = np.where(np.diff(np.sign(col)))[0] + if len(zero_crossings_rows) >= 2: + print("WARNING | Number of zero crossings:", len(zero_crossings_rows)) + zero_crossings_rows = [zero_crossings_rows[0]] + + if len(zero_crossings_rows) == 0: + offset = np.nan + else: + assert col[zero_crossings_rows[0]] >= 0 + assert col[zero_crossings_rows[0] + 1] <= 0 + dy = col[zero_crossings_rows[0] + 1] - col[zero_crossings_rows[0]] + offset = zero_crossings_rows[0] - col[zero_crossings_rows[0]] / dy + assert col[zero_crossings_rows[0]] / dy <= 1.0 + # Reproject [nx, ny+epsilon] back + epsilon = 1e-5 + end_vector_x = nx.copy() + end_vector_y = ny.copy() - epsilon + # -5. pano to Sphere + a = (max_theta - min_theta) / (max_x - min_x) + b = max_theta - a * max_x # from y=ax+b %% -a; + ntheta_end = end_vector_x * a + b + ## for y + a = (min_phi - max_phi) / (max_y - min_y) + b = max_phi - a * min_y + nphi_end = end_vector_y * a + b + # -4. sph 2 cart + Y_Sph = np.sin(nphi) + X_Sph = np.cos(nphi_end) * np.sin(ntheta_end) + Z_Sph = np.cos(nphi_end) * np.cos(ntheta_end) + # -3. Reverse Rotation of the sphere + coords = np.vstack((X_Sph.ravel(), Y_Sph.ravel(), Z_Sph.ravel())) + sph = rot_roll.dot(rot_el.T.dot(rot_az.T.dot(coords))) + sph = sph.reshape((3, H, W)).transpose((1, 2, 0)) + X_Sph, Y_Sph, Z_Sph = sph[:, :, 0], sph[:, :, 1], sph[:, :, 2] + + # -1. Projection on the image plane + + X_Cam = X_Sph * f / (xi * csqrt(X_Sph**2 + Y_Sph**2 + Z_Sph**2) + Z_Sph) + u0 + Y_Cam = -Y_Sph * f / (xi * csqrt(X_Sph**2 + Y_Sph**2 + Z_Sph**2) + Z_Sph) + v0 + up = np.stack((X_Cam - grid_x, Y_Cam - grid_y)).transpose(1, 2, 0) + up = normalize(up.reshape(-1, 2)).reshape(up.shape) + + return im, ntheta, nphi, offset, up, lat, xy_map + +def minfocal(u0, v0, xi, xref=1, yref=1): + """compute the minimum focal for the image to be catadioptric given xi""" + value = -(1 - xi * xi) * ((xref - u0) * (xref - u0) + (yref - v0) * (yref - v0)) + + if value < 0: + return 0 + else: + return np.sqrt(value) * 1.0001 + +#------------------------------------------------ + +def draw_perspective_fields( + img_rgb, up, latimap, color=None, density=10, arrow_inv_len=20, return_img=True +): + """draw perspective field on top of input image + + Args: + img_rgb (np.ndarray): input image + up (np.ndarray): gravity field (h, w, 2) + latimap (np.ndarray): latitude map (h, w) (radians) + color ((float, float, float), optional): RGB color for up vectors. [0, 1] + Defaults to None. + density (int, optional): Value to control density of up vectors. + Each row has (width // density) vectors. + Each column has (height // density) vectors. + Defaults to 10. + arrow_inv_len (int, optional): Value to control vector length + Vector length set to (image plane diagonal // arrow_inv_len). + Defaults to 20. + return_img (bool, optional): bool to control if to return np array or VisImage + + Returns: + image blended with perspective fields. + """ + visualizer = VisualizerPerspective(img_rgb.copy()) + vis_output = visualizer.draw_lati(latimap) + if torch.is_tensor(up): + up = up.numpy().transpose(1, 2, 0) + im_h, im_w, _ = img_rgb.shape + x, y = np.meshgrid( + np.arange(0, im_w, im_w // density), np.arange(0, im_h, im_h // density) + ) + x, y = x.ravel(), y.ravel() + arrow_len = np.sqrt(im_w**2 + im_h**2) // arrow_inv_len + end = up[y, x, :] * arrow_len + if color is None: + color = (0, 1, 0) + vis_output = visualizer.draw_arrow(x, y, end[:, 0], -end[:, 1], color=color) + if return_img: + return vis_output.get_image() + else: + return vis_output diff --git a/flux_style_shaper_api/perspective_fields/visualizer.py b/flux_style_shaper_api/perspective_fields/visualizer.py new file mode 100644 index 00000000..f6f9d4bf --- /dev/null +++ b/flux_style_shaper_api/perspective_fields/visualizer.py @@ -0,0 +1,283 @@ +####################################################################### +# Adapted code from https://github.com/jinlinyi/PerspectiveFields +####################################################################### +# Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg + + +class VisImage: + def __init__(self, img, scale=1.0): + """ + Args: + img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. + scale (float): scale the input image + """ + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + """ + Args: + Same as in :meth:`__init__()`. + + Returns: + fig (matplotlib.pyplot.figure): top level container for all the image plot elements. + ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. + """ + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + """ + Args: + img: same as in __init__ + """ + img = img.astype("uint8") + self.ax.imshow( + img, extent=(0, self.width, self.height, 0), interpolation="nearest" + ) + + def save(self, filepath): + """ + Args: + filepath (str): a string that contains the absolute path, including the file name, where + the visualized image will be saved. + """ + self.fig.savefig(filepath) + + def get_image(self): + """ + Returns: + ndarray: + the visualized image of shape (H, W, 3) (RGB) in uint8 type. + The shape is scaled w.r.t the input image using the given `scale` argument. + """ + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + # buf = io.BytesIO() # works for cairo backend + # canvas.print_rgba(buf) + # width, height = self.width, self.height + # s = buf.getvalue() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + """ + Visualizer that draws data about detection/segmentation on images. + + It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` + that draw primitive objects to images, as well as high-level wrappers like + `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` + that draw composite data in some pre-defined style. + + Note that the exact visualization style for the high-level wrappers are subject to change. + Style such as color, opacity, label contents, visibility of labels, or even the visibility + of objects themselves (e.g. when the object is too small) may change according + to different heuristics, as long as the results still look visually reasonable. + + To obtain a consistent style, you can implement custom drawing functions with the + abovementioned primitive methods instead. If you need more customized visualization + styles, you can process the data yourself following their format documented in + tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not + intend to satisfy everyone's preference on drawing styles. + + This visualizer focuses on high rendering quality rather than performance. It is not + designed to be used for real-time applications. + """ + + def __init__(self, img_rgb, scale=1.0, font_size_scale=1.0): + """ + Args: + img_rgb: a numpy array of shape (H, W, C), where H and W correspond to + the height and width of the image respectively. C is the number of + color channels. The image is required to be in RGB format since that + is a requirement of the Matplotlib library. The image is also expected + to be in the range [0, 255]. + font_size_scale: extra scaling of font size on top of default font size + """ + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 9 + self._default_font_size = ( + max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale) + * font_size_scale + ) + + """ + Primitive drawing functions: + """ + + def draw_text( + self, + text, + position, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + """ + Args: + text (str): class label + position (tuple): a tuple of the x and y coordinates to place text on image. + font_size (int, optional): font of the text. If not provided, a font size + proportional to the image width is calculated and used. + color: color of the text. Refer to `matplotlib.colors` for full list + of formats that are accepted. + horizontal_alignment (str): see `matplotlib.text.Text` + rotation: rotation angle in degrees CCW + + Returns: + output (VisImage): image object with text drawn. + """ + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + + x, y = position + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + family="sans-serif", + bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def get_output(self): + """ + Returns: + output (VisImage): the image output containing the visualizations added + to the image. + """ + return self.output + + +class VisualizerPerspective(Visualizer): + def draw_arrow( + self, + x_pos, + y_pos, + x_direct, + y_direct, + color=None, + linestyle="-", + linewidth=None, + ): + """ + Args: + x_data (list[int]): a list containing x values of all the points being drawn. + Length of list should match the length of y_data. + y_data (list[int]): a list containing y values of all the points being drawn. + Length of list should match the length of x_data. + color: color of the line. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + linestyle: style of the line. Refer to `matplotlib.lines.Line2D` + for a full list of formats that are accepted. + linewidth (float or None): width of the line. When it's None, + a default value will be computed and used. + + Returns: + output (VisImage): image object with line drawn. + """ + if linewidth is None: + linewidth = self._default_font_size / 3 + linewidth = max(linewidth, 1) + self.output.ax.quiver( + x_pos, + y_pos, + x_direct, + y_direct, + color=color, + scale_units="xy", + scale=1, + antialiased=True, + headaxislength=3.5, + linewidths=0.1, # , width=0.01 + ) + return self.output + + def draw_lati( + self, latimap, alpha_contourf=0.4, alpha_contour=0.9, contour_only=False + ): + """Blend latitude map""" + height, width = latimap.shape + y, x = np.mgrid[0:height, 0:width] + cmap = plt.get_cmap("seismic") + bands = 20 + levels = np.linspace(-np.pi / 2, np.pi / 2, bands - 1) + if not contour_only: + pp = self.output.ax.contourf( + x, + y, + latimap, + levels=levels, + cmap=cmap, + alpha=alpha_contourf, + antialiased=True, + ) + pp2 = self.output.ax.contour( + x, + y, + latimap, + pp.levels, + cmap=cmap, + alpha=alpha_contour, + antialiased=True, + linewidths=5, + ) + try: + for c in pp2.collections: + c.set_linestyle("solid") + except AttributeError: + pass + else: + # only plot central contour + pp = self.output.ax.contour( + x, + y, + latimap, + levels=[0], + cmap=cmap, + alpha=alpha_contour, + antialiased=True, + linewidths=15, + ) + return self.output