import gradio as gr import numpy as np import torch import os from PIL import Image from perspective_fields import pano_utils as pu from diffusers import ( AutoencoderKL, ControlNetModel, FluxControlNetPipeline, FlowMatchEulerDiscreteScheduler, FluxControlNetModel, ) from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import T5EncoderModel, T5TokenizerFast from safetensors.torch import load_file # 加载本地模型文件路径 current_dir = os.path.dirname(os.path.abspath(__file__)) project_dir = os.path.dirname(current_dir) # 修正路径格式,使用os.path.join确保路径分隔符正确 flux_model_dir = os.path.join(project_dir, "models", "diffusion_models") print(f"Loading models from:") print(f"Flux dir: {flux_model_dir}") # 创建缓存目录(如果不存在) os.makedirs(os.path.join(project_dir, "cache"), exist_ok=True) # 加载所有需要的组件 print("Loading components...") # 1. 加载调度器 scheduler = FlowMatchEulerDiscreteScheduler.from_config( { "num_train_timesteps": 1000 } ) # 5. 加载ControlNet - 使用与FLUX.1-dev兼容的模型 print("Loading ControlNet compatible with FLUX.1-dev...") controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16,cache_dir=os.path.join(project_dir, "cache")) # 2. 加载VAE try: vae = AutoencoderKL.from_pretrained( "diffusers/FLUX.1-vae", torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet cache_dir=os.path.join(project_dir, "cache") ) except Exception as e: print(f"Error loading VAE: {e}") print("Loading default VAE...") vae = AutoencoderKL.from_pretrained( "diffusers/FLUX.1-vae", torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet cache_dir=os.path.join(project_dir, "cache") ) # 3. 加载文本编码器和标记器 text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet cache_dir=os.path.join(project_dir, "cache") ) tokenizer = CLIPTokenizer.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=os.path.join(project_dir, "cache") ) # 4. 加载第二个文本编码器和标记器 print("从本地加载预先下载的T5模型...") # 指定本地模型路径 - 假设已经手动下载了魔搭的模型 t5_model_path = os.path.join(project_dir, "models", "t5_xxl") # 根据实际路径调整 # 检查模型路径是否存在 if not os.path.exists(t5_model_path): print(f"警告: 本地模型路径 {t5_model_path} 不存在。请先手动下载魔搭的T5-XXL模型到此路径。") print("继续尝试从HuggingFace下载,但这可能会很慢...") t5_model_path = "google/t5-v1_1-xxl" text_encoder_2 = T5EncoderModel.from_pretrained( t5_model_path, # 使用本地路径 torch_dtype=torch.bfloat16, cache_dir=os.path.join(project_dir, "cache") ) tokenizer_2 = T5TokenizerFast.from_pretrained( t5_model_path, # 使用本地路径 cache_dir=os.path.join(project_dir, "cache") ) # 6. 加载特征提取器和图像编码器 feature_extractor = CLIPImageProcessor.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=os.path.join(project_dir, "cache") ) image_encoder = CLIPVisionModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet cache_dir=os.path.join(project_dir, "cache") ) # 7. 加载PipelineImageProcessor (transformer) base_model = "black-forest-labs/FLUX.1-dev" # 不加载transformer组件,因为这会导致类型错误 # pipeline会自动初始化正确类型的transformer # 8. 加载Pipeline - 直接从HuggingFace加载 try: print("Loading FLUX model from HuggingFace...") # 设置缓存相关环境变量 os.environ["HF_HOME"] = os.path.join(project_dir, "cache") os.environ["TRANSFORMERS_CACHE"] = os.path.join(project_dir, "cache", "transformers") os.environ["DIFFUSERS_CACHE"] = os.path.join(project_dir, "cache", "diffusers") # 设置缓存路径参数 cache_path = os.path.join(project_dir, "cache") pipe = FluxControlNetPipeline.from_pretrained( base_model, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, feature_extractor=feature_extractor, image_encoder=image_encoder, controlnet=controlnet, torch_dtype=torch.bfloat16, # 使用bfloat16匹配controlnet low_cpu_mem_usage=False, cache_dir=cache_path, resume_download=True, # 如果下载中断,尝试继续下载 force_download=False # 如果已缓存,不重新下载 ) except Exception as e: print(f"Error loading from HuggingFace: {e}") print("Trying alternative configuration...") # 启用CPU卸载以减少GPU内存使用 pipe.enable_model_cpu_offload() def inference(prompt, pf_image, n_steps=2, seed=13): """Generates an image based on the given prompt and perspective field image.""" pf = Image.fromarray( np.concatenate( [np.expand_dims(pf_image[:, :, i], axis=-1) for i in [2, 1, 0]], axis=2 ) ) pf_condition = pf.resize((1024, 1024)) generator = torch.manual_seed(seed) print("Control image size:", pf_condition.size) # 应该是 (1024, 1024) return pipe( prompt=prompt, num_inference_steps=n_steps, generator=generator, control_image=pf_condition, # 只保留ControlNet的输入参数 # 移除不支持的参数 # image=pf_condition, # 需要为img2img提供一个基础图像 # strength=0.8, # 控制处理强度 ).images[0] def obtain_pf(xi, roll, pitch, vfov): """Computes perspective fields given camera parameters.""" w, h = (1024, 1024) equi_img = np.zeros((h, w, 3), dtype=np.uint8) x = -np.sin(np.radians(vfov / 2)) z = np.sqrt(1 - x**2) f_px_effective = -0.5 * (w / 2) * (xi + z) / x crop, _, _, _, up, lat, _ = pu.crop_distortion( equi_img, f=f_px_effective, xi=xi, H=h, W=w, az=10, el=-pitch, roll=roll ) gravity = (up + 1) * 127.5 latitude = np.expand_dims((np.degrees(lat) + 90) * (255 / 180), axis=-1) pf_image = np.concatenate([gravity, latitude], axis=2).astype(np.uint8) # 尝试生成blend图像,如果失败则创建一个替代的blend图像 try: blend = pu.draw_perspective_fields(crop, up, np.radians(np.degrees(lat))) except Exception as e: print(f"Error generating blend image: {e}") # 创建一个替代的blend图像 blend = np.zeros_like(pf_image) # 添加一些基本的可视化 blend[:,:,0] = (np.degrees(lat) + 90) * (255 / 180) # 使用latitude作为红色通道 blend[:,:,1] = (up + 1) * 127.5 # 使用up作为绿色通道 blend[:,:,2] = 128 # 蓝色通道设置为中性值 return pf_image, blend # Gradio UI demo = gr.Blocks(theme=gr.themes.Soft()) with demo: gr.Markdown("""---""") gr.Markdown("""# PreciseCam: Precise Camera Control for Text-to-Image Generation""") gr.Markdown("""1. Set the camera parameters (Roll, Pitch, Vertical FOV, ξ)""") gr.Markdown("""2. Click "Compute PF-US" to generate the perspective field image""") gr.Markdown("""3. Enter a prompt for the image generation""") gr.Markdown("""4. Click "Generate Image" to create the final image""") gr.Markdown("""---""") with gr.Row(): with gr.Column(): roll = gr.Slider(-90, 90, 0, label="Roll") pitch = gr.Slider(-90, 90, 1, label="Pitch") vfov = gr.Slider(15, 140, 50, label="Vertical FOV") xi = gr.Slider(0.0, 1, 0.2, label="ξ") prompt = gr.Textbox( lines=4, label="Prompt", show_copy_button=True, value="A colorful autumn park with leaves of orange, red, and yellow scattered across a winding path.", ) pf_btn = gr.Button("Compute PF-US", variant="primary") with gr.Row(): pf_img = gr.Image(height=1024 // 4, width=1024 // 4, label="PF-US") condition_img = gr.Image( height=1024 // 4, width=1024 // 4, label="Internal PF-US (RGB)" ) with gr.Column(scale=2): result_img = gr.Image(label="Generated Image", height=1024 // 2) inf_btn = gr.Button("Generate Image", variant="primary") gr.Markdown("""---""") pf_btn.click( obtain_pf, inputs=[xi, roll, pitch, vfov], outputs=[condition_img, pf_img] ) inf_btn.click(inference, inputs=[prompt, condition_img], outputs=[result_img]) # 设置环境变量,禁用代理 os.environ["HTTP_PROXY"] = "" os.environ["HTTPS_PROXY"] = "" os.environ["NO_PROXY"] = "localhost,127.0.0.1" # 启动Gradio应用,指定内联选项和服务器名称 demo.launch( server_name="0.0.0.0", server_port=9946, share=False, favicon_path=None, show_api=False, prevent_thread_lock=True )