1231
This commit is contained in:
parent
bf877c81fe
commit
3777ea04c9
|
@ -0,0 +1,253 @@
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
from perspective_fields import pano_utils as pu
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
|
ControlNetModel,
|
||||||
|
FluxControlNetPipeline,
|
||||||
|
FlowMatchEulerDiscreteScheduler,
|
||||||
|
FluxControlNetModel,
|
||||||
|
)
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
from transformers import T5EncoderModel, T5TokenizerFast
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
# 加载本地模型文件路径
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_dir = os.path.dirname(current_dir)
|
||||||
|
|
||||||
|
# 修正路径格式,使用os.path.join确保路径分隔符正确
|
||||||
|
flux_model_dir = os.path.join(project_dir, "models", "diffusion_models")
|
||||||
|
|
||||||
|
print(f"Loading models from:")
|
||||||
|
print(f"Flux dir: {flux_model_dir}")
|
||||||
|
|
||||||
|
# 创建缓存目录(如果不存在)
|
||||||
|
os.makedirs(os.path.join(project_dir, "cache"), exist_ok=True)
|
||||||
|
|
||||||
|
# 加载所有需要的组件
|
||||||
|
print("Loading components...")
|
||||||
|
|
||||||
|
# 1. 加载调度器
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||||||
|
{
|
||||||
|
"num_train_timesteps": 1000
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 加载ControlNet - 使用与FLUX.1-dev兼容的模型
|
||||||
|
print("Loading ControlNet compatible with FLUX.1-dev...")
|
||||||
|
controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
||||||
|
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16,cache_dir=os.path.join(project_dir, "cache"))
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 加载VAE
|
||||||
|
try:
|
||||||
|
vae = AutoencoderKL.from_pretrained(
|
||||||
|
"diffusers/FLUX.1-vae",
|
||||||
|
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading VAE: {e}")
|
||||||
|
print("Loading default VAE...")
|
||||||
|
vae = AutoencoderKL.from_pretrained(
|
||||||
|
"diffusers/FLUX.1-vae",
|
||||||
|
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 加载文本编码器和标记器
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
|
"openai/clip-vit-large-patch14",
|
||||||
|
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
|
"openai/clip-vit-large-patch14",
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 加载第二个文本编码器和标记器
|
||||||
|
print("从本地加载预先下载的T5模型...")
|
||||||
|
# 指定本地模型路径 - 假设已经手动下载了魔搭的模型
|
||||||
|
t5_model_path = os.path.join(project_dir, "models", "t5_xxl") # 根据实际路径调整
|
||||||
|
|
||||||
|
# 检查模型路径是否存在
|
||||||
|
if not os.path.exists(t5_model_path):
|
||||||
|
print(f"警告: 本地模型路径 {t5_model_path} 不存在。请先手动下载魔搭的T5-XXL模型到此路径。")
|
||||||
|
print("继续尝试从HuggingFace下载,但这可能会很慢...")
|
||||||
|
t5_model_path = "google/t5-v1_1-xxl"
|
||||||
|
|
||||||
|
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||||
|
t5_model_path, # 使用本地路径
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
tokenizer_2 = T5TokenizerFast.from_pretrained(
|
||||||
|
t5_model_path, # 使用本地路径
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 6. 加载特征提取器和图像编码器
|
||||||
|
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||||
|
"openai/clip-vit-large-patch14",
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||||
|
"openai/clip-vit-large-patch14",
|
||||||
|
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||||||
|
cache_dir=os.path.join(project_dir, "cache")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. 加载PipelineImageProcessor (transformer)
|
||||||
|
base_model = "black-forest-labs/FLUX.1-dev"
|
||||||
|
|
||||||
|
# 不加载transformer组件,因为这会导致类型错误
|
||||||
|
# pipeline会自动初始化正确类型的transformer
|
||||||
|
|
||||||
|
# 8. 加载Pipeline - 直接从HuggingFace加载
|
||||||
|
try:
|
||||||
|
print("Loading FLUX model from HuggingFace...")
|
||||||
|
# 设置缓存相关环境变量
|
||||||
|
os.environ["HF_HOME"] = os.path.join(project_dir, "cache")
|
||||||
|
os.environ["TRANSFORMERS_CACHE"] = os.path.join(project_dir, "cache", "transformers")
|
||||||
|
os.environ["DIFFUSERS_CACHE"] = os.path.join(project_dir, "cache", "diffusers")
|
||||||
|
# 设置缓存路径参数
|
||||||
|
cache_path = os.path.join(project_dir, "cache")
|
||||||
|
|
||||||
|
pipe = FluxControlNetPipeline.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
text_encoder_2=text_encoder_2,
|
||||||
|
tokenizer_2=tokenizer_2,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
controlnet=controlnet,
|
||||||
|
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||||||
|
low_cpu_mem_usage=False,
|
||||||
|
cache_dir=cache_path,
|
||||||
|
resume_download=True, # 如果下载中断,尝试继续下载
|
||||||
|
force_download=False # 如果已缓存,不重新下载
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading from HuggingFace: {e}")
|
||||||
|
print("Trying alternative configuration...")
|
||||||
|
|
||||||
|
# 启用CPU卸载以减少GPU内存使用
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
def inference(prompt, pf_image, n_steps=2, seed=13):
|
||||||
|
"""Generates an image based on the given prompt and perspective field image."""
|
||||||
|
pf = Image.fromarray(
|
||||||
|
np.concatenate(
|
||||||
|
[np.expand_dims(pf_image[:, :, i], axis=-1) for i in [2, 1, 0]], axis=2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pf_condition = pf.resize((1024, 1024))
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
print("Control image size:", pf_condition.size) # 应该是 (1024, 1024)
|
||||||
|
return pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=n_steps,
|
||||||
|
generator=generator,
|
||||||
|
control_image=pf_condition, # 只保留ControlNet的输入参数
|
||||||
|
# 移除不支持的参数
|
||||||
|
# image=pf_condition, # 需要为img2img提供一个基础图像
|
||||||
|
# strength=0.8, # 控制处理强度
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
|
||||||
|
def obtain_pf(xi, roll, pitch, vfov):
|
||||||
|
"""Computes perspective fields given camera parameters."""
|
||||||
|
w, h = (1024, 1024)
|
||||||
|
equi_img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||||
|
x = -np.sin(np.radians(vfov / 2))
|
||||||
|
z = np.sqrt(1 - x**2)
|
||||||
|
f_px_effective = -0.5 * (w / 2) * (xi + z) / x
|
||||||
|
|
||||||
|
crop, _, _, _, up, lat, _ = pu.crop_distortion(
|
||||||
|
equi_img, f=f_px_effective, xi=xi, H=h, W=w, az=10, el=-pitch, roll=roll
|
||||||
|
)
|
||||||
|
|
||||||
|
gravity = (up + 1) * 127.5
|
||||||
|
latitude = np.expand_dims((np.degrees(lat) + 90) * (255 / 180), axis=-1)
|
||||||
|
pf_image = np.concatenate([gravity, latitude], axis=2).astype(np.uint8)
|
||||||
|
|
||||||
|
# 尝试生成blend图像,如果失败则创建一个替代的blend图像
|
||||||
|
try:
|
||||||
|
blend = pu.draw_perspective_fields(crop, up, np.radians(np.degrees(lat)))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating blend image: {e}")
|
||||||
|
# 创建一个替代的blend图像
|
||||||
|
blend = np.zeros_like(pf_image)
|
||||||
|
# 添加一些基本的可视化
|
||||||
|
blend[:,:,0] = (np.degrees(lat) + 90) * (255 / 180) # 使用latitude作为红色通道
|
||||||
|
blend[:,:,1] = (up + 1) * 127.5 # 使用up作为绿色通道
|
||||||
|
blend[:,:,2] = 128 # 蓝色通道设置为中性值
|
||||||
|
|
||||||
|
return pf_image, blend
|
||||||
|
|
||||||
|
|
||||||
|
# Gradio UI
|
||||||
|
demo = gr.Blocks(theme=gr.themes.Soft())
|
||||||
|
|
||||||
|
with demo:
|
||||||
|
gr.Markdown("""---""")
|
||||||
|
gr.Markdown("""# PreciseCam: Precise Camera Control for Text-to-Image Generation""")
|
||||||
|
gr.Markdown("""1. Set the camera parameters (Roll, Pitch, Vertical FOV, ξ)""")
|
||||||
|
gr.Markdown("""2. Click "Compute PF-US" to generate the perspective field image""")
|
||||||
|
gr.Markdown("""3. Enter a prompt for the image generation""")
|
||||||
|
gr.Markdown("""4. Click "Generate Image" to create the final image""")
|
||||||
|
gr.Markdown("""---""")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
roll = gr.Slider(-90, 90, 0, label="Roll")
|
||||||
|
pitch = gr.Slider(-90, 90, 1, label="Pitch")
|
||||||
|
vfov = gr.Slider(15, 140, 50, label="Vertical FOV")
|
||||||
|
xi = gr.Slider(0.0, 1, 0.2, label="ξ")
|
||||||
|
prompt = gr.Textbox(
|
||||||
|
lines=4,
|
||||||
|
label="Prompt",
|
||||||
|
show_copy_button=True,
|
||||||
|
value="A colorful autumn park with leaves of orange, red, and yellow scattered across a winding path.",
|
||||||
|
)
|
||||||
|
pf_btn = gr.Button("Compute PF-US", variant="primary")
|
||||||
|
with gr.Row():
|
||||||
|
pf_img = gr.Image(height=1024 // 4, width=1024 // 4, label="PF-US")
|
||||||
|
condition_img = gr.Image(
|
||||||
|
height=1024 // 4, width=1024 // 4, label="Internal PF-US (RGB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
result_img = gr.Image(label="Generated Image", height=1024 // 2)
|
||||||
|
inf_btn = gr.Button("Generate Image", variant="primary")
|
||||||
|
gr.Markdown("""---""")
|
||||||
|
|
||||||
|
pf_btn.click(
|
||||||
|
obtain_pf, inputs=[xi, roll, pitch, vfov], outputs=[condition_img, pf_img]
|
||||||
|
)
|
||||||
|
inf_btn.click(inference, inputs=[prompt, condition_img], outputs=[result_img])
|
||||||
|
|
||||||
|
# 设置环境变量,禁用代理
|
||||||
|
os.environ["HTTP_PROXY"] = ""
|
||||||
|
os.environ["HTTPS_PROXY"] = ""
|
||||||
|
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
|
||||||
|
|
||||||
|
# 启动Gradio应用,指定内联选项和服务器名称
|
||||||
|
demo.launch(
|
||||||
|
server_name="0.0.0.0",
|
||||||
|
server_port=9946,
|
||||||
|
share=False,
|
||||||
|
favicon_path=None,
|
||||||
|
show_api=False,
|
||||||
|
prevent_thread_lock=True
|
||||||
|
)
|
|
@ -0,0 +1,11 @@
|
||||||
|
Code based on the original implementation from [jinlinyi/PerspectiveFields](https://github.com/jinlinyi/PerspectiveFields).
|
||||||
|
|
||||||
|
**Citation:**
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{jin2023perspective,
|
||||||
|
title={Perspective Fields for Single Image Camera Calibration},
|
||||||
|
author={Linyi Jin and Jianming Zhang and Yannick Hold-Geoffroy and Oliver Wang and Kevin Matzen and Matthew Sticha and David F. Fouhey},
|
||||||
|
booktitle = {CVPR},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,261 @@
|
||||||
|
#######################################################################
|
||||||
|
# Adapted code from https://github.com/jinlinyi/PerspectiveFields
|
||||||
|
#######################################################################
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from equilib import grid_sample
|
||||||
|
from sklearn.preprocessing import normalize
|
||||||
|
from perspective_fields.visualizer import VisualizerPerspective
|
||||||
|
from numpy.lib.scimath import sqrt as csqrt
|
||||||
|
|
||||||
|
def deg2rad(deg):
|
||||||
|
"""convert degrees to radians"""
|
||||||
|
return deg * np.pi / 180
|
||||||
|
|
||||||
|
def diskradius(xi, f): # compute the disk radius when the image is catadioptric
|
||||||
|
return np.sqrt(-(f * f) / (1 - xi * xi))
|
||||||
|
|
||||||
|
def crop_distortion(image360_path, f, xi, H, W, az, el, roll):
|
||||||
|
"""
|
||||||
|
Reference: https://github.com/dompm/spherical-distortion-dataset/blob/main/spherical_distortion/spherical_distortion.py
|
||||||
|
Crop distorted image with specified camera parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image360_path (str): path to image which to crop from
|
||||||
|
f (float): focal_length of cropped image
|
||||||
|
xi:
|
||||||
|
H (int): height of cropped image
|
||||||
|
W (int): width of cropped image
|
||||||
|
az: camera rotation about camera frame y-axis of cropped image (degrees)
|
||||||
|
el: camera rotation about camera frame x-axis of cropped image (degrees)
|
||||||
|
roll: camera rotation about camera frame z-axis of cropped image (degrees)
|
||||||
|
Returns:
|
||||||
|
im (np.ndarray): cropped, distorted image
|
||||||
|
"""
|
||||||
|
|
||||||
|
u0 = W / 2.0
|
||||||
|
v0 = H / 2.0
|
||||||
|
|
||||||
|
grid_x, grid_y = np.meshgrid(list(range(W)), list(range(H)))
|
||||||
|
|
||||||
|
image360 = image360_path.copy()
|
||||||
|
|
||||||
|
ImPano_W = np.shape(image360)[1]
|
||||||
|
ImPano_H = np.shape(image360)[0]
|
||||||
|
x_ref = 1
|
||||||
|
y_ref = 1
|
||||||
|
|
||||||
|
fmin = minfocal(
|
||||||
|
u0, v0, xi, x_ref, y_ref
|
||||||
|
) # compute minimal focal length for the image to ve catadioptric with given xi
|
||||||
|
|
||||||
|
# 1. Projection on the camera plane
|
||||||
|
|
||||||
|
X_Cam = np.divide(grid_x - u0, f)
|
||||||
|
Y_Cam = -np.divide(grid_y - v0, f)
|
||||||
|
|
||||||
|
# 2. Projection on the sphere
|
||||||
|
|
||||||
|
AuxVal = np.multiply(X_Cam, X_Cam) + np.multiply(Y_Cam, Y_Cam)
|
||||||
|
|
||||||
|
alpha_cam = np.real(xi + csqrt(1 + np.multiply((1 - xi * xi), AuxVal)))
|
||||||
|
|
||||||
|
alpha_div = AuxVal + 1
|
||||||
|
|
||||||
|
alpha_cam_div = np.divide(alpha_cam, alpha_div)
|
||||||
|
|
||||||
|
X_Sph = np.multiply(X_Cam, alpha_cam_div)
|
||||||
|
Y_Sph = np.multiply(Y_Cam, alpha_cam_div)
|
||||||
|
Z_Sph = alpha_cam_div - xi
|
||||||
|
|
||||||
|
# 3. Rotation of the sphere
|
||||||
|
coords = np.vstack((X_Sph.ravel(), Y_Sph.ravel(), Z_Sph.ravel()))
|
||||||
|
rot_el = np.array(
|
||||||
|
[
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
np.cos(deg2rad(el)),
|
||||||
|
-np.sin(deg2rad(el)),
|
||||||
|
0.0,
|
||||||
|
np.sin(deg2rad(el)),
|
||||||
|
np.cos(deg2rad(el)),
|
||||||
|
]
|
||||||
|
).reshape((3, 3))
|
||||||
|
rot_az = np.array(
|
||||||
|
[
|
||||||
|
np.cos(deg2rad(az)),
|
||||||
|
0.0,
|
||||||
|
-np.sin(deg2rad(az)),
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
np.sin(deg2rad(az)),
|
||||||
|
0.0,
|
||||||
|
np.cos(deg2rad(az)),
|
||||||
|
]
|
||||||
|
).reshape((3, 3))
|
||||||
|
rot_roll = np.array(
|
||||||
|
[
|
||||||
|
np.cos(deg2rad(roll)),
|
||||||
|
np.sin(deg2rad(roll)),
|
||||||
|
0.0,
|
||||||
|
-np.sin(deg2rad(roll)),
|
||||||
|
np.cos(deg2rad(roll)),
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
]
|
||||||
|
).reshape((3, 3))
|
||||||
|
sph = rot_el.dot(rot_roll.T.dot(coords))
|
||||||
|
sph = rot_az.dot(sph)
|
||||||
|
|
||||||
|
sph = sph.reshape((3, H, W)).transpose((1, 2, 0))
|
||||||
|
X_Sph, Y_Sph, Z_Sph = sph[:, :, 0], sph[:, :, 1], sph[:, :, 2]
|
||||||
|
|
||||||
|
# 4. cart 2 sph
|
||||||
|
ntheta = np.arctan2(X_Sph, Z_Sph)
|
||||||
|
nphi = np.arctan2(Y_Sph, np.sqrt(Z_Sph**2 + X_Sph**2))
|
||||||
|
|
||||||
|
pi = np.pi
|
||||||
|
|
||||||
|
# 5. Sphere to pano
|
||||||
|
min_theta = -pi
|
||||||
|
max_theta = pi
|
||||||
|
min_phi = -pi / 2.0
|
||||||
|
max_phi = pi / 2.0
|
||||||
|
|
||||||
|
min_x = 0
|
||||||
|
max_x = ImPano_W - 1.0
|
||||||
|
min_y = 0
|
||||||
|
max_y = ImPano_H - 1.0
|
||||||
|
|
||||||
|
## for x
|
||||||
|
a = (max_theta - min_theta) / (max_x - min_x)
|
||||||
|
b = max_theta - a * max_x # from y=ax+b %% -a;
|
||||||
|
nx = (1.0 / a) * (ntheta - b)
|
||||||
|
|
||||||
|
## for y
|
||||||
|
a = (min_phi - max_phi) / (max_y - min_y)
|
||||||
|
b = max_phi - a * min_y # from y=ax+b %% -a;
|
||||||
|
ny = (1.0 / a) * (nphi - b)
|
||||||
|
lat = nphi.copy()
|
||||||
|
xy_map = np.stack((nx, ny)).transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# 6. Final step interpolation and mapping
|
||||||
|
# im = np.array(my_interpol.interp2linear(image360, nx, ny), dtype=np.uint8)
|
||||||
|
im = grid_sample.numpy_grid_sample.default(
|
||||||
|
image360.transpose(2, 0, 1), np.stack((ny, nx))
|
||||||
|
).transpose(1, 2, 0)
|
||||||
|
if (
|
||||||
|
f < fmin
|
||||||
|
): # if it is a catadioptric image, apply mask and a disk in the middle
|
||||||
|
r = diskradius(xi, f)
|
||||||
|
DIM = im.shape
|
||||||
|
ci = (np.round(DIM[0] / 2), np.round(DIM[1] / 2))
|
||||||
|
xx, yy = np.meshgrid(
|
||||||
|
list(range(DIM[0])) - ci[0], list(range(DIM[1])) - ci[1]
|
||||||
|
)
|
||||||
|
mask = np.double((np.multiply(xx, xx) + np.multiply(yy, yy)) < r * r)
|
||||||
|
mask_3channel = np.stack([mask, mask, mask], axis=-1).transpose((1, 0, 2))
|
||||||
|
im = np.array(np.multiply(im, mask_3channel), dtype=np.uint8)
|
||||||
|
|
||||||
|
col = nphi[:, W // 2]
|
||||||
|
zero_crossings_rows = np.where(np.diff(np.sign(col)))[0]
|
||||||
|
if len(zero_crossings_rows) >= 2:
|
||||||
|
print("WARNING | Number of zero crossings:", len(zero_crossings_rows))
|
||||||
|
zero_crossings_rows = [zero_crossings_rows[0]]
|
||||||
|
|
||||||
|
if len(zero_crossings_rows) == 0:
|
||||||
|
offset = np.nan
|
||||||
|
else:
|
||||||
|
assert col[zero_crossings_rows[0]] >= 0
|
||||||
|
assert col[zero_crossings_rows[0] + 1] <= 0
|
||||||
|
dy = col[zero_crossings_rows[0] + 1] - col[zero_crossings_rows[0]]
|
||||||
|
offset = zero_crossings_rows[0] - col[zero_crossings_rows[0]] / dy
|
||||||
|
assert col[zero_crossings_rows[0]] / dy <= 1.0
|
||||||
|
# Reproject [nx, ny+epsilon] back
|
||||||
|
epsilon = 1e-5
|
||||||
|
end_vector_x = nx.copy()
|
||||||
|
end_vector_y = ny.copy() - epsilon
|
||||||
|
# -5. pano to Sphere
|
||||||
|
a = (max_theta - min_theta) / (max_x - min_x)
|
||||||
|
b = max_theta - a * max_x # from y=ax+b %% -a;
|
||||||
|
ntheta_end = end_vector_x * a + b
|
||||||
|
## for y
|
||||||
|
a = (min_phi - max_phi) / (max_y - min_y)
|
||||||
|
b = max_phi - a * min_y
|
||||||
|
nphi_end = end_vector_y * a + b
|
||||||
|
# -4. sph 2 cart
|
||||||
|
Y_Sph = np.sin(nphi)
|
||||||
|
X_Sph = np.cos(nphi_end) * np.sin(ntheta_end)
|
||||||
|
Z_Sph = np.cos(nphi_end) * np.cos(ntheta_end)
|
||||||
|
# -3. Reverse Rotation of the sphere
|
||||||
|
coords = np.vstack((X_Sph.ravel(), Y_Sph.ravel(), Z_Sph.ravel()))
|
||||||
|
sph = rot_roll.dot(rot_el.T.dot(rot_az.T.dot(coords)))
|
||||||
|
sph = sph.reshape((3, H, W)).transpose((1, 2, 0))
|
||||||
|
X_Sph, Y_Sph, Z_Sph = sph[:, :, 0], sph[:, :, 1], sph[:, :, 2]
|
||||||
|
|
||||||
|
# -1. Projection on the image plane
|
||||||
|
|
||||||
|
X_Cam = X_Sph * f / (xi * csqrt(X_Sph**2 + Y_Sph**2 + Z_Sph**2) + Z_Sph) + u0
|
||||||
|
Y_Cam = -Y_Sph * f / (xi * csqrt(X_Sph**2 + Y_Sph**2 + Z_Sph**2) + Z_Sph) + v0
|
||||||
|
up = np.stack((X_Cam - grid_x, Y_Cam - grid_y)).transpose(1, 2, 0)
|
||||||
|
up = normalize(up.reshape(-1, 2)).reshape(up.shape)
|
||||||
|
|
||||||
|
return im, ntheta, nphi, offset, up, lat, xy_map
|
||||||
|
|
||||||
|
def minfocal(u0, v0, xi, xref=1, yref=1):
|
||||||
|
"""compute the minimum focal for the image to be catadioptric given xi"""
|
||||||
|
value = -(1 - xi * xi) * ((xref - u0) * (xref - u0) + (yref - v0) * (yref - v0))
|
||||||
|
|
||||||
|
if value < 0:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return np.sqrt(value) * 1.0001
|
||||||
|
|
||||||
|
#------------------------------------------------
|
||||||
|
|
||||||
|
def draw_perspective_fields(
|
||||||
|
img_rgb, up, latimap, color=None, density=10, arrow_inv_len=20, return_img=True
|
||||||
|
):
|
||||||
|
"""draw perspective field on top of input image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_rgb (np.ndarray): input image
|
||||||
|
up (np.ndarray): gravity field (h, w, 2)
|
||||||
|
latimap (np.ndarray): latitude map (h, w) (radians)
|
||||||
|
color ((float, float, float), optional): RGB color for up vectors. [0, 1]
|
||||||
|
Defaults to None.
|
||||||
|
density (int, optional): Value to control density of up vectors.
|
||||||
|
Each row has (width // density) vectors.
|
||||||
|
Each column has (height // density) vectors.
|
||||||
|
Defaults to 10.
|
||||||
|
arrow_inv_len (int, optional): Value to control vector length
|
||||||
|
Vector length set to (image plane diagonal // arrow_inv_len).
|
||||||
|
Defaults to 20.
|
||||||
|
return_img (bool, optional): bool to control if to return np array or VisImage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
image blended with perspective fields.
|
||||||
|
"""
|
||||||
|
visualizer = VisualizerPerspective(img_rgb.copy())
|
||||||
|
vis_output = visualizer.draw_lati(latimap)
|
||||||
|
if torch.is_tensor(up):
|
||||||
|
up = up.numpy().transpose(1, 2, 0)
|
||||||
|
im_h, im_w, _ = img_rgb.shape
|
||||||
|
x, y = np.meshgrid(
|
||||||
|
np.arange(0, im_w, im_w // density), np.arange(0, im_h, im_h // density)
|
||||||
|
)
|
||||||
|
x, y = x.ravel(), y.ravel()
|
||||||
|
arrow_len = np.sqrt(im_w**2 + im_h**2) // arrow_inv_len
|
||||||
|
end = up[y, x, :] * arrow_len
|
||||||
|
if color is None:
|
||||||
|
color = (0, 1, 0)
|
||||||
|
vis_output = visualizer.draw_arrow(x, y, end[:, 0], -end[:, 1], color=color)
|
||||||
|
if return_img:
|
||||||
|
return vis_output.get_image()
|
||||||
|
else:
|
||||||
|
return vis_output
|
|
@ -0,0 +1,283 @@
|
||||||
|
#######################################################################
|
||||||
|
# Adapted code from https://github.com/jinlinyi/PerspectiveFields
|
||||||
|
#######################################################################
|
||||||
|
# Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
|
||||||
|
import matplotlib.colors as mplc
|
||||||
|
import matplotlib.figure as mplfigure
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||||
|
|
||||||
|
|
||||||
|
class VisImage:
|
||||||
|
def __init__(self, img, scale=1.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
|
||||||
|
scale (float): scale the input image
|
||||||
|
"""
|
||||||
|
self.img = img
|
||||||
|
self.scale = scale
|
||||||
|
self.width, self.height = img.shape[1], img.shape[0]
|
||||||
|
self._setup_figure(img)
|
||||||
|
|
||||||
|
def _setup_figure(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Same as in :meth:`__init__()`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
|
||||||
|
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
|
||||||
|
"""
|
||||||
|
fig = mplfigure.Figure(frameon=False)
|
||||||
|
self.dpi = fig.get_dpi()
|
||||||
|
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
|
||||||
|
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
||||||
|
fig.set_size_inches(
|
||||||
|
(self.width * self.scale + 1e-2) / self.dpi,
|
||||||
|
(self.height * self.scale + 1e-2) / self.dpi,
|
||||||
|
)
|
||||||
|
self.canvas = FigureCanvasAgg(fig)
|
||||||
|
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
|
||||||
|
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
||||||
|
ax.axis("off")
|
||||||
|
self.fig = fig
|
||||||
|
self.ax = ax
|
||||||
|
self.reset_image(img)
|
||||||
|
|
||||||
|
def reset_image(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img: same as in __init__
|
||||||
|
"""
|
||||||
|
img = img.astype("uint8")
|
||||||
|
self.ax.imshow(
|
||||||
|
img, extent=(0, self.width, self.height, 0), interpolation="nearest"
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self, filepath):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
filepath (str): a string that contains the absolute path, including the file name, where
|
||||||
|
the visualized image will be saved.
|
||||||
|
"""
|
||||||
|
self.fig.savefig(filepath)
|
||||||
|
|
||||||
|
def get_image(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
ndarray:
|
||||||
|
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
|
||||||
|
The shape is scaled w.r.t the input image using the given `scale` argument.
|
||||||
|
"""
|
||||||
|
canvas = self.canvas
|
||||||
|
s, (width, height) = canvas.print_to_buffer()
|
||||||
|
# buf = io.BytesIO() # works for cairo backend
|
||||||
|
# canvas.print_rgba(buf)
|
||||||
|
# width, height = self.width, self.height
|
||||||
|
# s = buf.getvalue()
|
||||||
|
|
||||||
|
buffer = np.frombuffer(s, dtype="uint8")
|
||||||
|
|
||||||
|
img_rgba = buffer.reshape(height, width, 4)
|
||||||
|
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
||||||
|
return rgb.astype("uint8")
|
||||||
|
|
||||||
|
|
||||||
|
class Visualizer:
|
||||||
|
"""
|
||||||
|
Visualizer that draws data about detection/segmentation on images.
|
||||||
|
|
||||||
|
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
|
||||||
|
that draw primitive objects to images, as well as high-level wrappers like
|
||||||
|
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
|
||||||
|
that draw composite data in some pre-defined style.
|
||||||
|
|
||||||
|
Note that the exact visualization style for the high-level wrappers are subject to change.
|
||||||
|
Style such as color, opacity, label contents, visibility of labels, or even the visibility
|
||||||
|
of objects themselves (e.g. when the object is too small) may change according
|
||||||
|
to different heuristics, as long as the results still look visually reasonable.
|
||||||
|
|
||||||
|
To obtain a consistent style, you can implement custom drawing functions with the
|
||||||
|
abovementioned primitive methods instead. If you need more customized visualization
|
||||||
|
styles, you can process the data yourself following their format documented in
|
||||||
|
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
|
||||||
|
intend to satisfy everyone's preference on drawing styles.
|
||||||
|
|
||||||
|
This visualizer focuses on high rendering quality rather than performance. It is not
|
||||||
|
designed to be used for real-time applications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_rgb, scale=1.0, font_size_scale=1.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
|
||||||
|
the height and width of the image respectively. C is the number of
|
||||||
|
color channels. The image is required to be in RGB format since that
|
||||||
|
is a requirement of the Matplotlib library. The image is also expected
|
||||||
|
to be in the range [0, 255].
|
||||||
|
font_size_scale: extra scaling of font size on top of default font size
|
||||||
|
"""
|
||||||
|
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
|
||||||
|
self.output = VisImage(self.img, scale=scale)
|
||||||
|
self.cpu_device = torch.device("cpu")
|
||||||
|
|
||||||
|
# too small texts are useless, therefore clamp to 9
|
||||||
|
self._default_font_size = (
|
||||||
|
max(np.sqrt(self.output.height * self.output.width) // 90, 10 // scale)
|
||||||
|
* font_size_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Primitive drawing functions:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def draw_text(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
position,
|
||||||
|
*,
|
||||||
|
font_size=None,
|
||||||
|
color="g",
|
||||||
|
horizontal_alignment="center",
|
||||||
|
rotation=0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text (str): class label
|
||||||
|
position (tuple): a tuple of the x and y coordinates to place text on image.
|
||||||
|
font_size (int, optional): font of the text. If not provided, a font size
|
||||||
|
proportional to the image width is calculated and used.
|
||||||
|
color: color of the text. Refer to `matplotlib.colors` for full list
|
||||||
|
of formats that are accepted.
|
||||||
|
horizontal_alignment (str): see `matplotlib.text.Text`
|
||||||
|
rotation: rotation angle in degrees CCW
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output (VisImage): image object with text drawn.
|
||||||
|
"""
|
||||||
|
if not font_size:
|
||||||
|
font_size = self._default_font_size
|
||||||
|
|
||||||
|
# since the text background is dark, we don't want the text to be dark
|
||||||
|
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
||||||
|
color[np.argmax(color)] = max(0.8, np.max(color))
|
||||||
|
|
||||||
|
x, y = position
|
||||||
|
self.output.ax.text(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
text,
|
||||||
|
size=font_size * self.output.scale,
|
||||||
|
family="sans-serif",
|
||||||
|
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
|
||||||
|
verticalalignment="top",
|
||||||
|
horizontalalignment=horizontal_alignment,
|
||||||
|
color=color,
|
||||||
|
zorder=10,
|
||||||
|
rotation=rotation,
|
||||||
|
)
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
def get_output(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
output (VisImage): the image output containing the visualizations added
|
||||||
|
to the image.
|
||||||
|
"""
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
|
||||||
|
class VisualizerPerspective(Visualizer):
|
||||||
|
def draw_arrow(
|
||||||
|
self,
|
||||||
|
x_pos,
|
||||||
|
y_pos,
|
||||||
|
x_direct,
|
||||||
|
y_direct,
|
||||||
|
color=None,
|
||||||
|
linestyle="-",
|
||||||
|
linewidth=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x_data (list[int]): a list containing x values of all the points being drawn.
|
||||||
|
Length of list should match the length of y_data.
|
||||||
|
y_data (list[int]): a list containing y values of all the points being drawn.
|
||||||
|
Length of list should match the length of x_data.
|
||||||
|
color: color of the line. Refer to `matplotlib.colors` for a full list of
|
||||||
|
formats that are accepted.
|
||||||
|
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
|
||||||
|
for a full list of formats that are accepted.
|
||||||
|
linewidth (float or None): width of the line. When it's None,
|
||||||
|
a default value will be computed and used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output (VisImage): image object with line drawn.
|
||||||
|
"""
|
||||||
|
if linewidth is None:
|
||||||
|
linewidth = self._default_font_size / 3
|
||||||
|
linewidth = max(linewidth, 1)
|
||||||
|
self.output.ax.quiver(
|
||||||
|
x_pos,
|
||||||
|
y_pos,
|
||||||
|
x_direct,
|
||||||
|
y_direct,
|
||||||
|
color=color,
|
||||||
|
scale_units="xy",
|
||||||
|
scale=1,
|
||||||
|
antialiased=True,
|
||||||
|
headaxislength=3.5,
|
||||||
|
linewidths=0.1, # , width=0.01
|
||||||
|
)
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
def draw_lati(
|
||||||
|
self, latimap, alpha_contourf=0.4, alpha_contour=0.9, contour_only=False
|
||||||
|
):
|
||||||
|
"""Blend latitude map"""
|
||||||
|
height, width = latimap.shape
|
||||||
|
y, x = np.mgrid[0:height, 0:width]
|
||||||
|
cmap = plt.get_cmap("seismic")
|
||||||
|
bands = 20
|
||||||
|
levels = np.linspace(-np.pi / 2, np.pi / 2, bands - 1)
|
||||||
|
if not contour_only:
|
||||||
|
pp = self.output.ax.contourf(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
latimap,
|
||||||
|
levels=levels,
|
||||||
|
cmap=cmap,
|
||||||
|
alpha=alpha_contourf,
|
||||||
|
antialiased=True,
|
||||||
|
)
|
||||||
|
pp2 = self.output.ax.contour(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
latimap,
|
||||||
|
pp.levels,
|
||||||
|
cmap=cmap,
|
||||||
|
alpha=alpha_contour,
|
||||||
|
antialiased=True,
|
||||||
|
linewidths=5,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for c in pp2.collections:
|
||||||
|
c.set_linestyle("solid")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# only plot central contour
|
||||||
|
pp = self.output.ax.contour(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
latimap,
|
||||||
|
levels=[0],
|
||||||
|
cmap=cmap,
|
||||||
|
alpha=alpha_contour,
|
||||||
|
antialiased=True,
|
||||||
|
linewidths=15,
|
||||||
|
)
|
||||||
|
return self.output
|
Loading…
Reference in New Issue