593 lines
23 KiB
Python
593 lines
23 KiB
Python
|
"""
|
|||
|
FLUX风格塑形MCP资源模块
|
|||
|
---------------------
|
|||
|
将FLUX风格塑形API集成到MCP服务器中
|
|||
|
作为资源和工具提供给MCP客户端
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import random
|
|||
|
import logging
|
|||
|
import tempfile
|
|||
|
import shutil
|
|||
|
from pathlib import Path
|
|||
|
from typing import Dict, Any, Union, Tuple, Optional
|
|||
|
import base64
|
|||
|
from io import BytesIO
|
|||
|
|
|||
|
import torch
|
|||
|
from PIL import Image
|
|||
|
from huggingface_hub import hf_hub_download
|
|||
|
from dotenv import load_dotenv
|
|||
|
|
|||
|
# 设置日志
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
# 加载环境变量
|
|||
|
load_dotenv(override=True)
|
|||
|
|
|||
|
# 获取ComfyUI根目录
|
|||
|
def get_comfyui_dir():
|
|||
|
"""获取ComfyUI根目录路径"""
|
|||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
comfyui_dir = os.path.dirname(os.path.dirname(current_dir)) # 假设mcp_server在ComfyUI根目录下
|
|||
|
return comfyui_dir
|
|||
|
|
|||
|
# 下载所需模型
|
|||
|
def download_models():
|
|||
|
"""检查并下载所需模型"""
|
|||
|
logger.info("检查并下载所需模型...")
|
|||
|
# 获取ComfyUI根目录
|
|||
|
comfyui_dir = get_comfyui_dir()
|
|||
|
|
|||
|
# 获取Hugging Face令牌
|
|||
|
hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
|||
|
if not hf_token:
|
|||
|
logger.warning("未找到Hugging Face访问令牌!受限模型可能无法下载。")
|
|||
|
logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。")
|
|||
|
else:
|
|||
|
logger.info("已找到Hugging Face访问令牌")
|
|||
|
|
|||
|
models = [
|
|||
|
{"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")},
|
|||
|
{"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")},
|
|||
|
{"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")},
|
|||
|
{"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")},
|
|||
|
{"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")},
|
|||
|
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")},
|
|||
|
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")}
|
|||
|
]
|
|||
|
|
|||
|
for model in models:
|
|||
|
try:
|
|||
|
local_path = os.path.join(model["local_dir"], model["filename"])
|
|||
|
if not os.path.exists(local_path):
|
|||
|
logger.info(f"下载 {model['filename']} 到 {model['local_dir']}")
|
|||
|
os.makedirs(model["local_dir"], exist_ok=True)
|
|||
|
hf_hub_download(
|
|||
|
repo_id=model["repo_id"],
|
|||
|
filename=model["filename"],
|
|||
|
local_dir=model["local_dir"],
|
|||
|
token=hf_token # 使用令牌进行下载
|
|||
|
)
|
|||
|
else:
|
|||
|
logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}")
|
|||
|
except Exception as e:
|
|||
|
if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e):
|
|||
|
logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限")
|
|||
|
logger.error("请按照以下步骤解决:")
|
|||
|
logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌")
|
|||
|
logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限")
|
|||
|
logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌")
|
|||
|
else:
|
|||
|
logger.error(f"下载 {model['filename']} 时出错: {e}")
|
|||
|
|
|||
|
# 如果是关键模型,则抛出异常阻止继续
|
|||
|
if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]:
|
|||
|
raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行")
|
|||
|
|
|||
|
# 从索引获取值的辅助函数
|
|||
|
def get_value_at_index(obj: Union[Dict, list], index: int) -> Any:
|
|||
|
"""从索引获取值的辅助函数"""
|
|||
|
try:
|
|||
|
return obj[index]
|
|||
|
except KeyError:
|
|||
|
return obj["result"][index]
|
|||
|
|
|||
|
# 添加ComfyUI目录到系统路径
|
|||
|
def add_comfyui_directory_to_sys_path() -> None:
|
|||
|
"""将ComfyUI目录添加到系统路径"""
|
|||
|
comfyui_path = get_comfyui_dir()
|
|||
|
if comfyui_path is not None and os.path.isdir(comfyui_path):
|
|||
|
sys.path.append(comfyui_path)
|
|||
|
logger.info(f"'{comfyui_path}' 已添加到 sys.path")
|
|||
|
|
|||
|
# 添加额外模型路径
|
|||
|
def add_extra_model_paths() -> None:
|
|||
|
"""添加额外模型路径配置"""
|
|||
|
comfyui_dir = get_comfyui_dir()
|
|||
|
sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块
|
|||
|
|
|||
|
try:
|
|||
|
# 尝试直接从main.py导入load_extra_path_config
|
|||
|
from main import load_extra_path_config
|
|||
|
|
|||
|
extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml")
|
|||
|
if os.path.exists(extra_model_paths):
|
|||
|
logger.info(f"找到extra_model_paths配置: {extra_model_paths}")
|
|||
|
load_extra_path_config(extra_model_paths)
|
|||
|
else:
|
|||
|
logger.info("未找到extra_model_paths.yaml文件,将使用默认模型路径")
|
|||
|
except ImportError as e:
|
|||
|
logger.warning(f"无法导入load_extra_path_config: {e}")
|
|||
|
logger.info("将使用默认模型路径继续运行")
|
|||
|
|
|||
|
# 导入自定义节点
|
|||
|
def import_custom_nodes() -> None:
|
|||
|
"""导入ComfyUI自定义节点"""
|
|||
|
comfyui_dir = get_comfyui_dir()
|
|||
|
sys.path.append(comfyui_dir)
|
|||
|
|
|||
|
try:
|
|||
|
import asyncio
|
|||
|
import execution
|
|||
|
from nodes import init_extra_nodes
|
|||
|
import server
|
|||
|
loop = asyncio.new_event_loop()
|
|||
|
asyncio.set_event_loop(loop)
|
|||
|
server_instance = server.PromptServer(loop)
|
|||
|
execution.PromptQueue(server_instance)
|
|||
|
init_extra_nodes()
|
|||
|
logger.info("自定义节点导入成功")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"导入自定义节点时出错: {e}")
|
|||
|
raise
|
|||
|
|
|||
|
# 设置ComfyUI环境
|
|||
|
def setup_environment():
|
|||
|
"""设置ComfyUI环境,加载路径和模型"""
|
|||
|
logger.info("设置ComfyUI环境...")
|
|||
|
# 初始化路径
|
|||
|
add_comfyui_directory_to_sys_path()
|
|||
|
add_extra_model_paths()
|
|||
|
|
|||
|
# 下载所需模型
|
|||
|
download_models()
|
|||
|
|
|||
|
# 导入自定义节点
|
|||
|
try:
|
|||
|
import_custom_nodes()
|
|||
|
logger.info("自定义节点导入成功")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"导入自定义节点时出错: {e}")
|
|||
|
raise RuntimeError("无法导入自定义节点,程序无法继续运行")
|
|||
|
|
|||
|
# 保存临时图像文件
|
|||
|
def save_image_to_temp(image_data: Union[str, bytes]) -> str:
|
|||
|
"""
|
|||
|
将图像数据保存到临时文件
|
|||
|
|
|||
|
参数:
|
|||
|
image_data: 图像数据,可以是base64编码字符串或字节数据
|
|||
|
|
|||
|
返回:
|
|||
|
str: 临时文件路径
|
|||
|
"""
|
|||
|
# 如果是base64编码字符串,先解码
|
|||
|
if isinstance(image_data, str) and image_data.startswith(('data:image', 'data:application')):
|
|||
|
# 提取实际的base64编码部分
|
|||
|
content_type, base64_data = image_data.split(';base64,')
|
|||
|
image_data = base64.b64decode(base64_data)
|
|||
|
|
|||
|
# 创建临时文件并保存图像
|
|||
|
suffix = '.png' # 默认文件扩展名
|
|||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
|||
|
if isinstance(image_data, str):
|
|||
|
# 如果仍然是字符串,假设是文件路径
|
|||
|
shutil.copyfile(image_data, tmp.name)
|
|||
|
else:
|
|||
|
# 否则是字节数据
|
|||
|
tmp.write(image_data)
|
|||
|
tmp_path = tmp.name
|
|||
|
|
|||
|
return tmp_path
|
|||
|
|
|||
|
# 图像生成主函数
|
|||
|
def generate_image(prompt: str, structure_image: Union[str, bytes], style_image: Union[str, bytes],
|
|||
|
depth_strength: float = 15.0, style_strength: float = 0.5) -> Tuple[str, bytes]:
|
|||
|
"""
|
|||
|
生成风格化图像
|
|||
|
|
|||
|
参数:
|
|||
|
prompt: 文本提示
|
|||
|
structure_image: 结构图像数据(可以是base64编码字符串、字节数据或文件路径)
|
|||
|
style_image: 风格图像数据(可以是base64编码字符串、字节数据或文件路径)
|
|||
|
depth_strength: 深度强度
|
|||
|
style_strength: 风格强度
|
|||
|
|
|||
|
返回:
|
|||
|
Tuple[str, bytes]: 返回生成图像的路径和图像数据
|
|||
|
"""
|
|||
|
# 保存输入图像到临时文件
|
|||
|
structure_image_path = save_image_to_temp(structure_image)
|
|||
|
style_image_path = save_image_to_temp(style_image)
|
|||
|
|
|||
|
try:
|
|||
|
# 从这里开始复制原始生成函数的实现
|
|||
|
logger.info("开始图像生成过程...")
|
|||
|
|
|||
|
# ====================== 第1阶段:导入必要组件 ======================
|
|||
|
# 从ComfyUI的nodes.py导入所有需要的节点类型
|
|||
|
from nodes import (
|
|||
|
StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader,
|
|||
|
SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage,
|
|||
|
VAEDecode, UNETLoader, CLIPTextEncode,
|
|||
|
)
|
|||
|
|
|||
|
# ====================== 第2阶段:初始化常量 ======================
|
|||
|
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
|
|||
|
CONST_1024 = intconstant.get_value(value=1024)
|
|||
|
|
|||
|
logger.info("加载模型...")
|
|||
|
|
|||
|
# ====================== 第3阶段:加载模型 ======================
|
|||
|
dualcliploader = DualCLIPLoader()
|
|||
|
CLIP_MODEL = dualcliploader.load_clip(
|
|||
|
clip_name1="t5/t5xxl_fp16.safetensors",
|
|||
|
clip_name2="clip_l.safetensors",
|
|||
|
type="flux",
|
|||
|
)
|
|||
|
|
|||
|
vaeloader = VAELoader()
|
|||
|
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
|
|||
|
|
|||
|
unetloader = UNETLoader()
|
|||
|
UNET_MODEL = unetloader.load_unet(
|
|||
|
unet_name="flux1-depth-dev.safetensors",
|
|||
|
weight_dtype="default"
|
|||
|
)
|
|||
|
|
|||
|
clipvisionloader = CLIPVisionLoader()
|
|||
|
CLIP_VISION_MODEL = clipvisionloader.load_clip(
|
|||
|
clip_name="sigclip_vision_patch14_384.safetensors"
|
|||
|
)
|
|||
|
|
|||
|
stylemodelloader = StyleModelLoader()
|
|||
|
STYLE_MODEL = stylemodelloader.load_style_model(
|
|||
|
style_model_name="flux1-redux-dev.safetensors"
|
|||
|
)
|
|||
|
|
|||
|
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
|
|||
|
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
|
|||
|
|
|||
|
cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
|
|||
|
downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
|
|||
|
DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
|
|||
|
model="depth_anything_v2_vitl_fp32.safetensors"
|
|||
|
)
|
|||
|
|
|||
|
# ====================== 第4阶段:导入辅助节点 ======================
|
|||
|
cliptextencode = CLIPTextEncode()
|
|||
|
loadimage = LoadImage()
|
|||
|
vaeencode = VAEEncode()
|
|||
|
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
|
|||
|
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
|
|||
|
clipvisionencode = CLIPVisionEncode()
|
|||
|
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
|
|||
|
emptylatentimage = EmptyLatentImage()
|
|||
|
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
|
|||
|
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
|
|||
|
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
|
|||
|
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
|
|||
|
vaedecode = VAEDecode()
|
|||
|
cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
|
|||
|
saveimage = SaveImage()
|
|||
|
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
|
|||
|
depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
|
|||
|
imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
|
|||
|
|
|||
|
# ====================== 第5阶段:预加载模型到GPU ======================
|
|||
|
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
|
|||
|
|
|||
|
from comfy import model_management
|
|||
|
model_management.load_models_gpu([
|
|||
|
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
|
|||
|
])
|
|||
|
|
|||
|
logger.info("处理输入图像...")
|
|||
|
|
|||
|
# ====================== 第6阶段:准备输入图像 ======================
|
|||
|
comfyui_dir = get_comfyui_dir()
|
|||
|
input_dir = os.path.join(comfyui_dir, "input")
|
|||
|
os.makedirs(input_dir, exist_ok=True)
|
|||
|
|
|||
|
# 创建临时文件名以避免冲突
|
|||
|
structure_filename = f"structure_{random.randint(1000, 9999)}.png"
|
|||
|
style_filename = f"style_{random.randint(1000, 9999)}.png"
|
|||
|
|
|||
|
# 将上传的图像复制到输入目录
|
|||
|
structure_input_path = os.path.join(input_dir, structure_filename)
|
|||
|
style_input_path = os.path.join(input_dir, style_filename)
|
|||
|
|
|||
|
shutil.copy(structure_image_path, structure_input_path)
|
|||
|
shutil.copy(style_image_path, style_input_path)
|
|||
|
|
|||
|
# ====================== 第7阶段:图像生成流程 ======================
|
|||
|
with torch.inference_mode():
|
|||
|
# ------- 7.1: 设置CLIP -------
|
|||
|
clip_switch = cr_clip_input_switch.switch(
|
|||
|
Input=1,
|
|||
|
clip1=get_value_at_index(CLIP_MODEL, 0),
|
|||
|
clip2=get_value_at_index(CLIP_MODEL, 0),
|
|||
|
)
|
|||
|
|
|||
|
# ------- 7.2: 文本编码 -------
|
|||
|
text_encoded = cliptextencode.encode(
|
|||
|
text=prompt,
|
|||
|
clip=get_value_at_index(clip_switch, 0),
|
|||
|
)
|
|||
|
empty_text = cliptextencode.encode(
|
|||
|
text="",
|
|||
|
clip=get_value_at_index(clip_switch, 0),
|
|||
|
)
|
|||
|
|
|||
|
logger.info("处理结构图像...")
|
|||
|
|
|||
|
# ------- 7.3: 处理结构图像 -------
|
|||
|
structure_img = loadimage.load_image(image=structure_filename)
|
|||
|
|
|||
|
resized_img = imageresize.execute(
|
|||
|
width=get_value_at_index(CONST_1024, 0),
|
|||
|
height=get_value_at_index(CONST_1024, 0),
|
|||
|
interpolation="bicubic",
|
|||
|
method="keep proportion",
|
|||
|
condition="always",
|
|||
|
multiple_of=16,
|
|||
|
image=get_value_at_index(structure_img, 0),
|
|||
|
)
|
|||
|
|
|||
|
size_info = getimagesizeandcount.getsize(
|
|||
|
image=get_value_at_index(resized_img, 0)
|
|||
|
)
|
|||
|
|
|||
|
vae_encoded = vaeencode.encode(
|
|||
|
pixels=get_value_at_index(size_info, 0),
|
|||
|
vae=get_value_at_index(VAE_MODEL, 0),
|
|||
|
)
|
|||
|
|
|||
|
logger.info("处理深度...")
|
|||
|
|
|||
|
# ------- 7.4: 深度处理 -------
|
|||
|
depth_processed = depthanything_v2.process(
|
|||
|
da_model=get_value_at_index(DEPTH_MODEL, 0),
|
|||
|
images=get_value_at_index(size_info, 0),
|
|||
|
)
|
|||
|
|
|||
|
flux_guided = fluxguidance.append(
|
|||
|
guidance=depth_strength,
|
|||
|
conditioning=get_value_at_index(text_encoded, 0),
|
|||
|
)
|
|||
|
|
|||
|
logger.info("处理风格图像...")
|
|||
|
|
|||
|
# ------- 7.5: 风格处理 -------
|
|||
|
style_img = loadimage.load_image(image=style_filename)
|
|||
|
|
|||
|
style_encoded = clipvisionencode.encode(
|
|||
|
crop="center",
|
|||
|
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
|
|||
|
image=get_value_at_index(style_img, 0),
|
|||
|
)
|
|||
|
|
|||
|
logger.info("设置条件...")
|
|||
|
|
|||
|
# ------- 7.6: 设置条件 -------
|
|||
|
conditioning = instructpixtopixconditioning.encode(
|
|||
|
positive=get_value_at_index(flux_guided, 0),
|
|||
|
negative=get_value_at_index(empty_text, 0),
|
|||
|
vae=get_value_at_index(VAE_MODEL, 0),
|
|||
|
pixels=get_value_at_index(depth_processed, 0),
|
|||
|
)
|
|||
|
|
|||
|
style_applied = stylemodelapplyadvanced.apply_stylemodel(
|
|||
|
strength=style_strength,
|
|||
|
conditioning=get_value_at_index(conditioning, 0),
|
|||
|
style_model=get_value_at_index(STYLE_MODEL, 0),
|
|||
|
clip_vision_output=get_value_at_index(style_encoded, 0),
|
|||
|
)
|
|||
|
|
|||
|
# ------- 7.7: 创建潜在空间 -------
|
|||
|
empty_latent = emptylatentimage.generate(
|
|||
|
width=get_value_at_index(resized_img, 1),
|
|||
|
height=get_value_at_index(resized_img, 2),
|
|||
|
batch_size=1,
|
|||
|
)
|
|||
|
|
|||
|
logger.info("设置采样...")
|
|||
|
|
|||
|
# ------- 7.8: 设置扩散引导 -------
|
|||
|
guided = basicguider.get_guider(
|
|||
|
model=get_value_at_index(UNET_MODEL, 0),
|
|||
|
conditioning=get_value_at_index(style_applied, 0),
|
|||
|
)
|
|||
|
|
|||
|
schedule = basicscheduler.get_sigmas(
|
|||
|
scheduler="simple",
|
|||
|
steps=28,
|
|||
|
denoise=1,
|
|||
|
model=get_value_at_index(UNET_MODEL, 0),
|
|||
|
)
|
|||
|
|
|||
|
noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
|
|||
|
|
|||
|
logger.info("采样中...")
|
|||
|
|
|||
|
# ------- 7.9: 执行采样 -------
|
|||
|
sampled = samplercustomadvanced.sample(
|
|||
|
noise=get_value_at_index(noise, 0),
|
|||
|
guider=get_value_at_index(guided, 0),
|
|||
|
sampler=get_value_at_index(SAMPLER, 0),
|
|||
|
sigmas=get_value_at_index(schedule, 0),
|
|||
|
latent_image=get_value_at_index(empty_latent, 0),
|
|||
|
)
|
|||
|
|
|||
|
logger.info("解码图像...")
|
|||
|
|
|||
|
# ------- 7.10: 解码结果 -------
|
|||
|
decoded = vaedecode.decode(
|
|||
|
samples=get_value_at_index(sampled, 0),
|
|||
|
vae=get_value_at_index(VAE_MODEL, 0),
|
|||
|
)
|
|||
|
|
|||
|
# ------- 7.11: 保存图像 -------
|
|||
|
prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
|
|||
|
|
|||
|
saved = saveimage.save_images(
|
|||
|
filename_prefix=get_value_at_index(prefix, 0),
|
|||
|
images=get_value_at_index(decoded, 0),
|
|||
|
)
|
|||
|
|
|||
|
# 获取输出路径
|
|||
|
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
|
|||
|
logger.info(f"图像保存到 {saved_path}")
|
|||
|
|
|||
|
# 读取生成的图像数据
|
|||
|
with open(saved_path, "rb") as f:
|
|||
|
image_data = f.read()
|
|||
|
|
|||
|
# 清理临时文件
|
|||
|
try:
|
|||
|
os.remove(structure_input_path)
|
|||
|
os.remove(style_input_path)
|
|||
|
os.remove(structure_image_path)
|
|||
|
os.remove(style_image_path)
|
|||
|
except Exception as e:
|
|||
|
logger.warning(f"清理临时文件时出错: {e}")
|
|||
|
|
|||
|
# 返回生成图像的路径和数据
|
|||
|
return saved_path, image_data
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
# 清理临时文件
|
|||
|
try:
|
|||
|
os.remove(structure_image_path)
|
|||
|
os.remove(style_image_path)
|
|||
|
except:
|
|||
|
pass
|
|||
|
logger.error(f"生成图像时出错: {e}")
|
|||
|
raise
|
|||
|
|
|||
|
# 图像数据转换为base64
|
|||
|
def image_to_base64(image_path: str) -> str:
|
|||
|
"""
|
|||
|
将图像文件转换为base64编码字符串
|
|||
|
|
|||
|
参数:
|
|||
|
image_path: 图像文件路径
|
|||
|
|
|||
|
返回:
|
|||
|
str: base64编码的图像数据,包含MIME类型前缀
|
|||
|
"""
|
|||
|
with open(image_path, "rb") as f:
|
|||
|
image_data = f.read()
|
|||
|
|
|||
|
encoded = base64.b64encode(image_data).decode("utf-8")
|
|||
|
mime_type = "image/png" # 默认MIME类型
|
|||
|
|
|||
|
# 根据文件扩展名确定MIME类型
|
|||
|
ext = os.path.splitext(image_path)[1].lower()
|
|||
|
if ext == '.jpg' or ext == '.jpeg':
|
|||
|
mime_type = "image/jpeg"
|
|||
|
elif ext == '.webp':
|
|||
|
mime_type = "image/webp"
|
|||
|
|
|||
|
return f"data:{mime_type};base64,{encoded}"
|
|||
|
|
|||
|
# 初始化FLUX环境
|
|||
|
def init_flux():
|
|||
|
"""初始化FLUX环境,设置路径和导入模型"""
|
|||
|
try:
|
|||
|
setup_environment()
|
|||
|
logger.info("FLUX环境初始化成功")
|
|||
|
return True
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"FLUX环境初始化失败: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
# 集成到MCP服务器的函数
|
|||
|
def register_flux_resources(server):
|
|||
|
"""
|
|||
|
将FLUX风格塑形功能注册到MCP服务器
|
|||
|
|
|||
|
参数:
|
|||
|
server: MCP服务器实例
|
|||
|
"""
|
|||
|
# 初始化FLUX环境
|
|||
|
initialized = init_flux()
|
|||
|
if not initialized:
|
|||
|
logger.error("FLUX环境初始化失败,无法注册资源")
|
|||
|
return
|
|||
|
|
|||
|
# 注册资源: 服务状态
|
|||
|
@server.register_resource("flux://status")
|
|||
|
def get_flux_status() -> str:
|
|||
|
"""返回FLUX服务状态"""
|
|||
|
return "FLUX风格塑形服务已启动并运行正常"
|
|||
|
|
|||
|
# 注册工具: 生成风格化图像
|
|||
|
@server.register_tool("生成风格化图像")
|
|||
|
def generate_styled_image(prompt: str, structure_image_base64: str, style_image_base64: str,
|
|||
|
depth_strength: float = 15.0, style_strength: float = 0.5) -> Dict[str, Any]:
|
|||
|
"""
|
|||
|
生成风格化图像工具
|
|||
|
|
|||
|
参数:
|
|||
|
prompt: 文本提示,用于指导生成过程
|
|||
|
structure_image_base64: 结构图像的base64编码,提供基本构图
|
|||
|
style_image_base64: 风格图像的base64编码,提供艺术风格
|
|||
|
depth_strength: 深度强度,控制结构保持程度
|
|||
|
style_strength: 风格强度,控制风格应用程度
|
|||
|
|
|||
|
返回:
|
|||
|
Dict: 包含生成图像的base64编码和其他信息
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 生成图像
|
|||
|
saved_path, _ = generate_image(
|
|||
|
prompt=prompt,
|
|||
|
structure_image=structure_image_base64,
|
|||
|
style_image=style_image_base64,
|
|||
|
depth_strength=depth_strength,
|
|||
|
style_strength=style_strength
|
|||
|
)
|
|||
|
|
|||
|
# 将生成的图像转换为base64
|
|||
|
output_base64 = image_to_base64(saved_path)
|
|||
|
|
|||
|
# 返回结果
|
|||
|
return {
|
|||
|
"status": "success",
|
|||
|
"message": "图像生成成功",
|
|||
|
"image_base64": output_base64,
|
|||
|
"parameters": {
|
|||
|
"prompt": prompt,
|
|||
|
"depth_strength": depth_strength,
|
|||
|
"style_strength": style_strength
|
|||
|
}
|
|||
|
}
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"生成图像失败: {e}")
|
|||
|
return {
|
|||
|
"status": "error",
|
|||
|
"message": f"生成图像失败: {str(e)}"
|
|||
|
}
|
|||
|
|
|||
|
# 注册提示: 创建风格转移提示
|
|||
|
@server.register_prompt("风格转移提示")
|
|||
|
def style_transfer_prompt(subject: str, style: str) -> str:
|
|||
|
"""创建风格转移的提示模板"""
|
|||
|
return f"将{subject}转换为{style}风格"
|
|||
|
|
|||
|
logger.info("FLUX风格塑形资源和工具已注册到MCP服务器")
|