This commit is contained in:
zhxiao 2025-05-23 11:24:02 +08:00
parent bf877c81fe
commit 3777ea04c9
4 changed files with 808 additions and 0 deletions

View File

@ -0,0 +1,253 @@
import gradio as gr
import numpy as np
import torch
import os
from PIL import Image
from perspective_fields import pano_utils as pu
from diffusers import (
AutoencoderKL,
ControlNetModel,
FluxControlNetPipeline,
FlowMatchEulerDiscreteScheduler,
FluxControlNetModel,
)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import T5EncoderModel, T5TokenizerFast
from safetensors.torch import load_file
# 加载本地模型文件路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(current_dir)
# 修正路径格式使用os.path.join确保路径分隔符正确
flux_model_dir = os.path.join(project_dir, "models", "diffusion_models")
print(f"Loading models from:")
print(f"Flux dir: {flux_model_dir}")
# 创建缓存目录(如果不存在)
os.makedirs(os.path.join(project_dir, "cache"), exist_ok=True)
# 加载所有需要的组件
print("Loading components...")
# 1. 加载调度器
scheduler = FlowMatchEulerDiscreteScheduler.from_config(
{
"num_train_timesteps": 1000
}
)
# 5. 加载ControlNet - 使用与FLUX.1-dev兼容的模型
print("Loading ControlNet compatible with FLUX.1-dev...")
controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16,cache_dir=os.path.join(project_dir, "cache"))
# 2. 加载VAE
try:
vae = AutoencoderKL.from_pretrained(
"diffusers/FLUX.1-vae",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
except Exception as e:
print(f"Error loading VAE: {e}")
print("Loading default VAE...")
vae = AutoencoderKL.from_pretrained(
"diffusers/FLUX.1-vae",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
# 3. 加载文本编码器和标记器
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=os.path.join(project_dir, "cache")
)
# 4. 加载第二个文本编码器和标记器
print("从本地加载预先下载的T5模型...")
# 指定本地模型路径 - 假设已经手动下载了魔搭的模型
t5_model_path = os.path.join(project_dir, "models", "t5_xxl") # 根据实际路径调整
# 检查模型路径是否存在
if not os.path.exists(t5_model_path):
print(f"警告: 本地模型路径 {t5_model_path} 不存在。请先手动下载魔搭的T5-XXL模型到此路径。")
print("继续尝试从HuggingFace下载但这可能会很慢...")
t5_model_path = "google/t5-v1_1-xxl"
text_encoder_2 = T5EncoderModel.from_pretrained(
t5_model_path, # 使用本地路径
torch_dtype=torch.bfloat16,
cache_dir=os.path.join(project_dir, "cache")
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
t5_model_path, # 使用本地路径
cache_dir=os.path.join(project_dir, "cache")
)
# 6. 加载特征提取器和图像编码器
feature_extractor = CLIPImageProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=os.path.join(project_dir, "cache")
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
# 7. 加载PipelineImageProcessor (transformer)
base_model = "black-forest-labs/FLUX.1-dev"
# 不加载transformer组件因为这会导致类型错误
# pipeline会自动初始化正确类型的transformer
# 8. 加载Pipeline - 直接从HuggingFace加载
try:
print("Loading FLUX model from HuggingFace...")
# 设置缓存相关环境变量
os.environ["HF_HOME"] = os.path.join(project_dir, "cache")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(project_dir, "cache", "transformers")
os.environ["DIFFUSERS_CACHE"] = os.path.join(project_dir, "cache", "diffusers")
# 设置缓存路径参数
cache_path = os.path.join(project_dir, "cache")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
controlnet=controlnet,
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
low_cpu_mem_usage=False,
cache_dir=cache_path,
resume_download=True, # 如果下载中断,尝试继续下载
force_download=False # 如果已缓存,不重新下载
)
except Exception as e:
print(f"Error loading from HuggingFace: {e}")
print("Trying alternative configuration...")
# 启用CPU卸载以减少GPU内存使用
pipe.enable_model_cpu_offload()
def inference(prompt, pf_image, n_steps=2, seed=13):
"""Generates an image based on the given prompt and perspective field image."""
pf = Image.fromarray(
np.concatenate(
[np.expand_dims(pf_image[:, :, i], axis=-1) for i in [2, 1, 0]], axis=2
)
)
pf_condition = pf.resize((1024, 1024))
generator = torch.manual_seed(seed)
print("Control image size:", pf_condition.size) # 应该是 (1024, 1024)
return pipe(
prompt=prompt,
num_inference_steps=n_steps,
generator=generator,
control_image=pf_condition, # 只保留ControlNet的输入参数
# 移除不支持的参数
# image=pf_condition, # 需要为img2img提供一个基础图像
# strength=0.8, # 控制处理强度
).images[0]
def obtain_pf(xi, roll, pitch, vfov):
"""Computes perspective fields given camera parameters."""
w, h = (1024, 1024)
equi_img = np.zeros((h, w, 3), dtype=np.uint8)
x = -np.sin(np.radians(vfov / 2))
z = np.sqrt(1 - x**2)
f_px_effective = -0.5 * (w / 2) * (xi + z) / x
crop, _, _, _, up, lat, _ = pu.crop_distortion(
equi_img, f=f_px_effective, xi=xi, H=h, W=w, az=10, el=-pitch, roll=roll
)
gravity = (up + 1) * 127.5
latitude = np.expand_dims((np.degrees(lat) + 90) * (255 / 180), axis=-1)
pf_image = np.concatenate([gravity, latitude], axis=2).astype(np.uint8)
# 尝试生成blend图像如果失败则创建一个替代的blend图像
try:
blend = pu.draw_perspective_fields(crop, up, np.radians(np.degrees(lat)))
except Exception as e:
print(f"Error generating blend image: {e}")
# 创建一个替代的blend图像
blend = np.zeros_like(pf_image)
# 添加一些基本的可视化
blend[:,:,0] = (np.degrees(lat) + 90) * (255 / 180) # 使用latitude作为红色通道
blend[:,:,1] = (up + 1) * 127.5 # 使用up作为绿色通道
blend[:,:,2] = 128 # 蓝色通道设置为中性值
return pf_image, blend
# Gradio UI
demo = gr.Blocks(theme=gr.themes.Soft())
with demo:
gr.Markdown("""---""")
gr.Markdown("""# PreciseCam: Precise Camera Control for Text-to-Image Generation""")
gr.Markdown("""1. Set the camera parameters (Roll, Pitch, Vertical FOV, ξ)""")
gr.Markdown("""2. Click "Compute PF-US" to generate the perspective field image""")
gr.Markdown("""3. Enter a prompt for the image generation""")
gr.Markdown("""4. Click "Generate Image" to create the final image""")
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
roll = gr.Slider(-90, 90, 0, label="Roll")
pitch = gr.Slider(-90, 90, 1, label="Pitch")
vfov = gr.Slider(15, 140, 50, label="Vertical FOV")
xi = gr.Slider(0.0, 1, 0.2, label="ξ")
prompt = gr.Textbox(
lines=4,
label="Prompt",
show_copy_button=True,
value="A colorful autumn park with leaves of orange, red, and yellow scattered across a winding path.",
)
pf_btn = gr.Button("Compute PF-US", variant="primary")
with gr.Row():
pf_img = gr.Image(height=1024 // 4, width=1024 // 4, label="PF-US")
condition_img = gr.Image(
height=1024 // 4, width=1024 // 4, label="Internal PF-US (RGB)"
)
with gr.Column(scale=2):
result_img = gr.Image(label="Generated Image", height=1024 // 2)
inf_btn = gr.Button("Generate Image", variant="primary")
gr.Markdown("""---""")
pf_btn.click(
obtain_pf, inputs=[xi, roll, pitch, vfov], outputs=[condition_img, pf_img]
)
inf_btn.click(inference, inputs=[prompt, condition_img], outputs=[result_img])
# 设置环境变量,禁用代理
os.environ["HTTP_PROXY"] = ""
os.environ["HTTPS_PROXY"] = ""
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
# 启动Gradio应用指定内联选项和服务器名称
demo.launch(
server_name="0.0.0.0",
server_port=9946,
share=False,
favicon_path=None,
show_api=False,
prevent_thread_lock=True
)

View File

@ -0,0 +1,11 @@
Code based on the original implementation from [jinlinyi/PerspectiveFields](https://github.com/jinlinyi/PerspectiveFields).
**Citation:**
```bibtex
@inproceedings{jin2023perspective,
title={Perspective Fields for Single Image Camera Calibration},
author={Linyi Jin and Jianming Zhang and Yannick Hold-Geoffroy and Oliver Wang and Kevin Matzen and Matthew Sticha and David F. Fouhey},
booktitle = {CVPR},
year={2023}
}
```

View File

@ -0,0 +1,261 @@
#######################################################################
# Adapted code from https://github.com/jinlinyi/PerspectiveFields
#######################################################################
import torch
import numpy as np
from equilib import grid_sample
from sklearn.preprocessing import normalize
from perspective_fields.visualizer import VisualizerPerspective
from numpy.lib.scimath import sqrt as csqrt
def deg2rad(deg):
"""convert degrees to radians"""
return deg * np.pi / 180
def diskradius(xi, f): # compute the disk radius when the image is catadioptric
return np.sqrt(-(f * f) / (1 - xi * xi))
def crop_distortion(image360_path, f, xi, H, W, az, el, roll):
"""
Reference: https://github.com/dompm/spherical-distortion-dataset/blob/main/spherical_distortion/spherical_distortion.py
Crop distorted image with specified camera parameters
Args:
image360_path (str): path to image which to crop from
f (float): focal_length of cropped image
xi:
H (int): height of cropped image
W (int): width of cropped image
az: camera rotation about camera frame y-axis of cropped image (degrees)
el: camera rotation about camera frame x-axis of cropped image (degrees)
roll: camera rotation about camera frame z-axis of cropped image (degrees)
Returns:
im (np.ndarray): cropped, distorted image
"""
u0 = W / 2.0
v0 = H / 2.0
grid_x, grid_y = np.meshgrid(list(range(W)), list(range(H)))
image360 = image360_path.copy()
ImPano_W = np.shape(image360)[1]
ImPano_H = np.shape(image360)[0]
x_ref = 1
y_ref = 1
fmin = minfocal(
u0, v0, xi, x_ref, y_ref
) # compute minimal focal length for the image to ve catadioptric with given xi
# 1. Projection on the camera plane
X_Cam = np.divide(grid_x - u0, f)
Y_Cam = -np.divide(grid_y - v0, f)
# 2. Projection on the sphere
AuxVal = np.multiply(X_Cam, X_Cam) + np.multiply(Y_Cam, Y_Cam)
alpha_cam = np.real(xi + csqrt(1 + np.multiply((1 - xi * xi), AuxVal)))
alpha_div = AuxVal + 1
alpha_cam_div = np.divide(alpha_cam, alpha_div)
X_Sph = np.multiply(X_Cam, alpha_cam_div)
Y_Sph = np.multiply(Y_Cam, alpha_cam_div)
Z_Sph = alpha_cam_div - xi
# 3. Rotation of the sphere
coords = np.vstack((X_Sph.ravel(), Y_Sph.ravel(), Z_Sph.ravel()))
rot_el = np.array(
[
1.0,
0.0,
0.0,
0.0,
np.cos(deg2rad(el)),
-np.sin(deg2rad(el)),
0.0,
np.sin(deg2rad(el)),
np.cos(deg2rad(el)),
]
).reshape((3, 3))
rot_az = np.array(
[
np.cos(deg2rad(az)),
0.0,
-np.sin(deg2rad(az)),
0.0,
1.0,
0.0,
np.sin(deg2rad(az)),
0.0,
np.cos(deg2rad(az)),
]
).reshape((3, 3))
rot_roll = np.array(
[
np.cos(deg2rad(roll)),
np.sin(deg2rad(roll)),
0.0,
-np.sin(deg2rad(roll)),
np.cos(deg2rad(roll)),
0.0,
0.0,
0.0,
1.0,
]
).reshape((3, 3))
sph = rot_el.dot(rot_roll.T.dot(coords))
sph = rot_az.dot(sph)
sph = sph.reshape((3, H, W)).transpose((1, 2, 0))
X_Sph, Y_Sph, Z_Sph = sph[:, :, 0], sph[:, :, 1], sph[:, :, 2]
# 4. cart 2 sph
ntheta = np.arctan2(X_Sph, Z_Sph)
nphi = np.arctan2(Y_Sph, np.sqrt(Z_Sph**2 + X_Sph**2))
pi = np.pi
# 5. Sphere to pano
min_theta = -pi
max_theta = pi
min_phi = -pi / 2.0
max_phi = pi / 2.0
min_x = 0
max_x = ImPano_W - 1.0
min_y = 0
max_y = ImPano_H - 1.0
## for x
a = (max_theta - min_theta) / (max_x - min_x)
b = max_theta - a * max_x # from y=ax+b %% -a;
nx = (1.0 / a) * (ntheta - b)
## for y
a = (min_phi - max_phi) / (max_y - min_y)
b = max_phi - a * min_y # from y=ax+b %% -a;
ny = (1.0 / a) * (nphi - b)
lat = nphi.copy()
xy_map = np.stack((nx, ny)).transpose(1, 2, 0)
# 6. Final step interpolation and mapping
# im = np.array(my_interpol.interp2linear(image360, nx, ny), dtype=np.uint8)
im = grid_sample.numpy_grid_sample.default(
image360.transpose(2, 0, 1), np.stack((ny, nx))
).transpose(1, 2, 0)
if (
f < fmin
): # if it is a catadioptric image, apply mask and a disk in the middle
r = diskradius(xi, f)
DIM = im.shape
ci = (np.round(DIM[0] / 2), np.round(DIM[1] / 2))
xx, yy = np.meshgrid(
list(range(DIM[0])) - ci[0], list(range(DIM[1])) - ci[1]
)
mask = np.double((np.multiply(xx, xx) + np.multiply(yy, yy)) < r * r)
mask_3channel = np.stack([mask, mask, mask], axis=-1).transpose((1, 0, 2))
im = np.array(np.multiply(im, mask_3channel), dtype=np.uint8)
col = nphi[:, W // 2]
zero_crossings_rows = np.where(np.diff(np.sign(col)))[0]
if len(zero_crossings_rows) >= 2:
print("WARNING | Number of zero crossings:", len(zero_crossings_rows))
zero_crossings_rows = [zero_crossings_rows[0]]
if len(zero_crossings_rows) == 0:
offset = np.nan
else:
assert col[zero_crossings_rows[0]] >= 0
assert col[zero_crossings_rows[0] + 1] <= 0
dy = col[zero_crossings_rows[0] + 1] - col[zero_crossings_rows[0]]
offset = zero_crossings_rows[0] - col[zero_crossings_rows[0]] / dy
assert col[zero_crossings_rows[0]] / dy <= 1.0
# Reproject [nx, ny+epsilon] back
epsilon = 1e-5
end_vector_x = nx.copy()
end_vector_y = ny.copy() - epsilon
# -5. pano to Sphere
a = (max_theta - min_theta) / (max_x - min_x)
b = max_theta - a * max_x # from y=ax+b %% -a;
ntheta_end = end_vector_x * a + b
## for y
a = (min_phi - max_phi) / (max_y - min_y)
b = max_phi - a * min_y
nphi_end = end_vector_y * a + b
# -4. sph 2 cart
Y_Sph = np.sin(nphi)
X_Sph = np.cos(nphi_end) * np.sin(ntheta_end)
Z_Sph = np.cos(nphi_end) * np.cos(ntheta_end)
# -3. Reverse Rotation of the sphere
coords = np.vstack((X_Sph.ravel(), Y_Sph.ravel(), Z_Sph.ravel()))
sph = rot_roll.dot(rot_el.T.dot(rot_az.T.dot(coords)))
sph = sph.reshape((3, H, W)).transpose((1, 2, 0))
X_Sph, Y_Sph, Z_Sph = sph[:, :, 0], sph[:, :, 1], sph[:, :, 2]
# -1. Projection on the image plane
X_Cam = X_Sph * f / (xi * csqrt(X_Sph**2 + Y_Sph**2 + Z_Sph**2) + Z_Sph) + u0
Y_Cam = -Y_Sph * f / (xi * csqrt(X_Sph**2 + Y_Sph**2 + Z_Sph**2) + Z_Sph) + v0
up = np.stack((X_Cam - grid_x, Y_Cam - grid_y)).transpose(1, 2, 0)
up = normalize(up.reshape(-1, 2)).reshape(up.shape)
return im, ntheta, nphi, offset, up, lat, xy_map
def minfocal(u0, v0, xi, xref=1, yref=1):
"""compute the minimum focal for the image to be catadioptric given xi"""
value = -(1 - xi * xi) * ((xref - u0) * (xref - u0) + (yref - v0) * (yref - v0))
if value < 0:
return 0
else:
return np.sqrt(value) * 1.0001
#------------------------------------------------
def draw_perspective_fields(
img_rgb, up, latimap, color=None, density=10, arrow_inv_len=20, return_img=True
):
"""draw perspective field on top of input image
Args:
img_rgb (np.ndarray): input image
up (np.ndarray): gravity field (h, w, 2)
latimap (np.ndarray): latitude map (h, w) (radians)
color ((float, float, float), optional): RGB color for up vectors. [0, 1]
Defaults to None.
density (int, optional): Value to control density of up vectors.
Each row has (width // density) vectors.
Each column has (height // density) vectors.
Defaults to 10.
arrow_inv_len (int, optional): Value to control vector length
Vector length set to (image plane diagonal // arrow_inv_len).
Defaults to 20.
return_img (bool, optional): bool to control if to return np array or VisImage
Returns:
image blended with perspective fields.
"""
visualizer = VisualizerPerspective(img_rgb.copy())
vis_output = visualizer.draw_lati(latimap)
if torch.is_tensor(up):
up = up.numpy().transpose(1, 2, 0)
im_h, im_w, _ = img_rgb.shape
x, y = np.meshgrid(
np.arange(0, im_w, im_w // density), np.arange(0, im_h, im_h // density)
)
x, y = x.ravel(), y.ravel()
arrow_len = np.sqrt(im_w**2 + im_h**2) // arrow_inv_len
end = up[y, x, :] * arrow_len
if color is None:
color = (0, 1, 0)
vis_output = visualizer.draw_arrow(x, y, end[:, 0], -end[:, 1], color=color)
if return_img:
return vis_output.get_image()
else:
return vis_output

View File

@ -0,0 +1,283 @@
#######################################################################
# Adapted code from https://github.com/jinlinyi/PerspectiveFields
#######################################################################
# Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
import matplotlib.colors as mplc
import matplotlib.figure as mplfigure
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg
class VisImage:
def __init__(self, img, scale=1.0):
"""
Args:
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
scale (float): scale the input image
"""
self.img = img
self.scale = scale
self.width, self.height = img.shape[1], img.shape[0]
self._setup_figure(img)
def _setup_figure(self, img):
"""
Args:
Same as in :meth:`__init__()`.
Returns:
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
"""
fig = mplfigure.Figure(frameon=False)
self.dpi = fig.get_dpi()
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches(
(self.width * self.scale + 1e-2) / self.dpi,
(self.height * self.scale + 1e-2) / self.dpi,
)
self.canvas = FigureCanvasAgg(fig)
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
self.fig = fig
self.ax = ax
self.reset_image(img)
def reset_image(self, img):
"""
Args:
img: same as in __init__
"""
img = img.astype("uint8")
self.ax.imshow(
img, extent=(0, self.width, self.height, 0), interpolation="nearest"
)
def save(self, filepath):
"""
Args:
filepath (str): a string that contains the absolute path, including the file name, where
the visualized image will be saved.
"""
self.fig.savefig(filepath)
def get_image(self):
"""
Returns:
ndarray:
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
The shape is scaled w.r.t the input image using the given `scale` argument.
"""
canvas = self.canvas
s, (width, height) = canvas.print_to_buffer()
# buf = io.BytesIO() # works for cairo backend
# canvas.print_rgba(buf)
# width, height = self.width, self.height
# s = buf.getvalue()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
return rgb.astype("uint8")
class Visualizer:
"""
Visualizer that draws data about detection/segmentation on images.
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
that draw primitive objects to images, as well as high-level wrappers like
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
that draw composite data in some pre-defined style.
Note that the exact visualization style for the high-level wrappers are subject to change.
Style such as color, opacity, label contents, visibility of labels, or even the visibility
of objects themselves (e.g. when the object is too small) may change according
to different heuristics, as long as the results still look visually reasonable.
To obtain a consistent style, you can implement custom drawing functions with the
abovementioned primitive methods instead. If you need more customized visualization
styles, you can process the data yourself following their format documented in
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
intend to satisfy everyone's preference on drawing styles.
This visualizer focuses on high rendering quality rather than performance. It is not
designed to be used for real-time applications.
"""
def __init__(self, img_rgb, scale=1.0, font_size_scale=1.0):
"""
Args:
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
the height and width of the image respectively. C is the number of
color channels. The image is required to be in RGB format since that
is a requirement of the Matplotlib library. The image is also expected
to be in the range [0, 255].
font_size_scale: extra scaling of font size on top of default font size
"""
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
self.output = VisImage(self.img, scale=scale)
self.cpu_device = torch.device("cpu")
# too small texts are useless, therefore clamp to 9
self._default_font_size = (
max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale)
* font_size_scale
)
"""
Primitive drawing functions:
"""
def draw_text(
self,
text,
position,
*,
font_size=None,
color="g",
horizontal_alignment="center",
rotation=0,
):
"""
Args:
text (str): class label
position (tuple): a tuple of the x and y coordinates to place text on image.
font_size (int, optional): font of the text. If not provided, a font size
proportional to the image width is calculated and used.
color: color of the text. Refer to `matplotlib.colors` for full list
of formats that are accepted.
horizontal_alignment (str): see `matplotlib.text.Text`
rotation: rotation angle in degrees CCW
Returns:
output (VisImage): image object with text drawn.
"""
if not font_size:
font_size = self._default_font_size
# since the text background is dark, we don't want the text to be dark
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
x, y = position
self.output.ax.text(
x,
y,
text,
size=font_size * self.output.scale,
family="sans-serif",
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
verticalalignment="top",
horizontalalignment=horizontal_alignment,
color=color,
zorder=10,
rotation=rotation,
)
return self.output
def get_output(self):
"""
Returns:
output (VisImage): the image output containing the visualizations added
to the image.
"""
return self.output
class VisualizerPerspective(Visualizer):
def draw_arrow(
self,
x_pos,
y_pos,
x_direct,
y_direct,
color=None,
linestyle="-",
linewidth=None,
):
"""
Args:
x_data (list[int]): a list containing x values of all the points being drawn.
Length of list should match the length of y_data.
y_data (list[int]): a list containing y values of all the points being drawn.
Length of list should match the length of x_data.
color: color of the line. Refer to `matplotlib.colors` for a full list of
formats that are accepted.
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
for a full list of formats that are accepted.
linewidth (float or None): width of the line. When it's None,
a default value will be computed and used.
Returns:
output (VisImage): image object with line drawn.
"""
if linewidth is None:
linewidth = self._default_font_size / 3
linewidth = max(linewidth, 1)
self.output.ax.quiver(
x_pos,
y_pos,
x_direct,
y_direct,
color=color,
scale_units="xy",
scale=1,
antialiased=True,
headaxislength=3.5,
linewidths=0.1, # , width=0.01
)
return self.output
def draw_lati(
self, latimap, alpha_contourf=0.4, alpha_contour=0.9, contour_only=False
):
"""Blend latitude map"""
height, width = latimap.shape
y, x = np.mgrid[0:height, 0:width]
cmap = plt.get_cmap("seismic")
bands = 20
levels = np.linspace(-np.pi / 2, np.pi / 2, bands - 1)
if not contour_only:
pp = self.output.ax.contourf(
x,
y,
latimap,
levels=levels,
cmap=cmap,
alpha=alpha_contourf,
antialiased=True,
)
pp2 = self.output.ax.contour(
x,
y,
latimap,
pp.levels,
cmap=cmap,
alpha=alpha_contour,
antialiased=True,
linewidths=5,
)
try:
for c in pp2.collections:
c.set_linestyle("solid")
except AttributeError:
pass
else:
# only plot central contour
pp = self.output.ax.contour(
x,
y,
latimap,
levels=[0],
cmap=cmap,
alpha=alpha_contour,
antialiased=True,
linewidths=15,
)
return self.output