KICCO_AI_IMAGE/mcp_server/flux_style_resource.py

593 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FLUX风格塑形MCP资源模块
---------------------
将FLUX风格塑形API集成到MCP服务器中
作为资源和工具提供给MCP客户端
"""
import os
import sys
import random
import logging
import tempfile
import shutil
from pathlib import Path
from typing import Dict, Any, Union, Tuple, Optional
import base64
from io import BytesIO
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from dotenv import load_dotenv
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 加载环境变量
load_dotenv(override=True)
# 获取ComfyUI根目录
def get_comfyui_dir():
"""获取ComfyUI根目录路径"""
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_dir = os.path.dirname(os.path.dirname(current_dir)) # 假设mcp_server在ComfyUI根目录下
return comfyui_dir
# 下载所需模型
def download_models():
"""检查并下载所需模型"""
logger.info("检查并下载所需模型...")
# 获取ComfyUI根目录
comfyui_dir = get_comfyui_dir()
# 获取Hugging Face令牌
hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
if not hf_token:
logger.warning("未找到Hugging Face访问令牌受限模型可能无法下载。")
logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。")
else:
logger.info("已找到Hugging Face访问令牌")
models = [
{"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")},
{"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")},
{"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")},
{"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")},
{"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")},
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")},
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")}
]
for model in models:
try:
local_path = os.path.join(model["local_dir"], model["filename"])
if not os.path.exists(local_path):
logger.info(f"下载 {model['filename']}{model['local_dir']}")
os.makedirs(model["local_dir"], exist_ok=True)
hf_hub_download(
repo_id=model["repo_id"],
filename=model["filename"],
local_dir=model["local_dir"],
token=hf_token # 使用令牌进行下载
)
else:
logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}")
except Exception as e:
if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e):
logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限")
logger.error("请按照以下步骤解决:")
logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌")
logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限")
logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌")
else:
logger.error(f"下载 {model['filename']} 时出错: {e}")
# 如果是关键模型,则抛出异常阻止继续
if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]:
raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行")
# 从索引获取值的辅助函数
def get_value_at_index(obj: Union[Dict, list], index: int) -> Any:
"""从索引获取值的辅助函数"""
try:
return obj[index]
except KeyError:
return obj["result"][index]
# 添加ComfyUI目录到系统路径
def add_comfyui_directory_to_sys_path() -> None:
"""将ComfyUI目录添加到系统路径"""
comfyui_path = get_comfyui_dir()
if comfyui_path is not None and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
logger.info(f"'{comfyui_path}' 已添加到 sys.path")
# 添加额外模型路径
def add_extra_model_paths() -> None:
"""添加额外模型路径配置"""
comfyui_dir = get_comfyui_dir()
sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块
try:
# 尝试直接从main.py导入load_extra_path_config
from main import load_extra_path_config
extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml")
if os.path.exists(extra_model_paths):
logger.info(f"找到extra_model_paths配置: {extra_model_paths}")
load_extra_path_config(extra_model_paths)
else:
logger.info("未找到extra_model_paths.yaml文件将使用默认模型路径")
except ImportError as e:
logger.warning(f"无法导入load_extra_path_config: {e}")
logger.info("将使用默认模型路径继续运行")
# 导入自定义节点
def import_custom_nodes() -> None:
"""导入ComfyUI自定义节点"""
comfyui_dir = get_comfyui_dir()
sys.path.append(comfyui_dir)
try:
import asyncio
import execution
from nodes import init_extra_nodes
import server
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
init_extra_nodes()
logger.info("自定义节点导入成功")
except Exception as e:
logger.error(f"导入自定义节点时出错: {e}")
raise
# 设置ComfyUI环境
def setup_environment():
"""设置ComfyUI环境加载路径和模型"""
logger.info("设置ComfyUI环境...")
# 初始化路径
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
# 下载所需模型
download_models()
# 导入自定义节点
try:
import_custom_nodes()
logger.info("自定义节点导入成功")
except Exception as e:
logger.error(f"导入自定义节点时出错: {e}")
raise RuntimeError("无法导入自定义节点,程序无法继续运行")
# 保存临时图像文件
def save_image_to_temp(image_data: Union[str, bytes]) -> str:
"""
将图像数据保存到临时文件
参数:
image_data: 图像数据可以是base64编码字符串或字节数据
返回:
str: 临时文件路径
"""
# 如果是base64编码字符串先解码
if isinstance(image_data, str) and image_data.startswith(('data:image', 'data:application')):
# 提取实际的base64编码部分
content_type, base64_data = image_data.split(';base64,')
image_data = base64.b64decode(base64_data)
# 创建临时文件并保存图像
suffix = '.png' # 默认文件扩展名
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
if isinstance(image_data, str):
# 如果仍然是字符串,假设是文件路径
shutil.copyfile(image_data, tmp.name)
else:
# 否则是字节数据
tmp.write(image_data)
tmp_path = tmp.name
return tmp_path
# 图像生成主函数
def generate_image(prompt: str, structure_image: Union[str, bytes], style_image: Union[str, bytes],
depth_strength: float = 15.0, style_strength: float = 0.5) -> Tuple[str, bytes]:
"""
生成风格化图像
参数:
prompt: 文本提示
structure_image: 结构图像数据可以是base64编码字符串、字节数据或文件路径
style_image: 风格图像数据可以是base64编码字符串、字节数据或文件路径
depth_strength: 深度强度
style_strength: 风格强度
返回:
Tuple[str, bytes]: 返回生成图像的路径和图像数据
"""
# 保存输入图像到临时文件
structure_image_path = save_image_to_temp(structure_image)
style_image_path = save_image_to_temp(style_image)
try:
# 从这里开始复制原始生成函数的实现
logger.info("开始图像生成过程...")
# ====================== 第1阶段导入必要组件 ======================
# 从ComfyUI的nodes.py导入所有需要的节点类型
from nodes import (
StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader,
SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage,
VAEDecode, UNETLoader, CLIPTextEncode,
)
# ====================== 第2阶段初始化常量 ======================
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
CONST_1024 = intconstant.get_value(value=1024)
logger.info("加载模型...")
# ====================== 第3阶段加载模型 ======================
dualcliploader = DualCLIPLoader()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5/t5xxl_fp16.safetensors",
clip_name2="clip_l.safetensors",
type="flux",
)
vaeloader = VAELoader()
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
unetloader = UNETLoader()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-depth-dev.safetensors",
weight_dtype="default"
)
clipvisionloader = CLIPVisionLoader()
CLIP_VISION_MODEL = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
stylemodelloader = StyleModelLoader()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
model="depth_anything_v2_vitl_fp32.safetensors"
)
# ====================== 第4阶段导入辅助节点 ======================
cliptextencode = CLIPTextEncode()
loadimage = LoadImage()
vaeencode = VAEEncode()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
clipvisionencode = CLIPVisionEncode()
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
emptylatentimage = EmptyLatentImage()
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
vaedecode = VAEDecode()
cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
saveimage = SaveImage()
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
# ====================== 第5阶段预加载模型到GPU ======================
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
from comfy import model_management
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
])
logger.info("处理输入图像...")
# ====================== 第6阶段准备输入图像 ======================
comfyui_dir = get_comfyui_dir()
input_dir = os.path.join(comfyui_dir, "input")
os.makedirs(input_dir, exist_ok=True)
# 创建临时文件名以避免冲突
structure_filename = f"structure_{random.randint(1000, 9999)}.png"
style_filename = f"style_{random.randint(1000, 9999)}.png"
# 将上传的图像复制到输入目录
structure_input_path = os.path.join(input_dir, structure_filename)
style_input_path = os.path.join(input_dir, style_filename)
shutil.copy(structure_image_path, structure_input_path)
shutil.copy(style_image_path, style_input_path)
# ====================== 第7阶段图像生成流程 ======================
with torch.inference_mode():
# ------- 7.1: 设置CLIP -------
clip_switch = cr_clip_input_switch.switch(
Input=1,
clip1=get_value_at_index(CLIP_MODEL, 0),
clip2=get_value_at_index(CLIP_MODEL, 0),
)
# ------- 7.2: 文本编码 -------
text_encoded = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(clip_switch, 0),
)
empty_text = cliptextencode.encode(
text="",
clip=get_value_at_index(clip_switch, 0),
)
logger.info("处理结构图像...")
# ------- 7.3: 处理结构图像 -------
structure_img = loadimage.load_image(image=structure_filename)
resized_img = imageresize.execute(
width=get_value_at_index(CONST_1024, 0),
height=get_value_at_index(CONST_1024, 0),
interpolation="bicubic",
method="keep proportion",
condition="always",
multiple_of=16,
image=get_value_at_index(structure_img, 0),
)
size_info = getimagesizeandcount.getsize(
image=get_value_at_index(resized_img, 0)
)
vae_encoded = vaeencode.encode(
pixels=get_value_at_index(size_info, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
logger.info("处理深度...")
# ------- 7.4: 深度处理 -------
depth_processed = depthanything_v2.process(
da_model=get_value_at_index(DEPTH_MODEL, 0),
images=get_value_at_index(size_info, 0),
)
flux_guided = fluxguidance.append(
guidance=depth_strength,
conditioning=get_value_at_index(text_encoded, 0),
)
logger.info("处理风格图像...")
# ------- 7.5: 风格处理 -------
style_img = loadimage.load_image(image=style_filename)
style_encoded = clipvisionencode.encode(
crop="center",
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
image=get_value_at_index(style_img, 0),
)
logger.info("设置条件...")
# ------- 7.6: 设置条件 -------
conditioning = instructpixtopixconditioning.encode(
positive=get_value_at_index(flux_guided, 0),
negative=get_value_at_index(empty_text, 0),
vae=get_value_at_index(VAE_MODEL, 0),
pixels=get_value_at_index(depth_processed, 0),
)
style_applied = stylemodelapplyadvanced.apply_stylemodel(
strength=style_strength,
conditioning=get_value_at_index(conditioning, 0),
style_model=get_value_at_index(STYLE_MODEL, 0),
clip_vision_output=get_value_at_index(style_encoded, 0),
)
# ------- 7.7: 创建潜在空间 -------
empty_latent = emptylatentimage.generate(
width=get_value_at_index(resized_img, 1),
height=get_value_at_index(resized_img, 2),
batch_size=1,
)
logger.info("设置采样...")
# ------- 7.8: 设置扩散引导 -------
guided = basicguider.get_guider(
model=get_value_at_index(UNET_MODEL, 0),
conditioning=get_value_at_index(style_applied, 0),
)
schedule = basicscheduler.get_sigmas(
scheduler="simple",
steps=28,
denoise=1,
model=get_value_at_index(UNET_MODEL, 0),
)
noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
logger.info("采样中...")
# ------- 7.9: 执行采样 -------
sampled = samplercustomadvanced.sample(
noise=get_value_at_index(noise, 0),
guider=get_value_at_index(guided, 0),
sampler=get_value_at_index(SAMPLER, 0),
sigmas=get_value_at_index(schedule, 0),
latent_image=get_value_at_index(empty_latent, 0),
)
logger.info("解码图像...")
# ------- 7.10: 解码结果 -------
decoded = vaedecode.decode(
samples=get_value_at_index(sampled, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
# ------- 7.11: 保存图像 -------
prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
saved = saveimage.save_images(
filename_prefix=get_value_at_index(prefix, 0),
images=get_value_at_index(decoded, 0),
)
# 获取输出路径
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
logger.info(f"图像保存到 {saved_path}")
# 读取生成的图像数据
with open(saved_path, "rb") as f:
image_data = f.read()
# 清理临时文件
try:
os.remove(structure_input_path)
os.remove(style_input_path)
os.remove(structure_image_path)
os.remove(style_image_path)
except Exception as e:
logger.warning(f"清理临时文件时出错: {e}")
# 返回生成图像的路径和数据
return saved_path, image_data
except Exception as e:
# 清理临时文件
try:
os.remove(structure_image_path)
os.remove(style_image_path)
except:
pass
logger.error(f"生成图像时出错: {e}")
raise
# 图像数据转换为base64
def image_to_base64(image_path: str) -> str:
"""
将图像文件转换为base64编码字符串
参数:
image_path: 图像文件路径
返回:
str: base64编码的图像数据包含MIME类型前缀
"""
with open(image_path, "rb") as f:
image_data = f.read()
encoded = base64.b64encode(image_data).decode("utf-8")
mime_type = "image/png" # 默认MIME类型
# 根据文件扩展名确定MIME类型
ext = os.path.splitext(image_path)[1].lower()
if ext == '.jpg' or ext == '.jpeg':
mime_type = "image/jpeg"
elif ext == '.webp':
mime_type = "image/webp"
return f"data:{mime_type};base64,{encoded}"
# 初始化FLUX环境
def init_flux():
"""初始化FLUX环境设置路径和导入模型"""
try:
setup_environment()
logger.info("FLUX环境初始化成功")
return True
except Exception as e:
logger.error(f"FLUX环境初始化失败: {e}")
return False
# 集成到MCP服务器的函数
def register_flux_resources(server):
"""
将FLUX风格塑形功能注册到MCP服务器
参数:
server: MCP服务器实例
"""
# 初始化FLUX环境
initialized = init_flux()
if not initialized:
logger.error("FLUX环境初始化失败无法注册资源")
return
# 注册资源: 服务状态
@server.register_resource("flux://status")
def get_flux_status() -> str:
"""返回FLUX服务状态"""
return "FLUX风格塑形服务已启动并运行正常"
# 注册工具: 生成风格化图像
@server.register_tool("生成风格化图像")
def generate_styled_image(prompt: str, structure_image_base64: str, style_image_base64: str,
depth_strength: float = 15.0, style_strength: float = 0.5) -> Dict[str, Any]:
"""
生成风格化图像工具
参数:
prompt: 文本提示,用于指导生成过程
structure_image_base64: 结构图像的base64编码提供基本构图
style_image_base64: 风格图像的base64编码提供艺术风格
depth_strength: 深度强度,控制结构保持程度
style_strength: 风格强度,控制风格应用程度
返回:
Dict: 包含生成图像的base64编码和其他信息
"""
try:
# 生成图像
saved_path, _ = generate_image(
prompt=prompt,
structure_image=structure_image_base64,
style_image=style_image_base64,
depth_strength=depth_strength,
style_strength=style_strength
)
# 将生成的图像转换为base64
output_base64 = image_to_base64(saved_path)
# 返回结果
return {
"status": "success",
"message": "图像生成成功",
"image_base64": output_base64,
"parameters": {
"prompt": prompt,
"depth_strength": depth_strength,
"style_strength": style_strength
}
}
except Exception as e:
logger.error(f"生成图像失败: {e}")
return {
"status": "error",
"message": f"生成图像失败: {str(e)}"
}
# 注册提示: 创建风格转移提示
@server.register_prompt("风格转移提示")
def style_transfer_prompt(subject: str, style: str) -> str:
"""创建风格转移的提示模板"""
return f"{subject}转换为{style}风格"
logger.info("FLUX风格塑形资源和工具已注册到MCP服务器")