初始化项目

This commit is contained in:
zhxiao 2025-04-24 19:25:29 +08:00
parent 92cdc692f4
commit 3451f4cd87
19 changed files with 2698 additions and 7 deletions

1
.gitignore vendored
View File

@ -5,7 +5,6 @@ __pycache__/
!/input/example.png
/models/
/temp/
/custom_nodes/
!custom_nodes/example_node.py.example
extra_model_paths.yaml
/.vs

34
flux_style_shaper_api/.gitignore vendored Normal file
View File

@ -0,0 +1,34 @@
# Python虚拟环境
venv/
env/
ENV/
# Python缓存文件
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
# 日志文件
*.log
# 本地环境文件
.env
.env.local
# 生成的图像
生成图像_*.png
# 临时文件
tmp/
temp/
.DS_Store
Thumbs.db
# 模型文件 (由于大小原因通常不纳入版本控制)
# 直接过滤ComfyUI的整个模型目录
../models/
# 如果项目内部有模型文件也过滤
models/

View File

@ -0,0 +1,258 @@
# FLUX风格塑形API
一个基于ComfyUI和FLUX模型的图像风格塑形API服务可将结构图像和风格图像结合生成新的艺术图像。
## 概述
该API提供了FLUX风格塑形工作流的接口将结构和风格图像组合生成新的风格化图像。实现使用
- Flux[dev] Redux 进行风格转移
- Flux[dev] Depth 实现结构感知
- DepthAnything V2 进行深度估计
- SigCLIP Vision 实现图像理解
## 前置条件
在使用此API之前请确保
1. 已经安装并正确配置ComfyUI
2. 已安装FLUX模型所需的自定义节点到ComfyUI中
- 此API依赖ComfyUI的自定义节点包括但不限于提供FLUX支持的节点
- 运行前请确保您的ComfyUI环境中包含所有必要的自定义节点
3. 您的系统有足够的GPU内存运行FLUX模型建议至少8GB
4. 已安装兼容的CUDA版本和对应的PyTorch版本
- 对于CUDA 12.x包括CUDA 12.6需要PyTorch 2.2.0+cu121
- 对于CUDA 11.8/11.7/11.6需要PyTorch 2.0.1+cu118/cu117/cu116
- 使用我们提供的`create_venv.bat`或`create_venv.sh`脚本可以简化安装过程
## 模型说明
本API需要使用以下模型**首次运行时会自动下载**(可能需要一些时间):
| 模型文件 | 大小 | 用途 | 保存位置 |
|---------|------|------|---------|
| flux1-redux-dev.safetensors | ~1.3GB | 风格模型 | models/style_models |
| flux1-depth-dev.safetensors | ~1.4GB | 扩散模型 | models/diffusion_models |
| sigclip_vision_patch14_384.safetensors | ~0.3GB | CLIP视觉模型 | models/clip_vision |
| depth_anything_v2_vitl_fp32.safetensors | ~0.3GB | 深度模型 | models/depthanything |
| ae.safetensors | ~0.2GB | VAE模型 | models/vae/FLUX1 |
| clip_l.safetensors | ~0.2GB | CLIP文本编码器 | models/text_encoders |
| t5xxl_fp16.safetensors | ~5.0GB | T5文本编码器 | models/text_encoders/t5 |
**注意事项:**
- 所有模型总共约需要**8-9GB**磁盘空间
- 首次运行时会自动从Hugging Face下载这些模型
- 下载后这些模型将保存在ComfyUI的标准模型目录中
- 如果下载中断,下次运行时会继续完成下载
## Hugging Face访问令牌设置
**重要**FLUX模型是受限模型需要Hugging Face访问令牌才能下载。请按照以下步骤设置
1. **获取访问令牌**
- 访问 [Hugging Face Token设置页面](https://huggingface.co/settings/tokens)
- 创建新的访问令牌选择Read权限即可
- 复制生成的令牌
2. **申请模型访问权限**
- 访问以下模型页面并点击"Access repository"申请访问:
- [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev)
- [FLUX.1-Depth-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev)
- [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
3. **设置访问令牌**(选择一种方法):
**方法1创建.env文件**
```
# 在flux_style_shaper_api目录中创建.env文件写入
HUGGING_FACE_HUB_TOKEN=你的访问令牌
```
**方法2设置环境变量**
```bash
# Windows命令提示符
set HUGGING_FACE_HUB_TOKEN=你的访问令牌
# Windows PowerShell
$env:HUGGING_FACE_HUB_TOKEN="你的访问令牌"
# Linux/Mac
export HUGGING_FACE_HUB_TOKEN="你的访问令牌"
```
## 安装
### 方法一:使用虚拟环境(推荐)
1. 确保已安装并正常运行ComfyUI
2. 克隆或下载此目录到ComfyUI根目录
3. 使用提供的脚本创建虚拟环境:
**Windows:**
```
cd flux_style_shaper_api
create_venv.bat
```
**Linux/MacOS:**
```
cd flux_style_shaper_api
chmod +x create_venv.sh
./create_venv.sh
```
或者手动创建虚拟环境:
**Windows:**
```
cd flux_style_shaper_api
python -m venv venv
venv\Scripts\activate
pip install -r requirements.txt
```
**Linux/MacOS:**
```
cd flux_style_shaper_api
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```
### 方法二:直接安装
1. 确保已安装并正常运行ComfyUI
2. 克隆或下载此目录到ComfyUI根目录
3. 安装所需的Python包
```
cd flux_style_shaper_api
pip install -r requirements.txt
```
## 使用方法
### 使用虚拟环境启动(推荐)
1. 激活虚拟环境并启动服务:
**Windows:**
```
cd flux_style_shaper_api
venv\Scripts\activate
python flux_style_shaper_api.py
```
**Linux/MacOS:**
```
cd flux_style_shaper_api
source venv/bin/activate
python flux_style_shaper_api.py
```
### 使用启动脚本
1. 使用提供的启动脚本:
**Windows:**
```
cd flux_style_shaper_api
startup.bat
```
**Linux/MacOS:**
```
cd flux_style_shaper_api
./startup.sh
```
2. 服务将:
- 下载任何缺失的模型(首次运行可能需要一些时间)
- 安装所需的自定义节点(如果需要)
- 启动FastAPI服务
3. API端点
- 健康检查: `GET /health`
- 图像生成: `POST /generate`
4. Swagger文档访问 `http://localhost:8000/docs` 查看交互式API文档
## API接口说明
### 图像生成 `/generate`
**请求方式**: POST (multipart/form-data)
**参数**:
- `prompt`: 文本提示(可选)
- `structure_image`: 结构图像(必须)
- `style_image`: 风格图像(必须)
- `depth_strength`: 深度强度控制结构保留程度默认15.0
- `style_strength`: 风格强度控制风格应用程度默认0.5
**响应**:
- 成功:返回生成的图像文件
- 失败返回错误信息JSON
## 使用示例
### Python客户端
```python
import requests
# API端点
url = "http://localhost:8000/generate"
# 参数
data = {
"prompt": "一个漂亮的风景", # 可选
"depth_strength": 15.0, # 可选默认15
"style_strength": 0.5 # 可选默认0.5
}
# 图像文件
files = {
"structure_image": open("结构图像.jpg", "rb"),
"style_image": open("风格图像.jpg", "rb")
}
# 发送请求
response = requests.post(url, data=data, files=files)
# 保存生成的图像
if response.status_code == 200:
with open("生成图像.png", "wb") as f:
f.write(response.content)
print("图像已保存到 生成图像.png")
else:
print(f"生成失败: {response.json()}")
```
### cURL示例
```bash
curl -X POST "http://localhost:8000/generate" \
-F "prompt=一个漂亮的风景" \
-F "depth_strength=15.0" \
-F "style_strength=0.5" \
-F "structure_image=@结构图像.jpg" \
-F "style_image=@风格图像.jpg" \
--output 生成图像.png
```
## 所需模型
该服务会自动从Hugging Face下载以下模型如果它们不存在
- flux1-redux-dev.safetensors (风格模型)
- flux1-depth-dev.safetensors (扩散模型)
- sigclip_vision_patch14_384.safetensors (CLIP视觉模型)
- depth_anything_v2_vitl_fp32.safetensors (深度模型)
- ae.safetensors (FLUX的VAE)
- clip_l.safetensors (CLIP文本编码器)
- t5xxl_fp16.safetensors (FLUX的T5文本编码器)
## 鸣谢
- 原始ComfyUI工作流由[Nathan Shipley](https://x.com/CitizenPlain)提供
- FLUX模型由[Black Forest Labs](https://github.com/black-forest-labs)开发
- DepthAnything由[Kijai](https://github.com/kijai)开发

View File

@ -0,0 +1,108 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
FLUX风格塑形API的Python客户端示例
"""
import requests
import argparse
import os
from datetime import datetime
def generate_image(server_url, prompt, structure_image_path, style_image_path, depth_strength=15.0, style_strength=0.5):
"""
调用API生成风格化图像
参数:
server_url (str): API服务器地址
prompt (str): 文本提示
structure_image_path (str): 结构图像的路径
style_image_path (str): 风格图像的路径
depth_strength (float): 深度强度
style_strength (float): 风格强度
返回:
str: 生成图像的保存路径
"""
# API端点
url = f"{server_url}/generate"
# 参数
data = {
"prompt": prompt,
"depth_strength": depth_strength,
"style_strength": style_strength
}
# 图像文件
files = {
"structure_image": open(structure_image_path, "rb"),
"style_image": open(style_image_path, "rb")
}
print(f"正在发送请求到 {url}...")
print(f"参数: 提示词='{prompt}', 深度强度={depth_strength}, 风格强度={style_strength}")
print(f"结构图像: {structure_image_path}")
print(f"风格图像: {style_image_path}")
try:
# 发送请求
response = requests.post(url, data=data, files=files)
# 检查响应
if response.status_code == 200:
# 生成输出文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"生成图像_{timestamp}.png"
# 保存生成的图像
with open(output_filename, "wb") as f:
f.write(response.content)
print(f"成功! 图像已保存到 {output_filename}")
return output_filename
else:
print(f"生成失败: {response.json()}")
return None
except Exception as e:
print(f"请求出错: {e}")
return None
finally:
# 关闭文件
for f in files.values():
f.close()
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="FLUX风格塑形API客户端")
parser.add_argument("--server", default="http://localhost:8000", help="API服务器地址")
parser.add_argument("--prompt", default="", help="文本提示")
parser.add_argument("--structure", required=True, help="结构图像路径")
parser.add_argument("--style", required=True, help="风格图像路径")
parser.add_argument("--depth-strength", type=float, default=15.0, help="深度强度 (默认: 15.0)")
parser.add_argument("--style-strength", type=float, default=0.5, help="风格强度 (默认: 0.5)")
args = parser.parse_args()
# 检查文件是否存在
if not os.path.exists(args.structure):
print(f"错误: 结构图像文件不存在 {args.structure}")
return
if not os.path.exists(args.style):
print(f"错误: 风格图像文件不存在 {args.style}")
return
# 调用API
generate_image(
server_url=args.server,
prompt=args.prompt,
structure_image_path=args.structure,
style_image_path=args.style,
depth_strength=args.depth_strength,
style_strength=args.style_strength
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,90 @@
@echo off
echo 正在为FLUX风格塑形API创建Python虚拟环境...
echo.
rem 检查Python是否已安装
python --version > nul 2>&1
IF %ERRORLEVEL% NEQ 0 (
echo 错误未检测到Python安装。请安装Python 3.8或更高版本。
pause
exit /b 1
)
rem 检查是否已存在虚拟环境
IF EXIST venv (
echo 虚拟环境已存在。是否重新创建? (Y/N)
set /p recreate=
if /I "%recreate%"=="Y" (
echo 正在删除旧的虚拟环境...
rmdir /s /q venv
) else (
echo 操作已取消。
pause
exit /b 0
)
)
echo 正在创建新的虚拟环境...
python -m venv venv
IF %ERRORLEVEL% NEQ 0 (
echo 创建虚拟环境失败。请检查Python版本并确保已安装venv模块。
pause
exit /b 1
)
echo 正在激活虚拟环境并安装依赖...
call venv\Scripts\activate
pip install --upgrade pip
echo.
echo 现在需要安装PyTorch。请选择您系统上安装的CUDA版本:
echo 1. CUDA 12.x (最新适用于RTX 40系列等新显卡)
echo 2. CUDA 11.8
echo 3. CUDA 11.7
echo 4. CUDA 11.6
echo 5. 无CUDA (CPU版本不推荐)
echo 6. 手动选择其他版本
set /p cuda_choice="请选择 (1-6): "
if "%cuda_choice%"=="1" (
echo 正在安装PyTorch (CUDA 12.x)...
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
) else if "%cuda_choice%"=="2" (
echo 正在安装PyTorch (CUDA 11.8)...
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
) else if "%cuda_choice%"=="3" (
echo 正在安装PyTorch (CUDA 11.7)...
pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117
) else if "%cuda_choice%"=="4" (
echo 正在安装PyTorch (CUDA 11.6)...
pip install torch==2.0.1+cu116 torchvision==0.15.2+cu116 --index-url https://download.pytorch.org/whl/cu116
) else if "%cuda_choice%"=="5" (
echo 正在安装PyTorch (CPU版本)...
echo 警告: CPU版本将无法使用GPU加速不推荐用于图像生成
pip install torch==2.0.1 torchvision==0.15.2
) else if "%cuda_choice%"=="6" (
echo 请访问 https://pytorch.org/get-started/locally/ 选择合适的PyTorch版本
echo 安装完成后,请运行: pip install -r requirements.txt
pause
exit /b 0
) else (
echo 无效的选择将默认安装CUDA 12.x版本 (适用于最新显卡)
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
)
echo 正在安装其他依赖...
pip install -r requirements.txt
echo.
echo 验证PyTorch安装是否支持CUDA...
python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available()); print('CUDA版本:', torch.version.cuda if torch.cuda.is_available() else '不可用')"
echo.
echo 虚拟环境创建成功!您可以通过以下命令激活它:
echo venv\Scripts\activate
echo.
echo 或者直接运行 startup.bat 脚本启动API服务。
echo.
pause

View File

@ -0,0 +1,96 @@
#!/bin/bash
echo "正在为FLUX风格塑形API创建Python虚拟环境..."
echo ""
# 检查Python是否已安装
if ! command -v python3 &> /dev/null; then
echo "错误未检测到Python安装。请安装Python 3.8或更高版本。"
exit 1
fi
# 检查是否已存在虚拟环境
if [ -d "venv" ]; then
echo "虚拟环境已存在。是否重新创建? (y/n)"
read recreate
if [[ "$recreate" =~ ^[Yy]$ ]]; then
echo "正在删除旧的虚拟环境..."
rm -rf venv
else
echo "操作已取消。"
exit 0
fi
fi
echo "正在创建新的虚拟环境..."
python3 -m venv venv
if [ $? -ne 0 ]; then
echo "创建虚拟环境失败。请检查Python版本并确保已安装venv模块。"
exit 1
fi
echo "正在激活虚拟环境并安装依赖..."
source venv/bin/activate
pip install --upgrade pip
echo ""
echo "现在需要安装PyTorch。请选择您系统上安装的CUDA版本:"
echo "1. CUDA 12.x (最新适用于RTX 40系列等新显卡)"
echo "2. CUDA 11.8"
echo "3. CUDA 11.7"
echo "4. CUDA 11.6"
echo "5. 无CUDA (CPU版本不推荐)"
echo "6. 手动选择其他版本"
read -p "请选择 (1-6): " cuda_choice
case $cuda_choice in
1)
echo "正在安装PyTorch (CUDA 12.x)..."
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
;;
2)
echo "正在安装PyTorch (CUDA 11.8)..."
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
;;
3)
echo "正在安装PyTorch (CUDA 11.7)..."
pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117
;;
4)
echo "正在安装PyTorch (CUDA 11.6)..."
pip install torch==2.0.1+cu116 torchvision==0.15.2+cu116 --index-url https://download.pytorch.org/whl/cu116
;;
5)
echo "正在安装PyTorch (CPU版本)..."
echo "警告: CPU版本将无法使用GPU加速不推荐用于图像生成"
pip install torch==2.0.1 torchvision==0.15.2
;;
6)
echo "请访问 https://pytorch.org/get-started/locally/ 选择合适的PyTorch版本"
echo "安装完成后,请运行: pip install -r requirements.txt"
exit 0
;;
*)
echo "无效的选择将默认安装CUDA 12.x版本 (适用于最新显卡)"
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
;;
esac
echo "正在安装其他依赖..."
pip install -r requirements.txt
echo ""
echo "验证PyTorch安装是否支持CUDA..."
python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available()); print('CUDA版本:', torch.version.cuda if torch.cuda.is_available() else '不可用')"
echo ""
echo "虚拟环境创建成功!您可以通过以下命令激活它:"
echo "source venv/bin/activate"
echo ""
echo "或者直接运行 ./startup.sh 脚本启动API服务。"
echo ""
# 设置脚本为可执行
chmod +x startup.sh
chmod +x client_example.py

View File

@ -0,0 +1,585 @@
import os
import random
import sys
from typing import Sequence, Mapping, Any, Union
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
import logging
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks
from fastapi.responses import FileResponse
import uvicorn
import shutil
from pathlib import Path
import tempfile
from dotenv import load_dotenv
# 加载环境变量(从.env文件
load_dotenv(override=True)
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(title="FLUX风格塑形API", description="一个基于ComfyUI和FLUX模型的图像风格塑形API")
# 如果不存在则下载所需模型
def download_models():
logger.info("检查并下载所需模型...")
# 获取ComfyUI根目录
comfyui_dir = get_comfyui_dir()
# 获取Hugging Face令牌
hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
if not hf_token:
logger.warning("未找到Hugging Face访问令牌受限模型可能无法下载。")
logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。")
logger.warning("访问 https://huggingface.co/settings/tokens 获取访问令牌")
logger.warning("然后前往 https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev 等页面请求访问权限")
else:
logger.info("已找到Hugging Face访问令牌")
models = [
{"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")},
{"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")},
{"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")},
{"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")},
{"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")},
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")},
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")}
]
for model in models:
try:
local_path = os.path.join(model["local_dir"], model["filename"])
if not os.path.exists(local_path):
logger.info(f"下载 {model['filename']}{model['local_dir']}")
os.makedirs(model["local_dir"], exist_ok=True)
hf_hub_download(
repo_id=model["repo_id"],
filename=model["filename"],
local_dir=model["local_dir"],
token=hf_token # 使用令牌进行下载
)
else:
logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}")
except Exception as e:
if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e):
logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限")
logger.error("请按照以下步骤解决:")
logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌")
logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限")
logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌")
logger.error(" 或在项目目录创建 .env 文件并添加HUGGING_FACE_HUB_TOKEN=您的令牌")
else:
logger.error(f"下载 {model['filename']} 时出错: {e}")
# 如果是关键模型,则抛出异常阻止继续
if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]:
raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行")
else:
logger.warning(f"将尝试在没有 {model['filename']} 的情况下继续运行,但功能可能受限")
# 获取ComfyUI根目录
def get_comfyui_dir():
# 从当前目录向上一级查找ComfyUI根目录
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
return comfyui_dir
# 从索引获取值的辅助函数
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# 从原始脚本添加所有必要的设置函数
def find_path(name: str, path: str = None) -> str:
if path is None:
path = get_comfyui_dir()
if name in os.listdir(path):
path_name = os.path.join(path, name)
logger.info(f"{name} 找到: {path_name}")
return path_name
parent_directory = os.path.dirname(path)
if parent_directory == path:
return None
return find_path(name, parent_directory)
def add_comfyui_directory_to_sys_path() -> None:
comfyui_path = get_comfyui_dir()
if comfyui_path is not None and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
logger.info(f"'{comfyui_path}' 已添加到 sys.path")
def add_extra_model_paths() -> None:
comfyui_dir = get_comfyui_dir()
sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块
try:
# 尝试直接从main.py导入load_extra_path_config
from main import load_extra_path_config
extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml")
if os.path.exists(extra_model_paths):
logger.info(f"找到extra_model_paths配置: {extra_model_paths}")
load_extra_path_config(extra_model_paths)
else:
logger.info("未找到extra_model_paths.yaml文件将使用默认模型路径")
except ImportError as e:
# 如果导入失败,记录警告但继续运行
logger.warning(f"无法导入load_extra_path_config: {e}")
logger.info("将使用默认模型路径继续运行")
# 这不是致命错误,可以继续运行
def import_custom_nodes() -> None:
comfyui_dir = get_comfyui_dir()
sys.path.append(comfyui_dir)
try:
import asyncio
import execution
from nodes import init_extra_nodes
import server
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
init_extra_nodes()
logger.info("自定义节点导入成功")
except Exception as e:
logger.error(f"导入自定义节点时出错: {e}")
raise
# 设置ComfyUI环境
def setup_environment():
"""设置ComfyUI环境加载路径和模型"""
logger.info("设置ComfyUI环境...")
# 初始化路径
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
# 下载所需模型
download_models()
# 导入自定义节点
try:
import_custom_nodes()
logger.info("自定义节点导入成功")
except Exception as e:
logger.error(f"导入自定义节点时出错: {e}")
logger.error("请确保您已经安装了所需的ComfyUI自定义节点然后再运行此程序。")
logger.error("生成图像过程需要多个自定义节点包括但不限于支持FLUX模型的节点。")
raise RuntimeError("无法导入自定义节点,程序无法继续运行")
# 主要图像生成函数
def generate_image(prompt, structure_image_path, style_image_path, depth_strength=15, style_strength=0.5):
"""
主要生成函数处理输入并返回生成图像的路径
该函数实现了FLUX风格塑形的核心流程主要包含以下步骤
1. 加载各种模型CLIPVAEUNET等
2. 处理结构图像与风格图像
3. 通过深度估计增强结构感知
4. 应用风格模型实现风格转移
5. 采样生成最终图像
参数:
prompt (str): 文本提示用于指导生成过程
structure_image_path (str): 结构图像文件路径提供基本构图
style_image_path (str): 风格图像文件路径提供艺术风格
depth_strength (float): 深度强度控制结构保持程度
style_strength (float): 风格强度控制风格应用程度
返回:
str: 生成图像的保存路径
"""
logger.info("开始图像生成过程...")
# ====================== 第1阶段导入必要组件 ======================
# 从ComfyUI的nodes.py导入所有需要的节点类型
# 这些节点代表工作流中的不同操作组件
from nodes import (
StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader,
SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage,
VAEDecode, UNETLoader, CLIPTextEncode,
)
# ====================== 第2阶段初始化常量 ======================
# 创建一个整数常量节点并设置值为1024通常用于图像尺寸
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]() # 创建整数常量节点实例
CONST_1024 = intconstant.get_value(value=1024) # 设置并获取值为1024的常量
logger.info("加载模型...")
# ====================== 第3阶段加载模型 ======================
# 加载CLIP模型 - 用于文本和图像的理解
# 使用DualCLIPLoader加载两个CLIP模型T5和CLIP-L这是FLUX模型的标准配置
# T5是一个大型文本编码器CLIP-L用于更好的图像-文本对齐
dualcliploader = DualCLIPLoader()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5/t5xxl_fp16.safetensors", # T5文本编码器
clip_name2="clip_l.safetensors", # CLIP-L编码器
type="flux", # 使用FLUX配置
)
# 加载VAE (变分自编码器) - 用于图像压缩和解压缩
# VAE将图像转换为潜在空间表示反之亦然
# FLUX使用特定的VAE以获得更好的重建质量
vaeloader = VAELoader()
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
# 加载UNET - 这是扩散模型的核心,负责去噪过程
# FLUX-Depth是专门针对深度感知训练的扩散模型
unetloader = UNETLoader()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-depth-dev.safetensors", # 使用FLUX-Depth模型
weight_dtype="default" # 使用默认精度
)
# 加载CLIP Vision模型 - 专门用于图像理解
# 这对于捕获风格图像的特征非常重要
# SigCLIP是一个改进的CLIP视觉模型分辨率为384x384
clipvisionloader = CLIPVisionLoader()
CLIP_VISION_MODEL = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
# 加载Style Model (风格模型) - 用于风格转移
# FLUX-Redux是一个专门用于风格塑形的模型
stylemodelloader = StyleModelLoader()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
# 初始化采样器 - 控制从噪声到图像的转换过程
# Euler采样器是一种基于欧拉方法的简单但有效的采样算法
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
# 初始化深度模型 - 用于从图像估计深度
# DepthAnything V2是一个高质量的深度估计模型
# 它能从单一图像提取精确的深度信息,增强结构感知
cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
model="depth_anything_v2_vitl_fp32.safetensors"
)
# ====================== 第4阶段导入辅助节点 ======================
# 这些都是工作流中需要的各种处理节点
# 每个节点代表生成过程中的不同操作
cliptextencode = CLIPTextEncode() # 用于编码文本提示
loadimage = LoadImage() # 加载图像文件
vaeencode = VAEEncode() # 将图像编码为潜在表示
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() # FLUX特有的引导节点
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]() # 条件设置
clipvisionencode = CLIPVisionEncode() # 编码图像以提取视觉特征
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() # 应用风格
emptylatentimage = EmptyLatentImage() # 创建空白潜在图像
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]() # 基本引导器
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]() # 采样调度器
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() # 生成随机噪声
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]() # 高级采样器
vaedecode = VAEDecode() # 将潜在表示解码为图像
cr_text = NODE_CLASS_MAPPINGS["CR Text"]() # 文本处理
saveimage = SaveImage() # 保存生成的图像
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]() # 获取图像尺寸
depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]() # 深度处理
imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]() # 图像缩放
# ====================== 第5阶段预加载模型到GPU ======================
# 预加载模型到GPU以提高性能
# 这减少了处理过程中的GPU内存碎片和延迟
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
from comfy import model_management
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
])
logger.info("处理输入图像...")
# ====================== 第6阶段准备输入图像 ======================
# 将结构图像和风格图像复制到ComfyUI的input目录
# 这是必须的因为ComfyUI的LoadImage节点只能从该目录加载图像
comfyui_dir = get_comfyui_dir()
input_dir = os.path.join(comfyui_dir, "input")
os.makedirs(input_dir, exist_ok=True)
# 创建临时文件名以避免冲突
structure_filename = f"structure_{random.randint(1000, 9999)}.png"
style_filename = f"style_{random.randint(1000, 9999)}.png"
# 将上传的图像复制到输入目录
structure_input_path = os.path.join(input_dir, structure_filename)
style_input_path = os.path.join(input_dir, style_filename)
shutil.copy(structure_image_path, structure_input_path)
shutil.copy(style_image_path, style_input_path)
# ====================== 第7阶段图像生成流程 ======================
# 使用torch.inference_mode()以减少内存使用并提高速度
# 这告诉PyTorch不需要存储梯度信息
with torch.inference_mode():
# ------- 7.1: 设置CLIP -------
# 初始化CLIP开关使用同一个CLIP模型作为两个输入
# 这是因为某些节点需要两个CLIP模型作为输入但我们只需要使用一个
clip_switch = cr_clip_input_switch.switch(
Input=1, # 选择使用第一个CLIP输入
clip1=get_value_at_index(CLIP_MODEL, 0),
clip2=get_value_at_index(CLIP_MODEL, 0),
)
# ------- 7.2: 文本编码 -------
# 使用CLIP模型编码提示文本创建条件嵌入
# 这些嵌入用于引导扩散过程,告诉模型"生成什么"
text_encoded = cliptextencode.encode(
text=prompt, # 用户提供的文本提示
clip=get_value_at_index(clip_switch, 0), # 使用CLIP模型
)
# 同时创建一个空文本编码,用于无条件引导(负面提示)
empty_text = cliptextencode.encode(
text="", # 空字符串作为负面提示
clip=get_value_at_index(clip_switch, 0),
)
logger.info("处理结构图像...")
# ------- 7.3: 处理结构图像 -------
# 加载结构图像 - 这是生成过程的基础
structure_img = loadimage.load_image(image=structure_filename)
# 调整图像大小到1024x1024保持比例
# 这确保图像尺寸适合模型,同时保持图像的原始比例
resized_img = imageresize.execute(
width=get_value_at_index(CONST_1024, 0), # 宽度设为1024
height=get_value_at_index(CONST_1024, 0), # 高度设为1024
interpolation="bicubic", # 使用双三次插值获得更好的质量
method="keep proportion", # 保持图像比例
condition="always", # 总是执行调整
multiple_of=16, # 确保尺寸是16的倍数扩散模型的要求
image=get_value_at_index(structure_img, 0),
)
# 获取调整后图像的尺寸信息
size_info = getimagesizeandcount.getsize(
image=get_value_at_index(resized_img, 0)
)
# 使用VAE编码图像到潜在空间
# VAE将图像压缩到低维潜在空间这是扩散模型操作的空间
vae_encoded = vaeencode.encode(
pixels=get_value_at_index(size_info, 0), # 图像像素
vae=get_value_at_index(VAE_MODEL, 0), # VAE模型
)
logger.info("处理深度...")
# ------- 7.4: 深度处理 -------
# 使用DepthAnything模型从图像提取深度信息
# 深度图帮助模型理解图像的3D结构
depth_processed = depthanything_v2.process(
da_model=get_value_at_index(DEPTH_MODEL, 0), # 深度模型
images=get_value_at_index(size_info, 0), # 输入图像
)
# 应用Flux引导 - 这将深度信息和文本提示结合起来
# 深度强度参数控制深度信息对最终结果的影响程度
flux_guided = fluxguidance.append(
guidance=depth_strength, # 深度强度参数
conditioning=get_value_at_index(text_encoded, 0), # 文本条件
)
logger.info("处理风格图像...")
# ------- 7.5: 风格处理 -------
# 加载风格图像 - 这提供了艺术风格
style_img = loadimage.load_image(image=style_filename)
# 用CLIP Vision编码风格图像
# 这提取风格图像的视觉特征,用于后续的风格转移
style_encoded = clipvisionencode.encode(
crop="center", # 中心裁剪
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0), # CLIP视觉模型
image=get_value_at_index(style_img, 0), # 风格图像
)
logger.info("设置条件...")
# ------- 7.6: 设置条件 -------
# InstructPixToPixConditioning将文本条件、VAE和深度信息结合起来
# 这创建了一个复合条件,指导扩散过程
conditioning = instructpixtopixconditioning.encode(
positive=get_value_at_index(flux_guided, 0), # 正面条件(包含深度引导)
negative=get_value_at_index(empty_text, 0), # 负面条件(空文本)
vae=get_value_at_index(VAE_MODEL, 0), # VAE模型
pixels=get_value_at_index(depth_processed, 0), # 深度处理后的图像
)
# 应用风格 - 将风格特征应用到条件上
# 风格强度参数控制风格应用的程度
style_applied = stylemodelapplyadvanced.apply_stylemodel(
strength=style_strength, # 风格强度参数
conditioning=get_value_at_index(conditioning, 0), # 条件
style_model=get_value_at_index(STYLE_MODEL, 0), # 风格模型
clip_vision_output=get_value_at_index(style_encoded, 0), # 风格图像特征
)
# ------- 7.7: 创建潜在空间 -------
# 创建空的潜在图像 - 这是扩散过程的起点
# 尺寸与处理后的结构图像匹配
empty_latent = emptylatentimage.generate(
width=get_value_at_index(resized_img, 1), # 宽度
height=get_value_at_index(resized_img, 2), # 高度
batch_size=1, # 生成一张图像
)
logger.info("设置采样...")
# ------- 7.8: 设置扩散引导 -------
# 设置引导器 - 将UNET模型和条件结合起来
# 引导器负责根据条件引导扩散过程
guided = basicguider.get_guider(
model=get_value_at_index(UNET_MODEL, 0), # UNET模型
conditioning=get_value_at_index(style_applied, 0), # 应用了风格的条件
)
# 设置调度器 - 控制采样步骤和去噪强度
# 调度器决定了如何从噪声逐步生成图像
schedule = basicscheduler.get_sigmas(
scheduler="simple", # 简单调度器
steps=28, # 采样步数
denoise=1, # 去噪强度
model=get_value_at_index(UNET_MODEL, 0), # UNET模型
)
# 生成随机噪声 - 作为扩散过程的起点
# 不同的噪声种子会生成不同的图像
noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
logger.info("采样中...")
# ------- 7.9: 执行采样 -------
# 使用高级采样器从噪声生成图像
# 这是扩散模型的核心过程,将噪声逐步转换为图像
sampled = samplercustomadvanced.sample(
noise=get_value_at_index(noise, 0), # 随机噪声
guider=get_value_at_index(guided, 0), # 引导器
sampler=get_value_at_index(SAMPLER, 0), # 采样器
sigmas=get_value_at_index(schedule, 0), # 调度器提供的sigma值
latent_image=get_value_at_index(empty_latent, 0), # 初始潜在图像
)
logger.info("解码图像...")
# ------- 7.10: 解码结果 -------
# 使用VAE将潜在表示解码回像素空间
# 这是生成过程的最后一步,将潜在表示转换为可见图像
decoded = vaedecode.decode(
samples=get_value_at_index(sampled, 0), # 采样结果
vae=get_value_at_index(VAE_MODEL, 0), # VAE模型
)
# ------- 7.11: 保存图像 -------
# 设置保存的文件名前缀
prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
# 保存生成的图像到输出目录
saved = saveimage.save_images(
filename_prefix=get_value_at_index(prefix, 0), # 文件名前缀
images=get_value_at_index(decoded, 0), # 解码后的图像
)
# 获取输出路径
comfyui_dir = get_comfyui_dir()
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
logger.info(f"图像保存到 {saved_path}")
# 清理临时文件
try:
os.remove(structure_input_path)
os.remove(style_input_path)
except Exception as e:
logger.warning(f"清理临时文件时出错: {e}")
# 返回生成图像的路径
return saved_path
# 临时文件保存
def save_upload_file_tmp(upload_file: UploadFile) -> str:
try:
suffix = Path(upload_file.filename).suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(upload_file.file, tmp)
tmp_path = tmp.name
return tmp_path
finally:
upload_file.file.close()
# 清理临时文件
def cleanup_temp_file(filepath: str):
try:
os.unlink(filepath)
except Exception as e:
logger.warning(f"清理临时文件失败 {filepath}: {e}")
# API路由健康检查
@app.get("/health")
async def health_check():
return {"status": "ok", "message": "服务运行正常"}
# API路由图像生成
@app.post("/generate")
async def create_image(
background_tasks: BackgroundTasks,
prompt: str = Form(""),
structure_image: UploadFile = File(...),
style_image: UploadFile = File(...),
depth_strength: float = Form(15.0),
style_strength: float = Form(0.5)
):
# 保存上传的文件
structure_path = save_upload_file_tmp(structure_image)
style_path = save_upload_file_tmp(style_image)
# 添加清理任务
background_tasks.add_task(cleanup_temp_file, structure_path)
background_tasks.add_task(cleanup_temp_file, style_path)
try:
# 生成图像
output_path = generate_image(
prompt=prompt,
structure_image_path=structure_path,
style_image_path=style_path,
depth_strength=depth_strength,
style_strength=style_strength
)
# 返回生成的图像
return FileResponse(
path=output_path,
media_type="image/png",
filename=os.path.basename(output_path)
)
except Exception as e:
logger.error(f"生成图像时出错: {e}")
return {"error": f"生成图像失败: {str(e)}"}
# 初始化时设置环境
@app.on_event("startup")
async def startup_event():
setup_environment()
logger.info("API服务已启动并准备好接受请求")
# 主函数
def main():
# 启动FastAPI服务器
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,60 @@
@echo off
echo 正在启动FLUX风格塑形API服务...
echo.
rem 检查是否存在虚拟环境
IF EXIST venv (
echo 检测到虚拟环境,正在激活...
call venv\Scripts\activate
) ELSE (
echo 虚拟环境不存在,是否创建? (Y/N)
set /p create_venv=
if /I "%create_venv%"=="Y" (
echo 正在创建虚拟环境...
call create_venv.bat
echo 虚拟环境已创建并激活
) else (
echo 将使用系统Python环境
)
)
rem 安装依赖
echo 正在检查依赖...
pip install -r requirements.txt
rem 检查CUDA是否可用
echo 检查CUDA支持...
python -c "import torch; is_cuda = torch.cuda.is_available(); print('CUDA是否可用:', is_cuda); cuda_ver = torch.version.cuda if is_cuda else '不可用'; print('CUDA版本:', cuda_ver); exit(0 if is_cuda else 1)" > nul 2>&1
if %ERRORLEVEL% NEQ 0 (
echo 警告: 未检测到可用的CUDA支持。这将显著影响图像生成性能。
echo 如果您有支持CUDA的GPU请尝试运行create_venv.bat脚本重新安装适合您系统的PyTorch版本。
echo.
echo 是否继续使用CPU模式运行? (Y/N)
set /p continue_cpu=
if /I NOT "%continue_cpu%"=="Y" (
echo 操作已取消。
goto cleanup
)
echo 将使用CPU模式继续运行请注意性能会较低...
echo.
) else (
echo CUDA支持正常将使用GPU加速
echo.
)
echo.
echo 正在启动API服务...
echo.
echo API将在 http://localhost:8000 上运行
echo 可访问 http://localhost:8000/docs 查看API文档
echo.
echo 按Ctrl+C可停止服务
echo.
python flux_style_shaper_api.py
:cleanup
rem 如果使用了虚拟环境,退出时取消激活
IF EXIST venv (
call deactivate
)

View File

@ -0,0 +1,41 @@
#!/bin/bash
echo "正在启动FLUX风格塑形API服务..."
echo ""
# 检查是否存在虚拟环境
if [ -d "venv" ]; then
echo "检测到虚拟环境,正在激活..."
source venv/bin/activate
else
echo "虚拟环境不存在,是否创建? (y/n)"
read create_venv
if [[ "$create_venv" =~ ^[Yy]$ ]]; then
echo "正在创建虚拟环境..."
python -m venv venv
source venv/bin/activate
echo "虚拟环境已创建并激活"
else
echo "将使用系统Python环境"
fi
fi
# 安装依赖
echo "正在检查依赖..."
pip install -r requirements.txt
echo ""
echo "正在启动API服务..."
echo ""
echo "API将在 http://localhost:8000 上运行"
echo "可访问 http://localhost:8000/docs 查看API文档"
echo ""
echo "按Ctrl+C可停止服务"
echo ""
python flux_style_shaper_api.py
# 如果使用了虚拟环境,退出时取消激活
if [ -d "venv" ]; then
deactivate
fi

94
mcp_server/README.md Normal file
View File

@ -0,0 +1,94 @@
# MCP 服务器 - FLUX风格塑形
基于[Model Context Protocol (MCP)](https://github.com/modelcontextprotocol/python-sdk) 实现的服务器提供FLUX风格塑形功能。
## 功能特点
- 利用MCP协议提供标准化的API接口
- 支持FLUX风格塑形功能将结构图像和风格图像结合生成新图像
- 提供资源、工具和提示模板三种MCP核心组件
## 安装步骤
1. 安装依赖项
```bash
# 安装所有依赖
pip install -r requirements.txt
```
2. 获取必要的模型访问权限
FLUX模型需要从Hugging Face下载部分模型需要授权访问
- 访问 https://huggingface.co/settings/tokens 创建访问令牌
- 访问以下模型页面申请访问权限:
- https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev
- https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev
- https://huggingface.co/black-forest-labs/FLUX.1-dev
- 创建`.env`文件并添加你的令牌:
```
HUGGING_FACE_HUB_TOKEN=your_token_here
```
## 使用方法
### 启动服务器
```bash
python mcp_server/flux_starter.py --host 0.0.0.0 --port 8189
```
可用参数:
- `--host`: 监听地址默认127.0.0.1
- `--port`: 监听端口默认8189
- `--name`: 服务器名称默认FLUX-MCP
### 客户端演示
服务器启动后,可以使用演示客户端测试功能:
```bash
python mcp_server/flux_demo_client.py --structure 图像路径.jpg --style 风格图像路径.jpg
```
可用参数:
- `--host`: 服务器地址默认127.0.0.1
- `--port`: 服务器端口默认8189
- `--structure`: 结构图像路径(必需)
- `--style`: 风格图像路径(必需)
- `--prompt`: 文本提示(默认为空)
- `--depth-strength`: 深度强度默认15.0
- `--style-strength`: 风格强度默认0.5
- `--output`: 输出目录(默认:./output
## MCP资源和工具
本服务器提供以下MCP组件
### 资源 (Resources)
- `flux://status`: 获取FLUX服务状态
### 工具 (Tools)
- `生成风格化图像`: 生成风格化图像
- 参数:
- `prompt`: 文本提示
- `structure_image_base64`: 结构图像Base64编码
- `style_image_base64`: 风格图像Base64编码
- `depth_strength`: 深度强度默认15.0
- `style_strength`: 风格强度默认0.5
### 提示 (Prompts)
- `风格转移提示`: 生成风格转移提示模板
- 参数:
- `subject`: 主题
- `style`: 风格
## 开发信息
- 需要ComfyUI环境作为基础
- 使用`mcp` Python SDK版本 1.0.0 或更高
- 完全中文注释,便于理解和维护

12
mcp_server/__init__.py Normal file
View File

@ -0,0 +1,12 @@
"""
MCP服务器模块
------------
此模块提供了Model Context Protocol (MCP)服务器的实现
用于管理模型连接和资源交互
基于 https://github.com/modelcontextprotocol/python-sdk
"""
from .mcp import MCPServer, run, run_server
__all__ = ['MCPServer', 'run', 'run_server']

96
mcp_server/example.py Normal file
View File

@ -0,0 +1,96 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MCP服务器示例脚本
---------------
这个脚本演示了如何启动和使用MCP服务器
"""
import logging
import asyncio
from typing import Dict, Any
# 设置日志级别
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 导入MCP服务器
from mcp_server import MCPServer
# 自定义资源和工具示例
async def setup_custom_handlers(server: MCPServer):
"""设置自定义资源和工具"""
# 注册自定义资源
@server.register_resource("comfyui://status")
def get_comfy_status() -> str:
"""获取ComfyUI状态信息"""
return "ComfyUI 正在运行中"
# 注册自定义工具
@server.register_tool()
def generate_image(prompt: str, width: int = 512, height: int = 512) -> Dict[str, Any]:
"""生成图像的工具(示例)"""
# 这里只是示例实际实现需要与ComfyUI集成
return {
"status": "success",
"prompt": prompt,
"dimensions": f"{width}x{height}",
"message": "图像生成请求已提交"
}
# 注册自定义提示
@server.register_prompt("图像生成")
def image_generation_prompt(style: str, subject: str) -> str:
"""创建图像生成提示"""
return f"请创建一个{style}风格的{subject}图像。"
async def register_models(server: MCPServer):
"""注册示例模型"""
# 注册示例模型
model1 = server.register_model({
"name": "stable-diffusion-v1.5",
"type": "diffusion",
"description": "Stable Diffusion v1.5图像生成模型"
})
model2 = server.register_model({
"name": "llama-3-8b",
"type": "llm",
"description": "Llama 3 8B语言模型"
})
logging.info(f"已注册模型: {model1}, {model2}")
return [model1, model2]
async def main():
"""主函数"""
# 创建MCP服务器实例
server = MCPServer(server_name="ComfyUI-MCP示例")
# 设置自定义处理器
await setup_custom_handlers(server)
# 注册模型
models = await register_models(server)
# 打印服务器信息
logging.info(f"服务器统计信息: {server.get_stats()}")
# 启动服务器默认地址为127.0.0.1:8189
try:
logging.info("正在启动MCP服务器...")
await server.start()
except KeyboardInterrupt:
logging.info("服务器被用户中断")
finally:
# 清理资源
for model_id in models:
server.unregister_model(model_id)
logging.info("服务器已关闭")
if __name__ == "__main__":
# 运行主函数
asyncio.run(main())

View File

@ -0,0 +1,245 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
FLUX风格塑形MCP客户端演示脚本
--------------------------
演示如何使用MCP客户端连接到服务器并使用FLUX风格塑形功能
"""
import asyncio
import argparse
import logging
import sys
import os
import base64
from typing import Dict, Any
from pathlib import Path
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 尝试导入MCP客户端库
try:
from mcp import ClientSession, HttpServerParameters, types
from mcp.client.http import http_client
except ImportError:
logger.error("未找到MCP客户端库。请安装Model Context Protocol (MCP) Python SDK。")
logger.error("安装命令: pip install mcp")
sys.exit(1)
async def save_base64_image(base64_data: str, output_path: str) -> str:
"""
将base64编码的图像数据保存到文件
参数:
base64_data: 包含MIME类型的base64编码图像数据
output_path: 输出目录路径
返回:
str: 保存的文件路径
"""
# 确保输出目录存在
os.makedirs(output_path, exist_ok=True)
# 从base64字符串提取MIME类型和数据
if ";" in base64_data and "," in base64_data:
mime_part, data_part = base64_data.split(",", 1)
content_type = mime_part.split(";")[0].split(":")[1]
# 确定文件扩展名
if content_type == "image/jpeg":
ext = ".jpg"
elif content_type == "image/png":
ext = ".png"
elif content_type == "image/webp":
ext = ".webp"
else:
ext = ".png" # 默认扩展名
else:
# 如果格式不正确假设为PNG
data_part = base64_data
ext = ".png"
# 解码base64数据
image_data = base64.b64decode(data_part)
# 创建输出文件名
timestamp = asyncio.get_event_loop().time()
filename = f"flux_output_{int(timestamp)}{ext}"
file_path = os.path.join(output_path, filename)
# 保存图像
with open(file_path, "wb") as f:
f.write(image_data)
logger.info(f"图像已保存到: {file_path}")
return file_path
async def read_image_to_base64(image_path: str) -> str:
"""
读取图像文件并转换为base64编码
参数:
image_path: 图像文件路径
返回:
str: base64编码的图像数据
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"图像文件不存在: {image_path}")
# 读取图像文件
with open(image_path, "rb") as f:
image_data = f.read()
# 确定MIME类型
ext = os.path.splitext(image_path)[1].lower()
if ext in [".jpg", ".jpeg"]:
mime_type = "image/jpeg"
elif ext == ".png":
mime_type = "image/png"
elif ext == ".webp":
mime_type = "image/webp"
else:
mime_type = "image/png" # 默认MIME类型
# 编码为base64
encoded = base64.b64encode(image_data).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
async def run_demo(host: str, port: int, structure_image: str, style_image: str, prompt: str,
depth_strength: float, style_strength: float, output_dir: str):
"""
运行FLUX风格塑形演示
参数:
host: 服务器主机地址
port: 服务器端口
structure_image: 结构图像路径
style_image: 风格图像路径
prompt: 文本提示
depth_strength: 深度强度
style_strength: 风格强度
output_dir: 输出目录
"""
# 创建服务器参数
server_params = HttpServerParameters(
host=host,
port=port
)
# 连接到服务器
logger.info(f"正在连接到MCP服务器: {host}:{port}")
async with http_client(server_params) as (read, write):
# 创建客户端会话
async with ClientSession(read, write) as session:
# 初始化连接
logger.info("正在初始化MCP客户端会话...")
await session.initialize()
# 检查FLUX状态
try:
logger.info("正在检查FLUX服务状态...")
content, mime_type = await session.read_resource("flux://status")
logger.info(f"FLUX服务状态: {content}")
except Exception as e:
logger.error(f"获取FLUX状态失败: {e}")
return
# 列出可用工具
logger.info("正在获取可用工具列表...")
tools = await session.list_tools()
logger.info(f"可用工具: {[tool.name for tool in tools]}")
# 转换输入图像为base64
logger.info("正在读取输入图像...")
structure_image_base64 = await read_image_to_base64(structure_image)
style_image_base64 = await read_image_to_base64(style_image)
# 调用生成风格化图像工具
logger.info("正在调用生成风格化图像工具...")
try:
result = await session.call_tool(
"生成风格化图像",
arguments={
"prompt": prompt,
"structure_image_base64": structure_image_base64,
"style_image_base64": style_image_base64,
"depth_strength": depth_strength,
"style_strength": style_strength
}
)
# 处理结果
if result.get("status") == "success":
logger.info("图像生成成功")
# 保存生成的图像
image_base64 = result.get("image_base64")
if image_base64:
await save_base64_image(image_base64, output_dir)
else:
logger.warning("结果中不包含图像数据")
else:
logger.error(f"图像生成失败: {result.get('message', '未知错误')}")
except Exception as e:
logger.error(f"调用生成风格化图像工具失败: {e}")
async def main_async():
"""异步主函数"""
# 解析命令行参数
parser = argparse.ArgumentParser(description='FLUX风格塑形MCP客户端演示')
parser.add_argument('--host', type=str, default='127.0.0.1', help='MCP服务器地址')
parser.add_argument('--port', type=int, default=8189, help='MCP服务器端口')
parser.add_argument('--structure', type=str, required=True, help='结构图像路径')
parser.add_argument('--style', type=str, required=True, help='风格图像路径')
parser.add_argument('--prompt', type=str, default='', help='文本提示')
parser.add_argument('--depth-strength', type=float, default=15.0, help='深度强度')
parser.add_argument('--style-strength', type=float, default=0.5, help='风格强度')
parser.add_argument('--output', type=str, default='./output', help='输出目录')
args = parser.parse_args()
# 检查输入文件是否存在
if not os.path.exists(args.structure):
logger.error(f"结构图像文件不存在: {args.structure}")
return 1
if not os.path.exists(args.style):
logger.error(f"风格图像文件不存在: {args.style}")
return 1
# 运行演示
try:
await run_demo(
host=args.host,
port=args.port,
structure_image=args.structure,
style_image=args.style,
prompt=args.prompt,
depth_strength=args.depth_strength,
style_strength=args.style_strength,
output_dir=args.output
)
return 0
except KeyboardInterrupt:
logger.info("用户中断操作")
return 0
except Exception as e:
logger.error(f"演示运行出错: {e}")
return 1
def main():
"""主函数"""
try:
return asyncio.run(main_async())
except KeyboardInterrupt:
logger.info("用户中断操作")
return 0
except Exception as e:
logger.error(f"程序运行出错: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
FLUX风格塑形MCP服务器启动脚本
--------------------------
启动带有FLUX风格塑形功能的MCP服务器
"""
import logging
import argparse
import sys
import asyncio
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 导入MCP服务器
from mcp_server import MCPServer
from flux_style_resource import register_flux_resources
async def main_async():
"""异步主函数"""
# 解析命令行参数
parser = argparse.ArgumentParser(description='启动带有FLUX风格塑形功能的MCP服务器')
parser.add_argument('--host', type=str, default='127.0.0.1', help='服务器监听地址')
parser.add_argument('--port', type=int, default=8189, help='服务器监听端口')
parser.add_argument('--name', type=str, default='FLUX-MCP', help='服务器名称')
args = parser.parse_args()
try:
# 创建MCP服务器实例
logger.info(f"正在初始化MCP服务器: {args.name}...")
server = MCPServer(server_name=args.name)
# 注册FLUX风格塑形资源和工具
logger.info("正在注册FLUX风格塑形资源...")
register_flux_resources(server)
# 启动服务器
logger.info(f"正在启动MCP服务器监听地址: {args.host}:{args.port}")
await server.start(host=args.host, port=args.port)
except KeyboardInterrupt:
logger.info("收到中断信号,服务器关闭")
except Exception as e:
logger.error(f"服务器启动失败: {e}")
return 1
return 0
def main():
"""主函数"""
try:
return asyncio.run(main_async())
except KeyboardInterrupt:
logger.info("收到中断信号,服务器关闭")
return 0
except Exception as e:
logger.error(f"服务器运行出错: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,593 @@
"""
FLUX风格塑形MCP资源模块
---------------------
将FLUX风格塑形API集成到MCP服务器中
作为资源和工具提供给MCP客户端
"""
import os
import sys
import random
import logging
import tempfile
import shutil
from pathlib import Path
from typing import Dict, Any, Union, Tuple, Optional
import base64
from io import BytesIO
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from dotenv import load_dotenv
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 加载环境变量
load_dotenv(override=True)
# 获取ComfyUI根目录
def get_comfyui_dir():
"""获取ComfyUI根目录路径"""
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_dir = os.path.dirname(os.path.dirname(current_dir)) # 假设mcp_server在ComfyUI根目录下
return comfyui_dir
# 下载所需模型
def download_models():
"""检查并下载所需模型"""
logger.info("检查并下载所需模型...")
# 获取ComfyUI根目录
comfyui_dir = get_comfyui_dir()
# 获取Hugging Face令牌
hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
if not hf_token:
logger.warning("未找到Hugging Face访问令牌受限模型可能无法下载。")
logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。")
else:
logger.info("已找到Hugging Face访问令牌")
models = [
{"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")},
{"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")},
{"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")},
{"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")},
{"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")},
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")},
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")}
]
for model in models:
try:
local_path = os.path.join(model["local_dir"], model["filename"])
if not os.path.exists(local_path):
logger.info(f"下载 {model['filename']}{model['local_dir']}")
os.makedirs(model["local_dir"], exist_ok=True)
hf_hub_download(
repo_id=model["repo_id"],
filename=model["filename"],
local_dir=model["local_dir"],
token=hf_token # 使用令牌进行下载
)
else:
logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}")
except Exception as e:
if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e):
logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限")
logger.error("请按照以下步骤解决:")
logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌")
logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限")
logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌")
else:
logger.error(f"下载 {model['filename']} 时出错: {e}")
# 如果是关键模型,则抛出异常阻止继续
if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]:
raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行")
# 从索引获取值的辅助函数
def get_value_at_index(obj: Union[Dict, list], index: int) -> Any:
"""从索引获取值的辅助函数"""
try:
return obj[index]
except KeyError:
return obj["result"][index]
# 添加ComfyUI目录到系统路径
def add_comfyui_directory_to_sys_path() -> None:
"""将ComfyUI目录添加到系统路径"""
comfyui_path = get_comfyui_dir()
if comfyui_path is not None and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
logger.info(f"'{comfyui_path}' 已添加到 sys.path")
# 添加额外模型路径
def add_extra_model_paths() -> None:
"""添加额外模型路径配置"""
comfyui_dir = get_comfyui_dir()
sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块
try:
# 尝试直接从main.py导入load_extra_path_config
from main import load_extra_path_config
extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml")
if os.path.exists(extra_model_paths):
logger.info(f"找到extra_model_paths配置: {extra_model_paths}")
load_extra_path_config(extra_model_paths)
else:
logger.info("未找到extra_model_paths.yaml文件将使用默认模型路径")
except ImportError as e:
logger.warning(f"无法导入load_extra_path_config: {e}")
logger.info("将使用默认模型路径继续运行")
# 导入自定义节点
def import_custom_nodes() -> None:
"""导入ComfyUI自定义节点"""
comfyui_dir = get_comfyui_dir()
sys.path.append(comfyui_dir)
try:
import asyncio
import execution
from nodes import init_extra_nodes
import server
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
init_extra_nodes()
logger.info("自定义节点导入成功")
except Exception as e:
logger.error(f"导入自定义节点时出错: {e}")
raise
# 设置ComfyUI环境
def setup_environment():
"""设置ComfyUI环境加载路径和模型"""
logger.info("设置ComfyUI环境...")
# 初始化路径
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
# 下载所需模型
download_models()
# 导入自定义节点
try:
import_custom_nodes()
logger.info("自定义节点导入成功")
except Exception as e:
logger.error(f"导入自定义节点时出错: {e}")
raise RuntimeError("无法导入自定义节点,程序无法继续运行")
# 保存临时图像文件
def save_image_to_temp(image_data: Union[str, bytes]) -> str:
"""
将图像数据保存到临时文件
参数:
image_data: 图像数据可以是base64编码字符串或字节数据
返回:
str: 临时文件路径
"""
# 如果是base64编码字符串先解码
if isinstance(image_data, str) and image_data.startswith(('data:image', 'data:application')):
# 提取实际的base64编码部分
content_type, base64_data = image_data.split(';base64,')
image_data = base64.b64decode(base64_data)
# 创建临时文件并保存图像
suffix = '.png' # 默认文件扩展名
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
if isinstance(image_data, str):
# 如果仍然是字符串,假设是文件路径
shutil.copyfile(image_data, tmp.name)
else:
# 否则是字节数据
tmp.write(image_data)
tmp_path = tmp.name
return tmp_path
# 图像生成主函数
def generate_image(prompt: str, structure_image: Union[str, bytes], style_image: Union[str, bytes],
depth_strength: float = 15.0, style_strength: float = 0.5) -> Tuple[str, bytes]:
"""
生成风格化图像
参数:
prompt: 文本提示
structure_image: 结构图像数据可以是base64编码字符串字节数据或文件路径
style_image: 风格图像数据可以是base64编码字符串字节数据或文件路径
depth_strength: 深度强度
style_strength: 风格强度
返回:
Tuple[str, bytes]: 返回生成图像的路径和图像数据
"""
# 保存输入图像到临时文件
structure_image_path = save_image_to_temp(structure_image)
style_image_path = save_image_to_temp(style_image)
try:
# 从这里开始复制原始生成函数的实现
logger.info("开始图像生成过程...")
# ====================== 第1阶段导入必要组件 ======================
# 从ComfyUI的nodes.py导入所有需要的节点类型
from nodes import (
StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader,
SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage,
VAEDecode, UNETLoader, CLIPTextEncode,
)
# ====================== 第2阶段初始化常量 ======================
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
CONST_1024 = intconstant.get_value(value=1024)
logger.info("加载模型...")
# ====================== 第3阶段加载模型 ======================
dualcliploader = DualCLIPLoader()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5/t5xxl_fp16.safetensors",
clip_name2="clip_l.safetensors",
type="flux",
)
vaeloader = VAELoader()
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
unetloader = UNETLoader()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-depth-dev.safetensors",
weight_dtype="default"
)
clipvisionloader = CLIPVisionLoader()
CLIP_VISION_MODEL = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
stylemodelloader = StyleModelLoader()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
model="depth_anything_v2_vitl_fp32.safetensors"
)
# ====================== 第4阶段导入辅助节点 ======================
cliptextencode = CLIPTextEncode()
loadimage = LoadImage()
vaeencode = VAEEncode()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
clipvisionencode = CLIPVisionEncode()
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
emptylatentimage = EmptyLatentImage()
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
vaedecode = VAEDecode()
cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
saveimage = SaveImage()
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
# ====================== 第5阶段预加载模型到GPU ======================
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
from comfy import model_management
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
])
logger.info("处理输入图像...")
# ====================== 第6阶段准备输入图像 ======================
comfyui_dir = get_comfyui_dir()
input_dir = os.path.join(comfyui_dir, "input")
os.makedirs(input_dir, exist_ok=True)
# 创建临时文件名以避免冲突
structure_filename = f"structure_{random.randint(1000, 9999)}.png"
style_filename = f"style_{random.randint(1000, 9999)}.png"
# 将上传的图像复制到输入目录
structure_input_path = os.path.join(input_dir, structure_filename)
style_input_path = os.path.join(input_dir, style_filename)
shutil.copy(structure_image_path, structure_input_path)
shutil.copy(style_image_path, style_input_path)
# ====================== 第7阶段图像生成流程 ======================
with torch.inference_mode():
# ------- 7.1: 设置CLIP -------
clip_switch = cr_clip_input_switch.switch(
Input=1,
clip1=get_value_at_index(CLIP_MODEL, 0),
clip2=get_value_at_index(CLIP_MODEL, 0),
)
# ------- 7.2: 文本编码 -------
text_encoded = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(clip_switch, 0),
)
empty_text = cliptextencode.encode(
text="",
clip=get_value_at_index(clip_switch, 0),
)
logger.info("处理结构图像...")
# ------- 7.3: 处理结构图像 -------
structure_img = loadimage.load_image(image=structure_filename)
resized_img = imageresize.execute(
width=get_value_at_index(CONST_1024, 0),
height=get_value_at_index(CONST_1024, 0),
interpolation="bicubic",
method="keep proportion",
condition="always",
multiple_of=16,
image=get_value_at_index(structure_img, 0),
)
size_info = getimagesizeandcount.getsize(
image=get_value_at_index(resized_img, 0)
)
vae_encoded = vaeencode.encode(
pixels=get_value_at_index(size_info, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
logger.info("处理深度...")
# ------- 7.4: 深度处理 -------
depth_processed = depthanything_v2.process(
da_model=get_value_at_index(DEPTH_MODEL, 0),
images=get_value_at_index(size_info, 0),
)
flux_guided = fluxguidance.append(
guidance=depth_strength,
conditioning=get_value_at_index(text_encoded, 0),
)
logger.info("处理风格图像...")
# ------- 7.5: 风格处理 -------
style_img = loadimage.load_image(image=style_filename)
style_encoded = clipvisionencode.encode(
crop="center",
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
image=get_value_at_index(style_img, 0),
)
logger.info("设置条件...")
# ------- 7.6: 设置条件 -------
conditioning = instructpixtopixconditioning.encode(
positive=get_value_at_index(flux_guided, 0),
negative=get_value_at_index(empty_text, 0),
vae=get_value_at_index(VAE_MODEL, 0),
pixels=get_value_at_index(depth_processed, 0),
)
style_applied = stylemodelapplyadvanced.apply_stylemodel(
strength=style_strength,
conditioning=get_value_at_index(conditioning, 0),
style_model=get_value_at_index(STYLE_MODEL, 0),
clip_vision_output=get_value_at_index(style_encoded, 0),
)
# ------- 7.7: 创建潜在空间 -------
empty_latent = emptylatentimage.generate(
width=get_value_at_index(resized_img, 1),
height=get_value_at_index(resized_img, 2),
batch_size=1,
)
logger.info("设置采样...")
# ------- 7.8: 设置扩散引导 -------
guided = basicguider.get_guider(
model=get_value_at_index(UNET_MODEL, 0),
conditioning=get_value_at_index(style_applied, 0),
)
schedule = basicscheduler.get_sigmas(
scheduler="simple",
steps=28,
denoise=1,
model=get_value_at_index(UNET_MODEL, 0),
)
noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
logger.info("采样中...")
# ------- 7.9: 执行采样 -------
sampled = samplercustomadvanced.sample(
noise=get_value_at_index(noise, 0),
guider=get_value_at_index(guided, 0),
sampler=get_value_at_index(SAMPLER, 0),
sigmas=get_value_at_index(schedule, 0),
latent_image=get_value_at_index(empty_latent, 0),
)
logger.info("解码图像...")
# ------- 7.10: 解码结果 -------
decoded = vaedecode.decode(
samples=get_value_at_index(sampled, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
# ------- 7.11: 保存图像 -------
prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
saved = saveimage.save_images(
filename_prefix=get_value_at_index(prefix, 0),
images=get_value_at_index(decoded, 0),
)
# 获取输出路径
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
logger.info(f"图像保存到 {saved_path}")
# 读取生成的图像数据
with open(saved_path, "rb") as f:
image_data = f.read()
# 清理临时文件
try:
os.remove(structure_input_path)
os.remove(style_input_path)
os.remove(structure_image_path)
os.remove(style_image_path)
except Exception as e:
logger.warning(f"清理临时文件时出错: {e}")
# 返回生成图像的路径和数据
return saved_path, image_data
except Exception as e:
# 清理临时文件
try:
os.remove(structure_image_path)
os.remove(style_image_path)
except:
pass
logger.error(f"生成图像时出错: {e}")
raise
# 图像数据转换为base64
def image_to_base64(image_path: str) -> str:
"""
将图像文件转换为base64编码字符串
参数:
image_path: 图像文件路径
返回:
str: base64编码的图像数据包含MIME类型前缀
"""
with open(image_path, "rb") as f:
image_data = f.read()
encoded = base64.b64encode(image_data).decode("utf-8")
mime_type = "image/png" # 默认MIME类型
# 根据文件扩展名确定MIME类型
ext = os.path.splitext(image_path)[1].lower()
if ext == '.jpg' or ext == '.jpeg':
mime_type = "image/jpeg"
elif ext == '.webp':
mime_type = "image/webp"
return f"data:{mime_type};base64,{encoded}"
# 初始化FLUX环境
def init_flux():
"""初始化FLUX环境设置路径和导入模型"""
try:
setup_environment()
logger.info("FLUX环境初始化成功")
return True
except Exception as e:
logger.error(f"FLUX环境初始化失败: {e}")
return False
# 集成到MCP服务器的函数
def register_flux_resources(server):
"""
将FLUX风格塑形功能注册到MCP服务器
参数:
server: MCP服务器实例
"""
# 初始化FLUX环境
initialized = init_flux()
if not initialized:
logger.error("FLUX环境初始化失败无法注册资源")
return
# 注册资源: 服务状态
@server.register_resource("flux://status")
def get_flux_status() -> str:
"""返回FLUX服务状态"""
return "FLUX风格塑形服务已启动并运行正常"
# 注册工具: 生成风格化图像
@server.register_tool("生成风格化图像")
def generate_styled_image(prompt: str, structure_image_base64: str, style_image_base64: str,
depth_strength: float = 15.0, style_strength: float = 0.5) -> Dict[str, Any]:
"""
生成风格化图像工具
参数:
prompt: 文本提示用于指导生成过程
structure_image_base64: 结构图像的base64编码提供基本构图
style_image_base64: 风格图像的base64编码提供艺术风格
depth_strength: 深度强度控制结构保持程度
style_strength: 风格强度控制风格应用程度
返回:
Dict: 包含生成图像的base64编码和其他信息
"""
try:
# 生成图像
saved_path, _ = generate_image(
prompt=prompt,
structure_image=structure_image_base64,
style_image=style_image_base64,
depth_strength=depth_strength,
style_strength=style_strength
)
# 将生成的图像转换为base64
output_base64 = image_to_base64(saved_path)
# 返回结果
return {
"status": "success",
"message": "图像生成成功",
"image_base64": output_base64,
"parameters": {
"prompt": prompt,
"depth_strength": depth_strength,
"style_strength": style_strength
}
}
except Exception as e:
logger.error(f"生成图像失败: {e}")
return {
"status": "error",
"message": f"生成图像失败: {str(e)}"
}
# 注册提示: 创建风格转移提示
@server.register_prompt("风格转移提示")
def style_transfer_prompt(subject: str, style: str) -> str:
"""创建风格转移的提示模板"""
return f"{subject}转换为{style}风格"
logger.info("FLUX风格塑形资源和工具已注册到MCP服务器")

245
mcp_server/mcp.py Normal file
View File

@ -0,0 +1,245 @@
import asyncio
import logging
from typing import Dict, List, Optional, Any, Callable, Union
import os
# 导入MCP SDK相关库
from mcp.server.fastmcp import FastMCP
from mcp.types import Resource, Tool, Prompt
from mcp import types
class MCPServer:
"""
Model Context Protocol 服务器
用于管理模型连接和资源交互
"""
instance = None
def __init__(self, server_name: str = "ComfyUI-MCP"):
"""
初始化MCP服务器
参数:
server_name: 服务器名称
"""
MCPServer.instance = self
# 初始化FastMCP实例
self.mcp = FastMCP(server_name)
# 存储注册的资源、工具和提示
self.registered_resources: Dict[str, Callable] = {}
self.registered_tools: Dict[str, Callable] = {}
self.registered_prompts: Dict[str, Callable] = {}
# 存储模型连接信息
self.active_models: Dict[str, Dict[str, Any]] = {}
self.model_count: int = 0
# 状态标志
self.is_running: bool = False
# 设置基本资源和工具
self._setup_default_handlers()
logging.info(f"[MCP服务器] 已初始化,服务器名称: {server_name}")
def _setup_default_handlers(self):
"""设置默认的资源处理器和工具"""
# 注册状态资源
@self.mcp.resource("status://server")
def get_server_status() -> str:
"""返回服务器状态信息"""
return f"""
服务器状态:
- 运行中: {self.is_running}
- 活跃模型数: {len(self.active_models)}
- 已注册资源数: {len(self.registered_resources)}
- 已注册工具数: {len(self.registered_tools)}
- 已注册提示数: {len(self.registered_prompts)}
"""
# 注册模型列表资源
@self.mcp.resource("models://list")
def list_models() -> str:
"""返回当前连接的模型列表"""
if not self.active_models:
return "当前没有活跃的模型连接"
result = "活跃模型列表:\n"
for model_id, model_info in self.active_models.items():
result += f"- ID: {model_id}, 名称: {model_info.get('name', 'Unknown')}\n"
return result
# 注册Echo工具
@self.mcp.tool()
def echo(message: str) -> str:
"""简单的Echo工具用于测试连接"""
return f"MCP服务器回声: {message}"
# 注册系统信息工具
@self.mcp.tool()
def system_info() -> dict:
"""返回系统信息"""
import platform
return {
"os": platform.system(),
"python_version": platform.python_version(),
"hostname": platform.node(),
"cpu": platform.processor()
}
def register_resource(self, uri_pattern: str):
"""
注册一个资源处理函数
参数:
uri_pattern: 资源URI模式
"""
def decorator(func):
self.registered_resources[uri_pattern] = func
self.mcp.resource(uri_pattern)(func)
logging.info(f"[MCP服务器] 已注册资源: {uri_pattern}")
return func
return decorator
def register_tool(self, name: Optional[str] = None):
"""
注册一个工具函数
参数:
name: 工具名称可选
"""
def decorator(func):
tool_name = name or func.__name__
self.registered_tools[tool_name] = func
self.mcp.tool(name=tool_name)(func)
logging.info(f"[MCP服务器] 已注册工具: {tool_name}")
return func
return decorator
def register_prompt(self, name: Optional[str] = None):
"""
注册一个提示模板
参数:
name: 提示名称可选
"""
def decorator(func):
prompt_name = name or func.__name__
self.registered_prompts[prompt_name] = func
self.mcp.prompt(name=prompt_name)(func)
logging.info(f"[MCP服务器] 已注册提示: {prompt_name}")
return func
return decorator
def register_model(self, model_info: Dict[str, Any]) -> str:
"""
注册一个模型到MCP服务器
参数:
model_info: 模型信息字典
返回:
model_id: 模型ID
"""
model_id = f"model_{self.model_count}"
self.model_count += 1
self.active_models[model_id] = {
"id": model_id,
"registered_at": asyncio.get_event_loop().time(),
**model_info
}
logging.info(f"[MCP服务器] 已注册模型: {model_id}")
return model_id
def unregister_model(self, model_id: str) -> bool:
"""
从MCP服务器注销一个模型
参数:
model_id: 模型ID
返回:
成功与否
"""
if model_id in self.active_models:
del self.active_models[model_id]
logging.info(f"[MCP服务器] 已注销模型: {model_id}")
return True
logging.warning(f"[MCP服务器] 尝试注销未知模型: {model_id}")
return False
async def start(self, host: str = "127.0.0.1", port: int = 8189):
"""
启动MCP服务器
参数:
host: 主机地址
port: 端口号
"""
self.is_running = True
try:
# 启动FastMCP HTTP服务器
uvicorn_config = {
"host": host,
"port": port,
"log_level": "info"
}
logging.info(f"[MCP服务器] 正在启动... 地址: {host}:{port}")
# 使用FastMCP的HTTP服务器启动方法
await self.mcp.serve(**uvicorn_config)
except Exception as e:
self.is_running = False
logging.error(f"[MCP服务器] 启动失败: {str(e)}")
raise
finally:
self.is_running = False
def get_stats(self) -> Dict[str, Any]:
"""获取服务器统计信息"""
return {
"is_running": self.is_running,
"active_models_count": len(self.active_models),
"resources_count": len(self.registered_resources),
"tools_count": len(self.registered_tools),
"prompts_count": len(self.registered_prompts)
}
# 运行MCP服务器的异步函数
async def run_server(server_name: str = "ComfyUI-MCP", host: str = "127.0.0.1", port: int = 8189):
"""
运行MCP服务器的便捷函数
参数:
server_name: 服务器名称
host: 主机地址
port: 端口号
"""
server = MCPServer(server_name)
await server.start(host, port)
return server
# 同步运行MCP服务器的函数
def run(server_name: str = "ComfyUI-MCP", host: str = "127.0.0.1", port: int = 8189):
"""
同步运行MCP服务器
参数:
server_name: 服务器名称
host: 主机地址
port: 端口号
"""
loop = asyncio.get_event_loop()
try:
return loop.run_until_complete(run_server(server_name, host, port))
except KeyboardInterrupt:
logging.info("[MCP服务器] 收到中断信号,正在关闭...")
finally:
loop.close()

56
mcp_server/run.py Normal file
View File

@ -0,0 +1,56 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MCP服务器启动脚本
--------------
通过命令行启动MCP服务器
"""
import argparse
import logging
import sys
from mcp_server import run
def main():
"""主函数,解析命令行参数并启动服务器"""
# 配置命令行参数解析
parser = argparse.ArgumentParser(description='启动Model Context Protocol (MCP)服务器')
# 服务器配置参数
parser.add_argument('--name', type=str, default='ComfyUI-MCP',
help='服务器名称 (默认: ComfyUI-MCP)')
parser.add_argument('--host', type=str, default='127.0.0.1',
help='服务器监听地址 (默认: 127.0.0.1)')
parser.add_argument('--port', type=int, default=8189,
help='服务器监听端口 (默认: 8189)')
parser.add_argument('--verbose', action='store_true',
help='启用详细日志输出')
# 解析命令行参数
args = parser.parse_args()
# 配置日志级别
log_level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 记录启动信息
logging.info(f"正在启动MCP服务器 '{args.name}'{args.host}:{args.port}")
try:
# 启动服务器
run(server_name=args.name, host=args.host, port=args.port)
except KeyboardInterrupt:
logging.info("服务器被用户中断")
return 0
except Exception as e:
logging.error(f"服务器启动失败: {str(e)}")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,24 +1,41 @@
# ComfyUI 核心依赖
comfyui-frontend-package==1.16.9
comfyui-workflow-templates==0.1.3
torch
torchsde
torchvision
torchaudio
numpy>=1.25.0
numpy>=1.25.0 # flux版本是1.24.3,但我们使用更新版本
einops
transformers>=4.28.1
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
aiohttp>=3.11.8
aiohttp>=3.11.8 # 使用更新版本的要求
yarl>=1.18.0
pyyaml
Pillow
Pillow>=9.3.0 # 使用flux版本的最低要求
scipy
tqdm
tqdm>=4.65.0 # 使用flux版本的最低要求
psutil
#non essential dependencies:
# FLUX风格塑形API依赖
fastapi>=0.102.0
uvicorn>=0.23.0
python-multipart>=0.0.6
huggingface_hub>=0.19.0
GitPython>=3.1.30
python-dotenv>=1.0.0
numba
colour-science
rembg
pixeloe
transparent-background
# MCP服务器依赖
mcp>=1.0.0 # Model Context Protocol Python SDK
# 非必要依赖
kornia>=0.7.1
spandrel
soundfile

View File

@ -18,7 +18,6 @@ from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
import aiohttp
from aiohttp import web
import logging