KICCO_AI_IMAGE/flux_style_shaper_api/flux_precisecam_demo.py

254 lines
9.1 KiB
Python
Raw Normal View History

2025-05-23 11:24:02 +08:00
import gradio as gr
import numpy as np
import torch
import os
from PIL import Image
from perspective_fields import pano_utils as pu
from diffusers import (
AutoencoderKL,
ControlNetModel,
FluxControlNetPipeline,
FlowMatchEulerDiscreteScheduler,
FluxControlNetModel,
)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import T5EncoderModel, T5TokenizerFast
from safetensors.torch import load_file
# 加载本地模型文件路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(current_dir)
# 修正路径格式使用os.path.join确保路径分隔符正确
flux_model_dir = os.path.join(project_dir, "models", "diffusion_models")
print(f"Loading models from:")
print(f"Flux dir: {flux_model_dir}")
# 创建缓存目录(如果不存在)
os.makedirs(os.path.join(project_dir, "cache"), exist_ok=True)
# 加载所有需要的组件
print("Loading components...")
# 1. 加载调度器
scheduler = FlowMatchEulerDiscreteScheduler.from_config(
{
"num_train_timesteps": 1000
}
)
# 5. 加载ControlNet - 使用与FLUX.1-dev兼容的模型
print("Loading ControlNet compatible with FLUX.1-dev...")
controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16,cache_dir=os.path.join(project_dir, "cache"))
# 2. 加载VAE
try:
vae = AutoencoderKL.from_pretrained(
"diffusers/FLUX.1-vae",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
except Exception as e:
print(f"Error loading VAE: {e}")
print("Loading default VAE...")
vae = AutoencoderKL.from_pretrained(
"diffusers/FLUX.1-vae",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
# 3. 加载文本编码器和标记器
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=os.path.join(project_dir, "cache")
)
# 4. 加载第二个文本编码器和标记器
print("从本地加载预先下载的T5模型...")
# 指定本地模型路径 - 假设已经手动下载了魔搭的模型
t5_model_path = os.path.join(project_dir, "models", "t5_xxl") # 根据实际路径调整
# 检查模型路径是否存在
if not os.path.exists(t5_model_path):
print(f"警告: 本地模型路径 {t5_model_path} 不存在。请先手动下载魔搭的T5-XXL模型到此路径。")
print("继续尝试从HuggingFace下载但这可能会很慢...")
t5_model_path = "google/t5-v1_1-xxl"
text_encoder_2 = T5EncoderModel.from_pretrained(
t5_model_path, # 使用本地路径
torch_dtype=torch.bfloat16,
cache_dir=os.path.join(project_dir, "cache")
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
t5_model_path, # 使用本地路径
cache_dir=os.path.join(project_dir, "cache")
)
# 6. 加载特征提取器和图像编码器
feature_extractor = CLIPImageProcessor.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=os.path.join(project_dir, "cache")
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
cache_dir=os.path.join(project_dir, "cache")
)
# 7. 加载PipelineImageProcessor (transformer)
base_model = "black-forest-labs/FLUX.1-dev"
# 不加载transformer组件因为这会导致类型错误
# pipeline会自动初始化正确类型的transformer
# 8. 加载Pipeline - 直接从HuggingFace加载
try:
print("Loading FLUX model from HuggingFace...")
# 设置缓存相关环境变量
os.environ["HF_HOME"] = os.path.join(project_dir, "cache")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(project_dir, "cache", "transformers")
os.environ["DIFFUSERS_CACHE"] = os.path.join(project_dir, "cache", "diffusers")
# 设置缓存路径参数
cache_path = os.path.join(project_dir, "cache")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
controlnet=controlnet,
torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet
low_cpu_mem_usage=False,
cache_dir=cache_path,
resume_download=True, # 如果下载中断,尝试继续下载
force_download=False # 如果已缓存,不重新下载
)
except Exception as e:
print(f"Error loading from HuggingFace: {e}")
print("Trying alternative configuration...")
# 启用CPU卸载以减少GPU内存使用
pipe.enable_model_cpu_offload()
def inference(prompt, pf_image, n_steps=2, seed=13):
"""Generates an image based on the given prompt and perspective field image."""
pf = Image.fromarray(
np.concatenate(
[np.expand_dims(pf_image[:, :, i], axis=-1) for i in [2, 1, 0]], axis=2
)
)
pf_condition = pf.resize((1024, 1024))
generator = torch.manual_seed(seed)
print("Control image size:", pf_condition.size) # 应该是 (1024, 1024)
return pipe(
prompt=prompt,
num_inference_steps=n_steps,
generator=generator,
control_image=pf_condition, # 只保留ControlNet的输入参数
# 移除不支持的参数
# image=pf_condition, # 需要为img2img提供一个基础图像
# strength=0.8, # 控制处理强度
).images[0]
def obtain_pf(xi, roll, pitch, vfov):
"""Computes perspective fields given camera parameters."""
w, h = (1024, 1024)
equi_img = np.zeros((h, w, 3), dtype=np.uint8)
x = -np.sin(np.radians(vfov / 2))
z = np.sqrt(1 - x**2)
f_px_effective = -0.5 * (w / 2) * (xi + z) / x
crop, _, _, _, up, lat, _ = pu.crop_distortion(
equi_img, f=f_px_effective, xi=xi, H=h, W=w, az=10, el=-pitch, roll=roll
)
gravity = (up + 1) * 127.5
latitude = np.expand_dims((np.degrees(lat) + 90) * (255 / 180), axis=-1)
pf_image = np.concatenate([gravity, latitude], axis=2).astype(np.uint8)
# 尝试生成blend图像如果失败则创建一个替代的blend图像
try:
blend = pu.draw_perspective_fields(crop, up, np.radians(np.degrees(lat)))
except Exception as e:
print(f"Error generating blend image: {e}")
# 创建一个替代的blend图像
blend = np.zeros_like(pf_image)
# 添加一些基本的可视化
blend[:,:,0] = (np.degrees(lat) + 90) * (255 / 180) # 使用latitude作为红色通道
blend[:,:,1] = (up + 1) * 127.5 # 使用up作为绿色通道
blend[:,:,2] = 128 # 蓝色通道设置为中性值
return pf_image, blend
# Gradio UI
demo = gr.Blocks(theme=gr.themes.Soft())
with demo:
gr.Markdown("""---""")
gr.Markdown("""# PreciseCam: Precise Camera Control for Text-to-Image Generation""")
gr.Markdown("""1. Set the camera parameters (Roll, Pitch, Vertical FOV, ξ)""")
gr.Markdown("""2. Click "Compute PF-US" to generate the perspective field image""")
gr.Markdown("""3. Enter a prompt for the image generation""")
gr.Markdown("""4. Click "Generate Image" to create the final image""")
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
roll = gr.Slider(-90, 90, 0, label="Roll")
pitch = gr.Slider(-90, 90, 1, label="Pitch")
vfov = gr.Slider(15, 140, 50, label="Vertical FOV")
xi = gr.Slider(0.0, 1, 0.2, label="ξ")
prompt = gr.Textbox(
lines=4,
label="Prompt",
show_copy_button=True,
value="A colorful autumn park with leaves of orange, red, and yellow scattered across a winding path.",
)
pf_btn = gr.Button("Compute PF-US", variant="primary")
with gr.Row():
pf_img = gr.Image(height=1024 // 4, width=1024 // 4, label="PF-US")
condition_img = gr.Image(
height=1024 // 4, width=1024 // 4, label="Internal PF-US (RGB)"
)
with gr.Column(scale=2):
result_img = gr.Image(label="Generated Image", height=1024 // 2)
inf_btn = gr.Button("Generate Image", variant="primary")
gr.Markdown("""---""")
pf_btn.click(
obtain_pf, inputs=[xi, roll, pitch, vfov], outputs=[condition_img, pf_img]
)
inf_btn.click(inference, inputs=[prompt, condition_img], outputs=[result_img])
# 设置环境变量,禁用代理
os.environ["HTTP_PROXY"] = ""
os.environ["HTTPS_PROXY"] = ""
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
# 启动Gradio应用指定内联选项和服务器名称
demo.launch(
server_name="0.0.0.0",
server_port=9946,
share=False,
favicon_path=None,
show_api=False,
prevent_thread_lock=True
)