254 lines
9.1 KiB
Python
254 lines
9.1 KiB
Python
import gradio as gr
|
||
import numpy as np
|
||
import torch
|
||
import os
|
||
from PIL import Image
|
||
from perspective_fields import pano_utils as pu
|
||
from diffusers import (
|
||
AutoencoderKL,
|
||
ControlNetModel,
|
||
FluxControlNetPipeline,
|
||
FlowMatchEulerDiscreteScheduler,
|
||
FluxControlNetModel,
|
||
)
|
||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
|
||
from transformers import T5EncoderModel, T5TokenizerFast
|
||
from safetensors.torch import load_file
|
||
# 加载本地模型文件路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
project_dir = os.path.dirname(current_dir)
|
||
|
||
# 修正路径格式,使用os.path.join确保路径分隔符正确
|
||
flux_model_dir = os.path.join(project_dir, "models", "diffusion_models")
|
||
|
||
print(f"Loading models from:")
|
||
print(f"Flux dir: {flux_model_dir}")
|
||
|
||
# 创建缓存目录(如果不存在)
|
||
os.makedirs(os.path.join(project_dir, "cache"), exist_ok=True)
|
||
|
||
# 加载所有需要的组件
|
||
print("Loading components...")
|
||
|
||
# 1. 加载调度器
|
||
scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||
{
|
||
"num_train_timesteps": 1000
|
||
}
|
||
)
|
||
|
||
# 5. 加载ControlNet - 使用与FLUX.1-dev兼容的模型
|
||
print("Loading ControlNet compatible with FLUX.1-dev...")
|
||
controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
||
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16,cache_dir=os.path.join(project_dir, "cache"))
|
||
|
||
|
||
# 2. 加载VAE
|
||
try:
|
||
vae = AutoencoderKL.from_pretrained(
|
||
"diffusers/FLUX.1-vae",
|
||
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
except Exception as e:
|
||
print(f"Error loading VAE: {e}")
|
||
print("Loading default VAE...")
|
||
vae = AutoencoderKL.from_pretrained(
|
||
"diffusers/FLUX.1-vae",
|
||
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
|
||
# 3. 加载文本编码器和标记器
|
||
text_encoder = CLIPTextModel.from_pretrained(
|
||
"openai/clip-vit-large-patch14",
|
||
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
tokenizer = CLIPTokenizer.from_pretrained(
|
||
"openai/clip-vit-large-patch14",
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
|
||
# 4. 加载第二个文本编码器和标记器
|
||
print("从本地加载预先下载的T5模型...")
|
||
# 指定本地模型路径 - 假设已经手动下载了魔搭的模型
|
||
t5_model_path = os.path.join(project_dir, "models", "t5_xxl") # 根据实际路径调整
|
||
|
||
# 检查模型路径是否存在
|
||
if not os.path.exists(t5_model_path):
|
||
print(f"警告: 本地模型路径 {t5_model_path} 不存在。请先手动下载魔搭的T5-XXL模型到此路径。")
|
||
print("继续尝试从HuggingFace下载,但这可能会很慢...")
|
||
t5_model_path = "google/t5-v1_1-xxl"
|
||
|
||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||
t5_model_path, # 使用本地路径
|
||
torch_dtype=torch.bfloat16,
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
tokenizer_2 = T5TokenizerFast.from_pretrained(
|
||
t5_model_path, # 使用本地路径
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
|
||
|
||
|
||
# 6. 加载特征提取器和图像编码器
|
||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||
"openai/clip-vit-large-patch14",
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||
"openai/clip-vit-large-patch14",
|
||
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||
cache_dir=os.path.join(project_dir, "cache")
|
||
)
|
||
|
||
# 7. 加载PipelineImageProcessor (transformer)
|
||
base_model = "black-forest-labs/FLUX.1-dev"
|
||
|
||
# 不加载transformer组件,因为这会导致类型错误
|
||
# pipeline会自动初始化正确类型的transformer
|
||
|
||
# 8. 加载Pipeline - 直接从HuggingFace加载
|
||
try:
|
||
print("Loading FLUX model from HuggingFace...")
|
||
# 设置缓存相关环境变量
|
||
os.environ["HF_HOME"] = os.path.join(project_dir, "cache")
|
||
os.environ["TRANSFORMERS_CACHE"] = os.path.join(project_dir, "cache", "transformers")
|
||
os.environ["DIFFUSERS_CACHE"] = os.path.join(project_dir, "cache", "diffusers")
|
||
# 设置缓存路径参数
|
||
cache_path = os.path.join(project_dir, "cache")
|
||
|
||
pipe = FluxControlNetPipeline.from_pretrained(
|
||
base_model,
|
||
scheduler=scheduler,
|
||
vae=vae,
|
||
text_encoder=text_encoder,
|
||
tokenizer=tokenizer,
|
||
text_encoder_2=text_encoder_2,
|
||
tokenizer_2=tokenizer_2,
|
||
feature_extractor=feature_extractor,
|
||
image_encoder=image_encoder,
|
||
controlnet=controlnet,
|
||
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
|
||
low_cpu_mem_usage=False,
|
||
cache_dir=cache_path,
|
||
resume_download=True, # 如果下载中断,尝试继续下载
|
||
force_download=False # 如果已缓存,不重新下载
|
||
)
|
||
except Exception as e:
|
||
print(f"Error loading from HuggingFace: {e}")
|
||
print("Trying alternative configuration...")
|
||
|
||
# 启用CPU卸载以减少GPU内存使用
|
||
pipe.enable_model_cpu_offload()
|
||
|
||
def inference(prompt, pf_image, n_steps=2, seed=13):
|
||
"""Generates an image based on the given prompt and perspective field image."""
|
||
pf = Image.fromarray(
|
||
np.concatenate(
|
||
[np.expand_dims(pf_image[:, :, i], axis=-1) for i in [2, 1, 0]], axis=2
|
||
)
|
||
)
|
||
pf_condition = pf.resize((1024, 1024))
|
||
generator = torch.manual_seed(seed)
|
||
print("Control image size:", pf_condition.size) # 应该是 (1024, 1024)
|
||
return pipe(
|
||
prompt=prompt,
|
||
num_inference_steps=n_steps,
|
||
generator=generator,
|
||
control_image=pf_condition, # 只保留ControlNet的输入参数
|
||
# 移除不支持的参数
|
||
# image=pf_condition, # 需要为img2img提供一个基础图像
|
||
# strength=0.8, # 控制处理强度
|
||
).images[0]
|
||
|
||
|
||
def obtain_pf(xi, roll, pitch, vfov):
|
||
"""Computes perspective fields given camera parameters."""
|
||
w, h = (1024, 1024)
|
||
equi_img = np.zeros((h, w, 3), dtype=np.uint8)
|
||
x = -np.sin(np.radians(vfov / 2))
|
||
z = np.sqrt(1 - x**2)
|
||
f_px_effective = -0.5 * (w / 2) * (xi + z) / x
|
||
|
||
crop, _, _, _, up, lat, _ = pu.crop_distortion(
|
||
equi_img, f=f_px_effective, xi=xi, H=h, W=w, az=10, el=-pitch, roll=roll
|
||
)
|
||
|
||
gravity = (up + 1) * 127.5
|
||
latitude = np.expand_dims((np.degrees(lat) + 90) * (255 / 180), axis=-1)
|
||
pf_image = np.concatenate([gravity, latitude], axis=2).astype(np.uint8)
|
||
|
||
# 尝试生成blend图像,如果失败则创建一个替代的blend图像
|
||
try:
|
||
blend = pu.draw_perspective_fields(crop, up, np.radians(np.degrees(lat)))
|
||
except Exception as e:
|
||
print(f"Error generating blend image: {e}")
|
||
# 创建一个替代的blend图像
|
||
blend = np.zeros_like(pf_image)
|
||
# 添加一些基本的可视化
|
||
blend[:,:,0] = (np.degrees(lat) + 90) * (255 / 180) # 使用latitude作为红色通道
|
||
blend[:,:,1] = (up + 1) * 127.5 # 使用up作为绿色通道
|
||
blend[:,:,2] = 128 # 蓝色通道设置为中性值
|
||
|
||
return pf_image, blend
|
||
|
||
|
||
# Gradio UI
|
||
demo = gr.Blocks(theme=gr.themes.Soft())
|
||
|
||
with demo:
|
||
gr.Markdown("""---""")
|
||
gr.Markdown("""# PreciseCam: Precise Camera Control for Text-to-Image Generation""")
|
||
gr.Markdown("""1. Set the camera parameters (Roll, Pitch, Vertical FOV, ξ)""")
|
||
gr.Markdown("""2. Click "Compute PF-US" to generate the perspective field image""")
|
||
gr.Markdown("""3. Enter a prompt for the image generation""")
|
||
gr.Markdown("""4. Click "Generate Image" to create the final image""")
|
||
gr.Markdown("""---""")
|
||
|
||
with gr.Row():
|
||
with gr.Column():
|
||
roll = gr.Slider(-90, 90, 0, label="Roll")
|
||
pitch = gr.Slider(-90, 90, 1, label="Pitch")
|
||
vfov = gr.Slider(15, 140, 50, label="Vertical FOV")
|
||
xi = gr.Slider(0.0, 1, 0.2, label="ξ")
|
||
prompt = gr.Textbox(
|
||
lines=4,
|
||
label="Prompt",
|
||
show_copy_button=True,
|
||
value="A colorful autumn park with leaves of orange, red, and yellow scattered across a winding path.",
|
||
)
|
||
pf_btn = gr.Button("Compute PF-US", variant="primary")
|
||
with gr.Row():
|
||
pf_img = gr.Image(height=1024 // 4, width=1024 // 4, label="PF-US")
|
||
condition_img = gr.Image(
|
||
height=1024 // 4, width=1024 // 4, label="Internal PF-US (RGB)"
|
||
)
|
||
|
||
with gr.Column(scale=2):
|
||
result_img = gr.Image(label="Generated Image", height=1024 // 2)
|
||
inf_btn = gr.Button("Generate Image", variant="primary")
|
||
gr.Markdown("""---""")
|
||
|
||
pf_btn.click(
|
||
obtain_pf, inputs=[xi, roll, pitch, vfov], outputs=[condition_img, pf_img]
|
||
)
|
||
inf_btn.click(inference, inputs=[prompt, condition_img], outputs=[result_img])
|
||
|
||
# 设置环境变量,禁用代理
|
||
os.environ["HTTP_PROXY"] = ""
|
||
os.environ["HTTPS_PROXY"] = ""
|
||
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
|
||
|
||
# 启动Gradio应用,指定内联选项和服务器名称
|
||
demo.launch(
|
||
server_name="0.0.0.0",
|
||
server_port=9946,
|
||
share=False,
|
||
favicon_path=None,
|
||
show_api=False,
|
||
prevent_thread_lock=True
|
||
)
|