KICCO_AI_IMAGE/flux_style_shaper_api/flux_precisecam_demo.py

254 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import numpy as np
import torch
import os
from PIL import Image
from perspective_fields import pano_utils as pu
from diffusers import (
AutoencoderKL,
ControlNetModel,
FluxControlNetPipeline,
FlowMatchEulerDiscreteScheduler,
FluxControlNetModel,
)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import T5EncoderModel, T5TokenizerFast
from safetensors.torch import load_file
# 加载本地模型文件路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(current_dir)
# 修正路径格式使用os.path.join确保路径分隔符正确
flux_model_dir = os.path.join(project_dir, "models", "diffusion_models")
print(f"Loading models from:")
print(f"Flux dir: {flux_model_dir}")
# 创建缓存目录(如果不存在)
os.makedirs(os.path.join(project_dir, "cache"), exist_ok=True)
# 加载所有需要的组件
print("Loading components...")
# 1. 加载调度器
scheduler = FlowMatchEulerDiscreteScheduler.from_config(
{
"num_train_timesteps": 1000
}
)
# 5. 加载ControlNet - 使用与FLUX.1-dev兼容的模型
print("Loading ControlNet compatible with FLUX.1-dev...")
controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16,cache_dir=os.path.join(project_dir, "cache"))
# 2. 加载VAE
try:
vae = AutoencoderKL.from_pretrained(
"diffusers/FLUX.1-vae",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
except Exception as e:
print(f"Error loading VAE: {e}")
print("Loading default VAE...")
vae = AutoencoderKL.from_pretrained(
"diffusers/FLUX.1-vae",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
# 3. 加载文本编码器和标记器
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=os.path.join(project_dir, "cache")
)
# 4. 加载第二个文本编码器和标记器
print("从本地加载预先下载的T5模型...")
# 指定本地模型路径 - 假设已经手动下载了魔搭的模型
t5_model_path = os.path.join(project_dir, "models", "t5_xxl") # 根据实际路径调整
# 检查模型路径是否存在
if not os.path.exists(t5_model_path):
print(f"警告: 本地模型路径 {t5_model_path} 不存在。请先手动下载魔搭的T5-XXL模型到此路径。")
print("继续尝试从HuggingFace下载但这可能会很慢...")
t5_model_path = "google/t5-v1_1-xxl"
text_encoder_2 = T5EncoderModel.from_pretrained(
t5_model_path, # 使用本地路径
torch_dtype=torch.bfloat16,
cache_dir=os.path.join(project_dir, "cache")
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
t5_model_path, # 使用本地路径
cache_dir=os.path.join(project_dir, "cache")
)
# 6. 加载特征提取器和图像编码器
feature_extractor = CLIPImageProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=os.path.join(project_dir, "cache")
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
# 7. 加载PipelineImageProcessor (transformer)
base_model = "black-forest-labs/FLUX.1-dev"
# 不加载transformer组件因为这会导致类型错误
# pipeline会自动初始化正确类型的transformer
# 8. 加载Pipeline - 直接从HuggingFace加载
try:
print("Loading FLUX model from HuggingFace...")
# 设置缓存相关环境变量
os.environ["HF_HOME"] = os.path.join(project_dir, "cache")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(project_dir, "cache", "transformers")
os.environ["DIFFUSERS_CACHE"] = os.path.join(project_dir, "cache", "diffusers")
# 设置缓存路径参数
cache_path = os.path.join(project_dir, "cache")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
controlnet=controlnet,
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
low_cpu_mem_usage=False,
cache_dir=cache_path,
resume_download=True, # 如果下载中断,尝试继续下载
force_download=False # 如果已缓存,不重新下载
)
except Exception as e:
print(f"Error loading from HuggingFace: {e}")
print("Trying alternative configuration...")
# 启用CPU卸载以减少GPU内存使用
pipe.enable_model_cpu_offload()
def inference(prompt, pf_image, n_steps=2, seed=13):
"""Generates an image based on the given prompt and perspective field image."""
pf = Image.fromarray(
np.concatenate(
[np.expand_dims(pf_image[:, :, i], axis=-1) for i in [2, 1, 0]], axis=2
)
)
pf_condition = pf.resize((1024, 1024))
generator = torch.manual_seed(seed)
print("Control image size:", pf_condition.size) # 应该是 (1024, 1024)
return pipe(
prompt=prompt,
num_inference_steps=n_steps,
generator=generator,
control_image=pf_condition, # 只保留ControlNet的输入参数
# 移除不支持的参数
# image=pf_condition, # 需要为img2img提供一个基础图像
# strength=0.8, # 控制处理强度
).images[0]
def obtain_pf(xi, roll, pitch, vfov):
"""Computes perspective fields given camera parameters."""
w, h = (1024, 1024)
equi_img = np.zeros((h, w, 3), dtype=np.uint8)
x = -np.sin(np.radians(vfov / 2))
z = np.sqrt(1 - x**2)
f_px_effective = -0.5 * (w / 2) * (xi + z) / x
crop, _, _, _, up, lat, _ = pu.crop_distortion(
equi_img, f=f_px_effective, xi=xi, H=h, W=w, az=10, el=-pitch, roll=roll
)
gravity = (up + 1) * 127.5
latitude = np.expand_dims((np.degrees(lat) + 90) * (255 / 180), axis=-1)
pf_image = np.concatenate([gravity, latitude], axis=2).astype(np.uint8)
# 尝试生成blend图像如果失败则创建一个替代的blend图像
try:
blend = pu.draw_perspective_fields(crop, up, np.radians(np.degrees(lat)))
except Exception as e:
print(f"Error generating blend image: {e}")
# 创建一个替代的blend图像
blend = np.zeros_like(pf_image)
# 添加一些基本的可视化
blend[:,:,0] = (np.degrees(lat) + 90) * (255 / 180) # 使用latitude作为红色通道
blend[:,:,1] = (up + 1) * 127.5 # 使用up作为绿色通道
blend[:,:,2] = 128 # 蓝色通道设置为中性值
return pf_image, blend
# Gradio UI
demo = gr.Blocks(theme=gr.themes.Soft())
with demo:
gr.Markdown("""---""")
gr.Markdown("""# PreciseCam: Precise Camera Control for Text-to-Image Generation""")
gr.Markdown("""1. Set the camera parameters (Roll, Pitch, Vertical FOV, ξ)""")
gr.Markdown("""2. Click "Compute PF-US" to generate the perspective field image""")
gr.Markdown("""3. Enter a prompt for the image generation""")
gr.Markdown("""4. Click "Generate Image" to create the final image""")
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
roll = gr.Slider(-90, 90, 0, label="Roll")
pitch = gr.Slider(-90, 90, 1, label="Pitch")
vfov = gr.Slider(15, 140, 50, label="Vertical FOV")
xi = gr.Slider(0.0, 1, 0.2, label="ξ")
prompt = gr.Textbox(
lines=4,
label="Prompt",
show_copy_button=True,
value="A colorful autumn park with leaves of orange, red, and yellow scattered across a winding path.",
)
pf_btn = gr.Button("Compute PF-US", variant="primary")
with gr.Row():
pf_img = gr.Image(height=1024 // 4, width=1024 // 4, label="PF-US")
condition_img = gr.Image(
height=1024 // 4, width=1024 // 4, label="Internal PF-US (RGB)"
)
with gr.Column(scale=2):
result_img = gr.Image(label="Generated Image", height=1024 // 2)
inf_btn = gr.Button("Generate Image", variant="primary")
gr.Markdown("""---""")
pf_btn.click(
obtain_pf, inputs=[xi, roll, pitch, vfov], outputs=[condition_img, pf_img]
)
inf_btn.click(inference, inputs=[prompt, condition_img], outputs=[result_img])
# 设置环境变量,禁用代理
os.environ["HTTP_PROXY"] = ""
os.environ["HTTPS_PROXY"] = ""
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
# 启动Gradio应用指定内联选项和服务器名称
demo.launch(
server_name="0.0.0.0",
server_port=9946,
share=False,
favicon_path=None,
show_api=False,
prevent_thread_lock=True
)