593 lines
23 KiB
Python
593 lines
23 KiB
Python
"""
|
||
FLUX风格塑形MCP资源模块
|
||
---------------------
|
||
将FLUX风格塑形API集成到MCP服务器中
|
||
作为资源和工具提供给MCP客户端
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import random
|
||
import logging
|
||
import tempfile
|
||
import shutil
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Union, Tuple, Optional
|
||
import base64
|
||
from io import BytesIO
|
||
|
||
import torch
|
||
from PIL import Image
|
||
from huggingface_hub import hf_hub_download
|
||
from dotenv import load_dotenv
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 加载环境变量
|
||
load_dotenv(override=True)
|
||
|
||
# 获取ComfyUI根目录
|
||
def get_comfyui_dir():
|
||
"""获取ComfyUI根目录路径"""
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
comfyui_dir = os.path.dirname(os.path.dirname(current_dir)) # 假设mcp_server在ComfyUI根目录下
|
||
return comfyui_dir
|
||
|
||
# 下载所需模型
|
||
def download_models():
|
||
"""检查并下载所需模型"""
|
||
logger.info("检查并下载所需模型...")
|
||
# 获取ComfyUI根目录
|
||
comfyui_dir = get_comfyui_dir()
|
||
|
||
# 获取Hugging Face令牌
|
||
hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
||
if not hf_token:
|
||
logger.warning("未找到Hugging Face访问令牌!受限模型可能无法下载。")
|
||
logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。")
|
||
else:
|
||
logger.info("已找到Hugging Face访问令牌")
|
||
|
||
models = [
|
||
{"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")},
|
||
{"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")},
|
||
{"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")},
|
||
{"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")},
|
||
{"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")},
|
||
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")},
|
||
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")}
|
||
]
|
||
|
||
for model in models:
|
||
try:
|
||
local_path = os.path.join(model["local_dir"], model["filename"])
|
||
if not os.path.exists(local_path):
|
||
logger.info(f"下载 {model['filename']} 到 {model['local_dir']}")
|
||
os.makedirs(model["local_dir"], exist_ok=True)
|
||
hf_hub_download(
|
||
repo_id=model["repo_id"],
|
||
filename=model["filename"],
|
||
local_dir=model["local_dir"],
|
||
token=hf_token # 使用令牌进行下载
|
||
)
|
||
else:
|
||
logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}")
|
||
except Exception as e:
|
||
if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e):
|
||
logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限")
|
||
logger.error("请按照以下步骤解决:")
|
||
logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌")
|
||
logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限")
|
||
logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌")
|
||
else:
|
||
logger.error(f"下载 {model['filename']} 时出错: {e}")
|
||
|
||
# 如果是关键模型,则抛出异常阻止继续
|
||
if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]:
|
||
raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行")
|
||
|
||
# 从索引获取值的辅助函数
|
||
def get_value_at_index(obj: Union[Dict, list], index: int) -> Any:
|
||
"""从索引获取值的辅助函数"""
|
||
try:
|
||
return obj[index]
|
||
except KeyError:
|
||
return obj["result"][index]
|
||
|
||
# 添加ComfyUI目录到系统路径
|
||
def add_comfyui_directory_to_sys_path() -> None:
|
||
"""将ComfyUI目录添加到系统路径"""
|
||
comfyui_path = get_comfyui_dir()
|
||
if comfyui_path is not None and os.path.isdir(comfyui_path):
|
||
sys.path.append(comfyui_path)
|
||
logger.info(f"'{comfyui_path}' 已添加到 sys.path")
|
||
|
||
# 添加额外模型路径
|
||
def add_extra_model_paths() -> None:
|
||
"""添加额外模型路径配置"""
|
||
comfyui_dir = get_comfyui_dir()
|
||
sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块
|
||
|
||
try:
|
||
# 尝试直接从main.py导入load_extra_path_config
|
||
from main import load_extra_path_config
|
||
|
||
extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml")
|
||
if os.path.exists(extra_model_paths):
|
||
logger.info(f"找到extra_model_paths配置: {extra_model_paths}")
|
||
load_extra_path_config(extra_model_paths)
|
||
else:
|
||
logger.info("未找到extra_model_paths.yaml文件,将使用默认模型路径")
|
||
except ImportError as e:
|
||
logger.warning(f"无法导入load_extra_path_config: {e}")
|
||
logger.info("将使用默认模型路径继续运行")
|
||
|
||
# 导入自定义节点
|
||
def import_custom_nodes() -> None:
|
||
"""导入ComfyUI自定义节点"""
|
||
comfyui_dir = get_comfyui_dir()
|
||
sys.path.append(comfyui_dir)
|
||
|
||
try:
|
||
import asyncio
|
||
import execution
|
||
from nodes import init_extra_nodes
|
||
import server
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
server_instance = server.PromptServer(loop)
|
||
execution.PromptQueue(server_instance)
|
||
init_extra_nodes()
|
||
logger.info("自定义节点导入成功")
|
||
except Exception as e:
|
||
logger.error(f"导入自定义节点时出错: {e}")
|
||
raise
|
||
|
||
# 设置ComfyUI环境
|
||
def setup_environment():
|
||
"""设置ComfyUI环境,加载路径和模型"""
|
||
logger.info("设置ComfyUI环境...")
|
||
# 初始化路径
|
||
add_comfyui_directory_to_sys_path()
|
||
add_extra_model_paths()
|
||
|
||
# 下载所需模型
|
||
download_models()
|
||
|
||
# 导入自定义节点
|
||
try:
|
||
import_custom_nodes()
|
||
logger.info("自定义节点导入成功")
|
||
except Exception as e:
|
||
logger.error(f"导入自定义节点时出错: {e}")
|
||
raise RuntimeError("无法导入自定义节点,程序无法继续运行")
|
||
|
||
# 保存临时图像文件
|
||
def save_image_to_temp(image_data: Union[str, bytes]) -> str:
|
||
"""
|
||
将图像数据保存到临时文件
|
||
|
||
参数:
|
||
image_data: 图像数据,可以是base64编码字符串或字节数据
|
||
|
||
返回:
|
||
str: 临时文件路径
|
||
"""
|
||
# 如果是base64编码字符串,先解码
|
||
if isinstance(image_data, str) and image_data.startswith(('data:image', 'data:application')):
|
||
# 提取实际的base64编码部分
|
||
content_type, base64_data = image_data.split(';base64,')
|
||
image_data = base64.b64decode(base64_data)
|
||
|
||
# 创建临时文件并保存图像
|
||
suffix = '.png' # 默认文件扩展名
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||
if isinstance(image_data, str):
|
||
# 如果仍然是字符串,假设是文件路径
|
||
shutil.copyfile(image_data, tmp.name)
|
||
else:
|
||
# 否则是字节数据
|
||
tmp.write(image_data)
|
||
tmp_path = tmp.name
|
||
|
||
return tmp_path
|
||
|
||
# 图像生成主函数
|
||
def generate_image(prompt: str, structure_image: Union[str, bytes], style_image: Union[str, bytes],
|
||
depth_strength: float = 15.0, style_strength: float = 0.5) -> Tuple[str, bytes]:
|
||
"""
|
||
生成风格化图像
|
||
|
||
参数:
|
||
prompt: 文本提示
|
||
structure_image: 结构图像数据(可以是base64编码字符串、字节数据或文件路径)
|
||
style_image: 风格图像数据(可以是base64编码字符串、字节数据或文件路径)
|
||
depth_strength: 深度强度
|
||
style_strength: 风格强度
|
||
|
||
返回:
|
||
Tuple[str, bytes]: 返回生成图像的路径和图像数据
|
||
"""
|
||
# 保存输入图像到临时文件
|
||
structure_image_path = save_image_to_temp(structure_image)
|
||
style_image_path = save_image_to_temp(style_image)
|
||
|
||
try:
|
||
# 从这里开始复制原始生成函数的实现
|
||
logger.info("开始图像生成过程...")
|
||
|
||
# ====================== 第1阶段:导入必要组件 ======================
|
||
# 从ComfyUI的nodes.py导入所有需要的节点类型
|
||
from nodes import (
|
||
StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader,
|
||
SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage,
|
||
VAEDecode, UNETLoader, CLIPTextEncode,
|
||
)
|
||
|
||
# ====================== 第2阶段:初始化常量 ======================
|
||
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
|
||
CONST_1024 = intconstant.get_value(value=1024)
|
||
|
||
logger.info("加载模型...")
|
||
|
||
# ====================== 第3阶段:加载模型 ======================
|
||
dualcliploader = DualCLIPLoader()
|
||
CLIP_MODEL = dualcliploader.load_clip(
|
||
clip_name1="t5/t5xxl_fp16.safetensors",
|
||
clip_name2="clip_l.safetensors",
|
||
type="flux",
|
||
)
|
||
|
||
vaeloader = VAELoader()
|
||
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
|
||
|
||
unetloader = UNETLoader()
|
||
UNET_MODEL = unetloader.load_unet(
|
||
unet_name="flux1-depth-dev.safetensors",
|
||
weight_dtype="default"
|
||
)
|
||
|
||
clipvisionloader = CLIPVisionLoader()
|
||
CLIP_VISION_MODEL = clipvisionloader.load_clip(
|
||
clip_name="sigclip_vision_patch14_384.safetensors"
|
||
)
|
||
|
||
stylemodelloader = StyleModelLoader()
|
||
STYLE_MODEL = stylemodelloader.load_style_model(
|
||
style_model_name="flux1-redux-dev.safetensors"
|
||
)
|
||
|
||
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
|
||
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
|
||
|
||
cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
|
||
downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
|
||
DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
|
||
model="depth_anything_v2_vitl_fp32.safetensors"
|
||
)
|
||
|
||
# ====================== 第4阶段:导入辅助节点 ======================
|
||
cliptextencode = CLIPTextEncode()
|
||
loadimage = LoadImage()
|
||
vaeencode = VAEEncode()
|
||
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
|
||
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
|
||
clipvisionencode = CLIPVisionEncode()
|
||
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
|
||
emptylatentimage = EmptyLatentImage()
|
||
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
|
||
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
|
||
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
|
||
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
|
||
vaedecode = VAEDecode()
|
||
cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
|
||
saveimage = SaveImage()
|
||
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
|
||
depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
|
||
imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
|
||
|
||
# ====================== 第5阶段:预加载模型到GPU ======================
|
||
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
|
||
|
||
from comfy import model_management
|
||
model_management.load_models_gpu([
|
||
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
|
||
])
|
||
|
||
logger.info("处理输入图像...")
|
||
|
||
# ====================== 第6阶段:准备输入图像 ======================
|
||
comfyui_dir = get_comfyui_dir()
|
||
input_dir = os.path.join(comfyui_dir, "input")
|
||
os.makedirs(input_dir, exist_ok=True)
|
||
|
||
# 创建临时文件名以避免冲突
|
||
structure_filename = f"structure_{random.randint(1000, 9999)}.png"
|
||
style_filename = f"style_{random.randint(1000, 9999)}.png"
|
||
|
||
# 将上传的图像复制到输入目录
|
||
structure_input_path = os.path.join(input_dir, structure_filename)
|
||
style_input_path = os.path.join(input_dir, style_filename)
|
||
|
||
shutil.copy(structure_image_path, structure_input_path)
|
||
shutil.copy(style_image_path, style_input_path)
|
||
|
||
# ====================== 第7阶段:图像生成流程 ======================
|
||
with torch.inference_mode():
|
||
# ------- 7.1: 设置CLIP -------
|
||
clip_switch = cr_clip_input_switch.switch(
|
||
Input=1,
|
||
clip1=get_value_at_index(CLIP_MODEL, 0),
|
||
clip2=get_value_at_index(CLIP_MODEL, 0),
|
||
)
|
||
|
||
# ------- 7.2: 文本编码 -------
|
||
text_encoded = cliptextencode.encode(
|
||
text=prompt,
|
||
clip=get_value_at_index(clip_switch, 0),
|
||
)
|
||
empty_text = cliptextencode.encode(
|
||
text="",
|
||
clip=get_value_at_index(clip_switch, 0),
|
||
)
|
||
|
||
logger.info("处理结构图像...")
|
||
|
||
# ------- 7.3: 处理结构图像 -------
|
||
structure_img = loadimage.load_image(image=structure_filename)
|
||
|
||
resized_img = imageresize.execute(
|
||
width=get_value_at_index(CONST_1024, 0),
|
||
height=get_value_at_index(CONST_1024, 0),
|
||
interpolation="bicubic",
|
||
method="keep proportion",
|
||
condition="always",
|
||
multiple_of=16,
|
||
image=get_value_at_index(structure_img, 0),
|
||
)
|
||
|
||
size_info = getimagesizeandcount.getsize(
|
||
image=get_value_at_index(resized_img, 0)
|
||
)
|
||
|
||
vae_encoded = vaeencode.encode(
|
||
pixels=get_value_at_index(size_info, 0),
|
||
vae=get_value_at_index(VAE_MODEL, 0),
|
||
)
|
||
|
||
logger.info("处理深度...")
|
||
|
||
# ------- 7.4: 深度处理 -------
|
||
depth_processed = depthanything_v2.process(
|
||
da_model=get_value_at_index(DEPTH_MODEL, 0),
|
||
images=get_value_at_index(size_info, 0),
|
||
)
|
||
|
||
flux_guided = fluxguidance.append(
|
||
guidance=depth_strength,
|
||
conditioning=get_value_at_index(text_encoded, 0),
|
||
)
|
||
|
||
logger.info("处理风格图像...")
|
||
|
||
# ------- 7.5: 风格处理 -------
|
||
style_img = loadimage.load_image(image=style_filename)
|
||
|
||
style_encoded = clipvisionencode.encode(
|
||
crop="center",
|
||
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
|
||
image=get_value_at_index(style_img, 0),
|
||
)
|
||
|
||
logger.info("设置条件...")
|
||
|
||
# ------- 7.6: 设置条件 -------
|
||
conditioning = instructpixtopixconditioning.encode(
|
||
positive=get_value_at_index(flux_guided, 0),
|
||
negative=get_value_at_index(empty_text, 0),
|
||
vae=get_value_at_index(VAE_MODEL, 0),
|
||
pixels=get_value_at_index(depth_processed, 0),
|
||
)
|
||
|
||
style_applied = stylemodelapplyadvanced.apply_stylemodel(
|
||
strength=style_strength,
|
||
conditioning=get_value_at_index(conditioning, 0),
|
||
style_model=get_value_at_index(STYLE_MODEL, 0),
|
||
clip_vision_output=get_value_at_index(style_encoded, 0),
|
||
)
|
||
|
||
# ------- 7.7: 创建潜在空间 -------
|
||
empty_latent = emptylatentimage.generate(
|
||
width=get_value_at_index(resized_img, 1),
|
||
height=get_value_at_index(resized_img, 2),
|
||
batch_size=1,
|
||
)
|
||
|
||
logger.info("设置采样...")
|
||
|
||
# ------- 7.8: 设置扩散引导 -------
|
||
guided = basicguider.get_guider(
|
||
model=get_value_at_index(UNET_MODEL, 0),
|
||
conditioning=get_value_at_index(style_applied, 0),
|
||
)
|
||
|
||
schedule = basicscheduler.get_sigmas(
|
||
scheduler="simple",
|
||
steps=28,
|
||
denoise=1,
|
||
model=get_value_at_index(UNET_MODEL, 0),
|
||
)
|
||
|
||
noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
|
||
|
||
logger.info("采样中...")
|
||
|
||
# ------- 7.9: 执行采样 -------
|
||
sampled = samplercustomadvanced.sample(
|
||
noise=get_value_at_index(noise, 0),
|
||
guider=get_value_at_index(guided, 0),
|
||
sampler=get_value_at_index(SAMPLER, 0),
|
||
sigmas=get_value_at_index(schedule, 0),
|
||
latent_image=get_value_at_index(empty_latent, 0),
|
||
)
|
||
|
||
logger.info("解码图像...")
|
||
|
||
# ------- 7.10: 解码结果 -------
|
||
decoded = vaedecode.decode(
|
||
samples=get_value_at_index(sampled, 0),
|
||
vae=get_value_at_index(VAE_MODEL, 0),
|
||
)
|
||
|
||
# ------- 7.11: 保存图像 -------
|
||
prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
|
||
|
||
saved = saveimage.save_images(
|
||
filename_prefix=get_value_at_index(prefix, 0),
|
||
images=get_value_at_index(decoded, 0),
|
||
)
|
||
|
||
# 获取输出路径
|
||
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
|
||
logger.info(f"图像保存到 {saved_path}")
|
||
|
||
# 读取生成的图像数据
|
||
with open(saved_path, "rb") as f:
|
||
image_data = f.read()
|
||
|
||
# 清理临时文件
|
||
try:
|
||
os.remove(structure_input_path)
|
||
os.remove(style_input_path)
|
||
os.remove(structure_image_path)
|
||
os.remove(style_image_path)
|
||
except Exception as e:
|
||
logger.warning(f"清理临时文件时出错: {e}")
|
||
|
||
# 返回生成图像的路径和数据
|
||
return saved_path, image_data
|
||
|
||
except Exception as e:
|
||
# 清理临时文件
|
||
try:
|
||
os.remove(structure_image_path)
|
||
os.remove(style_image_path)
|
||
except:
|
||
pass
|
||
logger.error(f"生成图像时出错: {e}")
|
||
raise
|
||
|
||
# 图像数据转换为base64
|
||
def image_to_base64(image_path: str) -> str:
|
||
"""
|
||
将图像文件转换为base64编码字符串
|
||
|
||
参数:
|
||
image_path: 图像文件路径
|
||
|
||
返回:
|
||
str: base64编码的图像数据,包含MIME类型前缀
|
||
"""
|
||
with open(image_path, "rb") as f:
|
||
image_data = f.read()
|
||
|
||
encoded = base64.b64encode(image_data).decode("utf-8")
|
||
mime_type = "image/png" # 默认MIME类型
|
||
|
||
# 根据文件扩展名确定MIME类型
|
||
ext = os.path.splitext(image_path)[1].lower()
|
||
if ext == '.jpg' or ext == '.jpeg':
|
||
mime_type = "image/jpeg"
|
||
elif ext == '.webp':
|
||
mime_type = "image/webp"
|
||
|
||
return f"data:{mime_type};base64,{encoded}"
|
||
|
||
# 初始化FLUX环境
|
||
def init_flux():
|
||
"""初始化FLUX环境,设置路径和导入模型"""
|
||
try:
|
||
setup_environment()
|
||
logger.info("FLUX环境初始化成功")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"FLUX环境初始化失败: {e}")
|
||
return False
|
||
|
||
# 集成到MCP服务器的函数
|
||
def register_flux_resources(server):
|
||
"""
|
||
将FLUX风格塑形功能注册到MCP服务器
|
||
|
||
参数:
|
||
server: MCP服务器实例
|
||
"""
|
||
# 初始化FLUX环境
|
||
initialized = init_flux()
|
||
if not initialized:
|
||
logger.error("FLUX环境初始化失败,无法注册资源")
|
||
return
|
||
|
||
# 注册资源: 服务状态
|
||
@server.register_resource("flux://status")
|
||
def get_flux_status() -> str:
|
||
"""返回FLUX服务状态"""
|
||
return "FLUX风格塑形服务已启动并运行正常"
|
||
|
||
# 注册工具: 生成风格化图像
|
||
@server.register_tool("生成风格化图像")
|
||
def generate_styled_image(prompt: str, structure_image_base64: str, style_image_base64: str,
|
||
depth_strength: float = 15.0, style_strength: float = 0.5) -> Dict[str, Any]:
|
||
"""
|
||
生成风格化图像工具
|
||
|
||
参数:
|
||
prompt: 文本提示,用于指导生成过程
|
||
structure_image_base64: 结构图像的base64编码,提供基本构图
|
||
style_image_base64: 风格图像的base64编码,提供艺术风格
|
||
depth_strength: 深度强度,控制结构保持程度
|
||
style_strength: 风格强度,控制风格应用程度
|
||
|
||
返回:
|
||
Dict: 包含生成图像的base64编码和其他信息
|
||
"""
|
||
try:
|
||
# 生成图像
|
||
saved_path, _ = generate_image(
|
||
prompt=prompt,
|
||
structure_image=structure_image_base64,
|
||
style_image=style_image_base64,
|
||
depth_strength=depth_strength,
|
||
style_strength=style_strength
|
||
)
|
||
|
||
# 将生成的图像转换为base64
|
||
output_base64 = image_to_base64(saved_path)
|
||
|
||
# 返回结果
|
||
return {
|
||
"status": "success",
|
||
"message": "图像生成成功",
|
||
"image_base64": output_base64,
|
||
"parameters": {
|
||
"prompt": prompt,
|
||
"depth_strength": depth_strength,
|
||
"style_strength": style_strength
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"生成图像失败: {e}")
|
||
return {
|
||
"status": "error",
|
||
"message": f"生成图像失败: {str(e)}"
|
||
}
|
||
|
||
# 注册提示: 创建风格转移提示
|
||
@server.register_prompt("风格转移提示")
|
||
def style_transfer_prompt(subject: str, style: str) -> str:
|
||
"""创建风格转移的提示模板"""
|
||
return f"将{subject}转换为{style}风格"
|
||
|
||
logger.info("FLUX风格塑形资源和工具已注册到MCP服务器") |