""" FLUX风格塑形MCP资源模块 --------------------- 将FLUX风格塑形API集成到MCP服务器中 作为资源和工具提供给MCP客户端 """ import os import sys import random import logging import tempfile import shutil from pathlib import Path from typing import Dict, Any, Union, Tuple, Optional import base64 from io import BytesIO import torch from PIL import Image from huggingface_hub import hf_hub_download from dotenv import load_dotenv # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 加载环境变量 load_dotenv(override=True) # 获取ComfyUI根目录 def get_comfyui_dir(): """获取ComfyUI根目录路径""" current_dir = os.path.dirname(os.path.abspath(__file__)) comfyui_dir = os.path.dirname(os.path.dirname(current_dir)) # 假设mcp_server在ComfyUI根目录下 return comfyui_dir # 下载所需模型 def download_models(): """检查并下载所需模型""" logger.info("检查并下载所需模型...") # 获取ComfyUI根目录 comfyui_dir = get_comfyui_dir() # 获取Hugging Face令牌 hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN") if not hf_token: logger.warning("未找到Hugging Face访问令牌!受限模型可能无法下载。") logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。") else: logger.info("已找到Hugging Face访问令牌") models = [ {"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")}, {"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")}, {"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")}, {"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")}, {"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")}, {"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")}, {"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")} ] for model in models: try: local_path = os.path.join(model["local_dir"], model["filename"]) if not os.path.exists(local_path): logger.info(f"下载 {model['filename']} 到 {model['local_dir']}") os.makedirs(model["local_dir"], exist_ok=True) hf_hub_download( repo_id=model["repo_id"], filename=model["filename"], local_dir=model["local_dir"], token=hf_token # 使用令牌进行下载 ) else: logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}") except Exception as e: if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e): logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限") logger.error("请按照以下步骤解决:") logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌") logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限") logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌") else: logger.error(f"下载 {model['filename']} 时出错: {e}") # 如果是关键模型,则抛出异常阻止继续 if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]: raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行") # 从索引获取值的辅助函数 def get_value_at_index(obj: Union[Dict, list], index: int) -> Any: """从索引获取值的辅助函数""" try: return obj[index] except KeyError: return obj["result"][index] # 添加ComfyUI目录到系统路径 def add_comfyui_directory_to_sys_path() -> None: """将ComfyUI目录添加到系统路径""" comfyui_path = get_comfyui_dir() if comfyui_path is not None and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) logger.info(f"'{comfyui_path}' 已添加到 sys.path") # 添加额外模型路径 def add_extra_model_paths() -> None: """添加额外模型路径配置""" comfyui_dir = get_comfyui_dir() sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块 try: # 尝试直接从main.py导入load_extra_path_config from main import load_extra_path_config extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml") if os.path.exists(extra_model_paths): logger.info(f"找到extra_model_paths配置: {extra_model_paths}") load_extra_path_config(extra_model_paths) else: logger.info("未找到extra_model_paths.yaml文件,将使用默认模型路径") except ImportError as e: logger.warning(f"无法导入load_extra_path_config: {e}") logger.info("将使用默认模型路径继续运行") # 导入自定义节点 def import_custom_nodes() -> None: """导入ComfyUI自定义节点""" comfyui_dir = get_comfyui_dir() sys.path.append(comfyui_dir) try: import asyncio import execution from nodes import init_extra_nodes import server loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) init_extra_nodes() logger.info("自定义节点导入成功") except Exception as e: logger.error(f"导入自定义节点时出错: {e}") raise # 设置ComfyUI环境 def setup_environment(): """设置ComfyUI环境,加载路径和模型""" logger.info("设置ComfyUI环境...") # 初始化路径 add_comfyui_directory_to_sys_path() add_extra_model_paths() # 下载所需模型 download_models() # 导入自定义节点 try: import_custom_nodes() logger.info("自定义节点导入成功") except Exception as e: logger.error(f"导入自定义节点时出错: {e}") raise RuntimeError("无法导入自定义节点,程序无法继续运行") # 保存临时图像文件 def save_image_to_temp(image_data: Union[str, bytes]) -> str: """ 将图像数据保存到临时文件 参数: image_data: 图像数据,可以是base64编码字符串或字节数据 返回: str: 临时文件路径 """ # 如果是base64编码字符串,先解码 if isinstance(image_data, str) and image_data.startswith(('data:image', 'data:application')): # 提取实际的base64编码部分 content_type, base64_data = image_data.split(';base64,') image_data = base64.b64decode(base64_data) # 创建临时文件并保存图像 suffix = '.png' # 默认文件扩展名 with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: if isinstance(image_data, str): # 如果仍然是字符串,假设是文件路径 shutil.copyfile(image_data, tmp.name) else: # 否则是字节数据 tmp.write(image_data) tmp_path = tmp.name return tmp_path # 图像生成主函数 def generate_image(prompt: str, structure_image: Union[str, bytes], style_image: Union[str, bytes], depth_strength: float = 15.0, style_strength: float = 0.5) -> Tuple[str, bytes]: """ 生成风格化图像 参数: prompt: 文本提示 structure_image: 结构图像数据(可以是base64编码字符串、字节数据或文件路径) style_image: 风格图像数据(可以是base64编码字符串、字节数据或文件路径) depth_strength: 深度强度 style_strength: 风格强度 返回: Tuple[str, bytes]: 返回生成图像的路径和图像数据 """ # 保存输入图像到临时文件 structure_image_path = save_image_to_temp(structure_image) style_image_path = save_image_to_temp(style_image) try: # 从这里开始复制原始生成函数的实现 logger.info("开始图像生成过程...") # ====================== 第1阶段:导入必要组件 ====================== # 从ComfyUI的nodes.py导入所有需要的节点类型 from nodes import ( StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader, SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage, VAEDecode, UNETLoader, CLIPTextEncode, ) # ====================== 第2阶段:初始化常量 ====================== intconstant = NODE_CLASS_MAPPINGS["INTConstant"]() CONST_1024 = intconstant.get_value(value=1024) logger.info("加载模型...") # ====================== 第3阶段:加载模型 ====================== dualcliploader = DualCLIPLoader() CLIP_MODEL = dualcliploader.load_clip( clip_name1="t5/t5xxl_fp16.safetensors", clip_name2="clip_l.safetensors", type="flux", ) vaeloader = VAELoader() VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors") unetloader = UNETLoader() UNET_MODEL = unetloader.load_unet( unet_name="flux1-depth-dev.safetensors", weight_dtype="default" ) clipvisionloader = CLIPVisionLoader() CLIP_VISION_MODEL = clipvisionloader.load_clip( clip_name="sigclip_vision_patch14_384.safetensors" ) stylemodelloader = StyleModelLoader() STYLE_MODEL = stylemodelloader.load_style_model( style_model_name="flux1-redux-dev.safetensors" ) ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]() SAMPLER = ksamplerselect.get_sampler(sampler_name="euler") cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]() downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]() DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel( model="depth_anything_v2_vitl_fp32.safetensors" ) # ====================== 第4阶段:导入辅助节点 ====================== cliptextencode = CLIPTextEncode() loadimage = LoadImage() vaeencode = VAEEncode() fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]() clipvisionencode = CLIPVisionEncode() stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() emptylatentimage = EmptyLatentImage() basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]() basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]() randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]() vaedecode = VAEDecode() cr_text = NODE_CLASS_MAPPINGS["CR Text"]() saveimage = SaveImage() getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]() depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]() imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]() # ====================== 第5阶段:预加载模型到GPU ====================== model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL] from comfy import model_management model_management.load_models_gpu([ loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders ]) logger.info("处理输入图像...") # ====================== 第6阶段:准备输入图像 ====================== comfyui_dir = get_comfyui_dir() input_dir = os.path.join(comfyui_dir, "input") os.makedirs(input_dir, exist_ok=True) # 创建临时文件名以避免冲突 structure_filename = f"structure_{random.randint(1000, 9999)}.png" style_filename = f"style_{random.randint(1000, 9999)}.png" # 将上传的图像复制到输入目录 structure_input_path = os.path.join(input_dir, structure_filename) style_input_path = os.path.join(input_dir, style_filename) shutil.copy(structure_image_path, structure_input_path) shutil.copy(style_image_path, style_input_path) # ====================== 第7阶段:图像生成流程 ====================== with torch.inference_mode(): # ------- 7.1: 设置CLIP ------- clip_switch = cr_clip_input_switch.switch( Input=1, clip1=get_value_at_index(CLIP_MODEL, 0), clip2=get_value_at_index(CLIP_MODEL, 0), ) # ------- 7.2: 文本编码 ------- text_encoded = cliptextencode.encode( text=prompt, clip=get_value_at_index(clip_switch, 0), ) empty_text = cliptextencode.encode( text="", clip=get_value_at_index(clip_switch, 0), ) logger.info("处理结构图像...") # ------- 7.3: 处理结构图像 ------- structure_img = loadimage.load_image(image=structure_filename) resized_img = imageresize.execute( width=get_value_at_index(CONST_1024, 0), height=get_value_at_index(CONST_1024, 0), interpolation="bicubic", method="keep proportion", condition="always", multiple_of=16, image=get_value_at_index(structure_img, 0), ) size_info = getimagesizeandcount.getsize( image=get_value_at_index(resized_img, 0) ) vae_encoded = vaeencode.encode( pixels=get_value_at_index(size_info, 0), vae=get_value_at_index(VAE_MODEL, 0), ) logger.info("处理深度...") # ------- 7.4: 深度处理 ------- depth_processed = depthanything_v2.process( da_model=get_value_at_index(DEPTH_MODEL, 0), images=get_value_at_index(size_info, 0), ) flux_guided = fluxguidance.append( guidance=depth_strength, conditioning=get_value_at_index(text_encoded, 0), ) logger.info("处理风格图像...") # ------- 7.5: 风格处理 ------- style_img = loadimage.load_image(image=style_filename) style_encoded = clipvisionencode.encode( crop="center", clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0), image=get_value_at_index(style_img, 0), ) logger.info("设置条件...") # ------- 7.6: 设置条件 ------- conditioning = instructpixtopixconditioning.encode( positive=get_value_at_index(flux_guided, 0), negative=get_value_at_index(empty_text, 0), vae=get_value_at_index(VAE_MODEL, 0), pixels=get_value_at_index(depth_processed, 0), ) style_applied = stylemodelapplyadvanced.apply_stylemodel( strength=style_strength, conditioning=get_value_at_index(conditioning, 0), style_model=get_value_at_index(STYLE_MODEL, 0), clip_vision_output=get_value_at_index(style_encoded, 0), ) # ------- 7.7: 创建潜在空间 ------- empty_latent = emptylatentimage.generate( width=get_value_at_index(resized_img, 1), height=get_value_at_index(resized_img, 2), batch_size=1, ) logger.info("设置采样...") # ------- 7.8: 设置扩散引导 ------- guided = basicguider.get_guider( model=get_value_at_index(UNET_MODEL, 0), conditioning=get_value_at_index(style_applied, 0), ) schedule = basicscheduler.get_sigmas( scheduler="simple", steps=28, denoise=1, model=get_value_at_index(UNET_MODEL, 0), ) noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64)) logger.info("采样中...") # ------- 7.9: 执行采样 ------- sampled = samplercustomadvanced.sample( noise=get_value_at_index(noise, 0), guider=get_value_at_index(guided, 0), sampler=get_value_at_index(SAMPLER, 0), sigmas=get_value_at_index(schedule, 0), latent_image=get_value_at_index(empty_latent, 0), ) logger.info("解码图像...") # ------- 7.10: 解码结果 ------- decoded = vaedecode.decode( samples=get_value_at_index(sampled, 0), vae=get_value_at_index(VAE_MODEL, 0), ) # ------- 7.11: 保存图像 ------- prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux") saved = saveimage.save_images( filename_prefix=get_value_at_index(prefix, 0), images=get_value_at_index(decoded, 0), ) # 获取输出路径 saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename']) logger.info(f"图像保存到 {saved_path}") # 读取生成的图像数据 with open(saved_path, "rb") as f: image_data = f.read() # 清理临时文件 try: os.remove(structure_input_path) os.remove(style_input_path) os.remove(structure_image_path) os.remove(style_image_path) except Exception as e: logger.warning(f"清理临时文件时出错: {e}") # 返回生成图像的路径和数据 return saved_path, image_data except Exception as e: # 清理临时文件 try: os.remove(structure_image_path) os.remove(style_image_path) except: pass logger.error(f"生成图像时出错: {e}") raise # 图像数据转换为base64 def image_to_base64(image_path: str) -> str: """ 将图像文件转换为base64编码字符串 参数: image_path: 图像文件路径 返回: str: base64编码的图像数据,包含MIME类型前缀 """ with open(image_path, "rb") as f: image_data = f.read() encoded = base64.b64encode(image_data).decode("utf-8") mime_type = "image/png" # 默认MIME类型 # 根据文件扩展名确定MIME类型 ext = os.path.splitext(image_path)[1].lower() if ext == '.jpg' or ext == '.jpeg': mime_type = "image/jpeg" elif ext == '.webp': mime_type = "image/webp" return f"data:{mime_type};base64,{encoded}" # 初始化FLUX环境 def init_flux(): """初始化FLUX环境,设置路径和导入模型""" try: setup_environment() logger.info("FLUX环境初始化成功") return True except Exception as e: logger.error(f"FLUX环境初始化失败: {e}") return False # 集成到MCP服务器的函数 def register_flux_resources(server): """ 将FLUX风格塑形功能注册到MCP服务器 参数: server: MCP服务器实例 """ # 初始化FLUX环境 initialized = init_flux() if not initialized: logger.error("FLUX环境初始化失败,无法注册资源") return # 注册资源: 服务状态 @server.register_resource("flux://status") def get_flux_status() -> str: """返回FLUX服务状态""" return "FLUX风格塑形服务已启动并运行正常" # 注册工具: 生成风格化图像 @server.register_tool("生成风格化图像") def generate_styled_image(prompt: str, structure_image_base64: str, style_image_base64: str, depth_strength: float = 15.0, style_strength: float = 0.5) -> Dict[str, Any]: """ 生成风格化图像工具 参数: prompt: 文本提示,用于指导生成过程 structure_image_base64: 结构图像的base64编码,提供基本构图 style_image_base64: 风格图像的base64编码,提供艺术风格 depth_strength: 深度强度,控制结构保持程度 style_strength: 风格强度,控制风格应用程度 返回: Dict: 包含生成图像的base64编码和其他信息 """ try: # 生成图像 saved_path, _ = generate_image( prompt=prompt, structure_image=structure_image_base64, style_image=style_image_base64, depth_strength=depth_strength, style_strength=style_strength ) # 将生成的图像转换为base64 output_base64 = image_to_base64(saved_path) # 返回结果 return { "status": "success", "message": "图像生成成功", "image_base64": output_base64, "parameters": { "prompt": prompt, "depth_strength": depth_strength, "style_strength": style_strength } } except Exception as e: logger.error(f"生成图像失败: {e}") return { "status": "error", "message": f"生成图像失败: {str(e)}" } # 注册提示: 创建风格转移提示 @server.register_prompt("风格转移提示") def style_transfer_prompt(subject: str, style: str) -> str: """创建风格转移的提示模板""" return f"将{subject}转换为{style}风格" logger.info("FLUX风格塑形资源和工具已注册到MCP服务器")