初始化项目
This commit is contained in:
parent
92cdc692f4
commit
3451f4cd87
|
@ -5,7 +5,6 @@ __pycache__/
|
|||
!/input/example.png
|
||||
/models/
|
||||
/temp/
|
||||
/custom_nodes/
|
||||
!custom_nodes/example_node.py.example
|
||||
extra_model_paths.yaml
|
||||
/.vs
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Python虚拟环境
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# Python缓存文件
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
|
||||
# 日志文件
|
||||
*.log
|
||||
|
||||
# 本地环境文件
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# 生成的图像
|
||||
生成图像_*.png
|
||||
|
||||
# 临时文件
|
||||
tmp/
|
||||
temp/
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# 模型文件 (由于大小原因通常不纳入版本控制)
|
||||
# 直接过滤ComfyUI的整个模型目录
|
||||
../models/
|
||||
|
||||
# 如果项目内部有模型文件也过滤
|
||||
models/
|
|
@ -0,0 +1,258 @@
|
|||
# FLUX风格塑形API
|
||||
|
||||
一个基于ComfyUI和FLUX模型的图像风格塑形API服务,可将结构图像和风格图像结合生成新的艺术图像。
|
||||
|
||||
## 概述
|
||||
|
||||
该API提供了FLUX风格塑形工作流的接口,将结构和风格图像组合生成新的风格化图像。实现使用:
|
||||
|
||||
- Flux[dev] Redux 进行风格转移
|
||||
- Flux[dev] Depth 实现结构感知
|
||||
- DepthAnything V2 进行深度估计
|
||||
- SigCLIP Vision 实现图像理解
|
||||
|
||||
## 前置条件
|
||||
|
||||
在使用此API之前,请确保:
|
||||
|
||||
1. 已经安装并正确配置ComfyUI
|
||||
2. 已安装FLUX模型所需的自定义节点到ComfyUI中
|
||||
- 此API依赖ComfyUI的自定义节点,包括但不限于提供FLUX支持的节点
|
||||
- 运行前请确保您的ComfyUI环境中包含所有必要的自定义节点
|
||||
3. 您的系统有足够的GPU内存运行FLUX模型(建议至少8GB)
|
||||
4. 已安装兼容的CUDA版本和对应的PyTorch版本
|
||||
- 对于CUDA 12.x(包括CUDA 12.6):需要PyTorch 2.2.0+cu121
|
||||
- 对于CUDA 11.8/11.7/11.6:需要PyTorch 2.0.1+cu118/cu117/cu116
|
||||
- 使用我们提供的`create_venv.bat`或`create_venv.sh`脚本可以简化安装过程
|
||||
|
||||
## 模型说明
|
||||
|
||||
本API需要使用以下模型,**首次运行时会自动下载**(可能需要一些时间):
|
||||
|
||||
| 模型文件 | 大小 | 用途 | 保存位置 |
|
||||
|---------|------|------|---------|
|
||||
| flux1-redux-dev.safetensors | ~1.3GB | 风格模型 | models/style_models |
|
||||
| flux1-depth-dev.safetensors | ~1.4GB | 扩散模型 | models/diffusion_models |
|
||||
| sigclip_vision_patch14_384.safetensors | ~0.3GB | CLIP视觉模型 | models/clip_vision |
|
||||
| depth_anything_v2_vitl_fp32.safetensors | ~0.3GB | 深度模型 | models/depthanything |
|
||||
| ae.safetensors | ~0.2GB | VAE模型 | models/vae/FLUX1 |
|
||||
| clip_l.safetensors | ~0.2GB | CLIP文本编码器 | models/text_encoders |
|
||||
| t5xxl_fp16.safetensors | ~5.0GB | T5文本编码器 | models/text_encoders/t5 |
|
||||
|
||||
**注意事项:**
|
||||
- 所有模型总共约需要**8-9GB**磁盘空间
|
||||
- 首次运行时会自动从Hugging Face下载这些模型
|
||||
- 下载后,这些模型将保存在ComfyUI的标准模型目录中
|
||||
- 如果下载中断,下次运行时会继续完成下载
|
||||
|
||||
## Hugging Face访问令牌设置
|
||||
|
||||
**重要**:FLUX模型是受限模型,需要Hugging Face访问令牌才能下载。请按照以下步骤设置:
|
||||
|
||||
1. **获取访问令牌**:
|
||||
- 访问 [Hugging Face Token设置页面](https://huggingface.co/settings/tokens)
|
||||
- 创建新的访问令牌(选择Read权限即可)
|
||||
- 复制生成的令牌
|
||||
|
||||
2. **申请模型访问权限**:
|
||||
- 访问以下模型页面并点击"Access repository"申请访问:
|
||||
- [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev)
|
||||
- [FLUX.1-Depth-dev](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev)
|
||||
- [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
|
||||
3. **设置访问令牌**(选择一种方法):
|
||||
|
||||
**方法1:创建.env文件**:
|
||||
```
|
||||
# 在flux_style_shaper_api目录中创建.env文件,写入:
|
||||
HUGGING_FACE_HUB_TOKEN=你的访问令牌
|
||||
```
|
||||
|
||||
**方法2:设置环境变量**:
|
||||
```bash
|
||||
# Windows命令提示符
|
||||
set HUGGING_FACE_HUB_TOKEN=你的访问令牌
|
||||
|
||||
# Windows PowerShell
|
||||
$env:HUGGING_FACE_HUB_TOKEN="你的访问令牌"
|
||||
|
||||
# Linux/Mac
|
||||
export HUGGING_FACE_HUB_TOKEN="你的访问令牌"
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
### 方法一:使用虚拟环境(推荐)
|
||||
|
||||
1. 确保已安装并正常运行ComfyUI
|
||||
2. 克隆或下载此目录到ComfyUI根目录
|
||||
3. 使用提供的脚本创建虚拟环境:
|
||||
|
||||
**Windows:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
create_venv.bat
|
||||
```
|
||||
|
||||
**Linux/MacOS:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
chmod +x create_venv.sh
|
||||
./create_venv.sh
|
||||
```
|
||||
|
||||
或者手动创建虚拟环境:
|
||||
|
||||
**Windows:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
python -m venv venv
|
||||
venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Linux/MacOS:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 方法二:直接安装
|
||||
|
||||
1. 确保已安装并正常运行ComfyUI
|
||||
2. 克隆或下载此目录到ComfyUI根目录
|
||||
3. 安装所需的Python包:
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 使用虚拟环境启动(推荐)
|
||||
|
||||
1. 激活虚拟环境并启动服务:
|
||||
|
||||
**Windows:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
venv\Scripts\activate
|
||||
python flux_style_shaper_api.py
|
||||
```
|
||||
|
||||
**Linux/MacOS:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
source venv/bin/activate
|
||||
python flux_style_shaper_api.py
|
||||
```
|
||||
|
||||
### 使用启动脚本
|
||||
|
||||
1. 使用提供的启动脚本:
|
||||
|
||||
**Windows:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
startup.bat
|
||||
```
|
||||
|
||||
**Linux/MacOS:**
|
||||
```
|
||||
cd flux_style_shaper_api
|
||||
./startup.sh
|
||||
```
|
||||
|
||||
2. 服务将:
|
||||
- 下载任何缺失的模型(首次运行可能需要一些时间)
|
||||
- 安装所需的自定义节点(如果需要)
|
||||
- 启动FastAPI服务
|
||||
|
||||
3. API端点:
|
||||
- 健康检查: `GET /health`
|
||||
- 图像生成: `POST /generate`
|
||||
|
||||
4. Swagger文档:访问 `http://localhost:8000/docs` 查看交互式API文档
|
||||
|
||||
## API接口说明
|
||||
|
||||
### 图像生成 `/generate`
|
||||
|
||||
**请求方式**: POST (multipart/form-data)
|
||||
|
||||
**参数**:
|
||||
- `prompt`: 文本提示(可选)
|
||||
- `structure_image`: 结构图像(必须)
|
||||
- `style_image`: 风格图像(必须)
|
||||
- `depth_strength`: 深度强度,控制结构保留程度(默认15.0)
|
||||
- `style_strength`: 风格强度,控制风格应用程度(默认0.5)
|
||||
|
||||
**响应**:
|
||||
- 成功:返回生成的图像文件
|
||||
- 失败:返回错误信息JSON
|
||||
|
||||
## 使用示例
|
||||
|
||||
### Python客户端
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# API端点
|
||||
url = "http://localhost:8000/generate"
|
||||
|
||||
# 参数
|
||||
data = {
|
||||
"prompt": "一个漂亮的风景", # 可选
|
||||
"depth_strength": 15.0, # 可选,默认15
|
||||
"style_strength": 0.5 # 可选,默认0.5
|
||||
}
|
||||
|
||||
# 图像文件
|
||||
files = {
|
||||
"structure_image": open("结构图像.jpg", "rb"),
|
||||
"style_image": open("风格图像.jpg", "rb")
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
response = requests.post(url, data=data, files=files)
|
||||
|
||||
# 保存生成的图像
|
||||
if response.status_code == 200:
|
||||
with open("生成图像.png", "wb") as f:
|
||||
f.write(response.content)
|
||||
print("图像已保存到 生成图像.png")
|
||||
else:
|
||||
print(f"生成失败: {response.json()}")
|
||||
```
|
||||
|
||||
### cURL示例
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/generate" \
|
||||
-F "prompt=一个漂亮的风景" \
|
||||
-F "depth_strength=15.0" \
|
||||
-F "style_strength=0.5" \
|
||||
-F "structure_image=@结构图像.jpg" \
|
||||
-F "style_image=@风格图像.jpg" \
|
||||
--output 生成图像.png
|
||||
```
|
||||
|
||||
## 所需模型
|
||||
|
||||
该服务会自动从Hugging Face下载以下模型(如果它们不存在):
|
||||
- flux1-redux-dev.safetensors (风格模型)
|
||||
- flux1-depth-dev.safetensors (扩散模型)
|
||||
- sigclip_vision_patch14_384.safetensors (CLIP视觉模型)
|
||||
- depth_anything_v2_vitl_fp32.safetensors (深度模型)
|
||||
- ae.safetensors (FLUX的VAE)
|
||||
- clip_l.safetensors (CLIP文本编码器)
|
||||
- t5xxl_fp16.safetensors (FLUX的T5文本编码器)
|
||||
|
||||
## 鸣谢
|
||||
|
||||
- 原始ComfyUI工作流由[Nathan Shipley](https://x.com/CitizenPlain)提供
|
||||
- FLUX模型由[Black Forest Labs](https://github.com/black-forest-labs)开发
|
||||
- DepthAnything由[Kijai](https://github.com/kijai)开发
|
|
@ -0,0 +1,108 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
FLUX风格塑形API的Python客户端示例
|
||||
"""
|
||||
|
||||
import requests
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
def generate_image(server_url, prompt, structure_image_path, style_image_path, depth_strength=15.0, style_strength=0.5):
|
||||
"""
|
||||
调用API生成风格化图像
|
||||
|
||||
参数:
|
||||
server_url (str): API服务器地址
|
||||
prompt (str): 文本提示
|
||||
structure_image_path (str): 结构图像的路径
|
||||
style_image_path (str): 风格图像的路径
|
||||
depth_strength (float): 深度强度
|
||||
style_strength (float): 风格强度
|
||||
|
||||
返回:
|
||||
str: 生成图像的保存路径
|
||||
"""
|
||||
# API端点
|
||||
url = f"{server_url}/generate"
|
||||
|
||||
# 参数
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"depth_strength": depth_strength,
|
||||
"style_strength": style_strength
|
||||
}
|
||||
|
||||
# 图像文件
|
||||
files = {
|
||||
"structure_image": open(structure_image_path, "rb"),
|
||||
"style_image": open(style_image_path, "rb")
|
||||
}
|
||||
|
||||
print(f"正在发送请求到 {url}...")
|
||||
print(f"参数: 提示词='{prompt}', 深度强度={depth_strength}, 风格强度={style_strength}")
|
||||
print(f"结构图像: {structure_image_path}")
|
||||
print(f"风格图像: {style_image_path}")
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
response = requests.post(url, data=data, files=files)
|
||||
|
||||
# 检查响应
|
||||
if response.status_code == 200:
|
||||
# 生成输出文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_filename = f"生成图像_{timestamp}.png"
|
||||
|
||||
# 保存生成的图像
|
||||
with open(output_filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
print(f"成功! 图像已保存到 {output_filename}")
|
||||
return output_filename
|
||||
else:
|
||||
print(f"生成失败: {response.json()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求出错: {e}")
|
||||
return None
|
||||
finally:
|
||||
# 关闭文件
|
||||
for f in files.values():
|
||||
f.close()
|
||||
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description="FLUX风格塑形API客户端")
|
||||
parser.add_argument("--server", default="http://localhost:8000", help="API服务器地址")
|
||||
parser.add_argument("--prompt", default="", help="文本提示")
|
||||
parser.add_argument("--structure", required=True, help="结构图像路径")
|
||||
parser.add_argument("--style", required=True, help="风格图像路径")
|
||||
parser.add_argument("--depth-strength", type=float, default=15.0, help="深度强度 (默认: 15.0)")
|
||||
parser.add_argument("--style-strength", type=float, default=0.5, help="风格强度 (默认: 0.5)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(args.structure):
|
||||
print(f"错误: 结构图像文件不存在 {args.structure}")
|
||||
return
|
||||
|
||||
if not os.path.exists(args.style):
|
||||
print(f"错误: 风格图像文件不存在 {args.style}")
|
||||
return
|
||||
|
||||
# 调用API
|
||||
generate_image(
|
||||
server_url=args.server,
|
||||
prompt=args.prompt,
|
||||
structure_image_path=args.structure,
|
||||
style_image_path=args.style,
|
||||
depth_strength=args.depth_strength,
|
||||
style_strength=args.style_strength
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,90 @@
|
|||
@echo off
|
||||
echo 正在为FLUX风格塑形API创建Python虚拟环境...
|
||||
echo.
|
||||
|
||||
rem 检查Python是否已安装
|
||||
python --version > nul 2>&1
|
||||
IF %ERRORLEVEL% NEQ 0 (
|
||||
echo 错误:未检测到Python安装。请安装Python 3.8或更高版本。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
rem 检查是否已存在虚拟环境
|
||||
IF EXIST venv (
|
||||
echo 虚拟环境已存在。是否重新创建? (Y/N)
|
||||
set /p recreate=
|
||||
if /I "%recreate%"=="Y" (
|
||||
echo 正在删除旧的虚拟环境...
|
||||
rmdir /s /q venv
|
||||
) else (
|
||||
echo 操作已取消。
|
||||
pause
|
||||
exit /b 0
|
||||
)
|
||||
)
|
||||
|
||||
echo 正在创建新的虚拟环境...
|
||||
python -m venv venv
|
||||
|
||||
IF %ERRORLEVEL% NEQ 0 (
|
||||
echo 创建虚拟环境失败。请检查Python版本并确保已安装venv模块。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo 正在激活虚拟环境并安装依赖...
|
||||
call venv\Scripts\activate
|
||||
pip install --upgrade pip
|
||||
|
||||
echo.
|
||||
echo 现在需要安装PyTorch。请选择您系统上安装的CUDA版本:
|
||||
echo 1. CUDA 12.x (最新,适用于RTX 40系列等新显卡)
|
||||
echo 2. CUDA 11.8
|
||||
echo 3. CUDA 11.7
|
||||
echo 4. CUDA 11.6
|
||||
echo 5. 无CUDA (CPU版本,不推荐)
|
||||
echo 6. 手动选择其他版本
|
||||
set /p cuda_choice="请选择 (1-6): "
|
||||
|
||||
if "%cuda_choice%"=="1" (
|
||||
echo 正在安装PyTorch (CUDA 12.x)...
|
||||
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
) else if "%cuda_choice%"=="2" (
|
||||
echo 正在安装PyTorch (CUDA 11.8)...
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
) else if "%cuda_choice%"=="3" (
|
||||
echo 正在安装PyTorch (CUDA 11.7)...
|
||||
pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117
|
||||
) else if "%cuda_choice%"=="4" (
|
||||
echo 正在安装PyTorch (CUDA 11.6)...
|
||||
pip install torch==2.0.1+cu116 torchvision==0.15.2+cu116 --index-url https://download.pytorch.org/whl/cu116
|
||||
) else if "%cuda_choice%"=="5" (
|
||||
echo 正在安装PyTorch (CPU版本)...
|
||||
echo 警告: CPU版本将无法使用GPU加速,不推荐用于图像生成!
|
||||
pip install torch==2.0.1 torchvision==0.15.2
|
||||
) else if "%cuda_choice%"=="6" (
|
||||
echo 请访问 https://pytorch.org/get-started/locally/ 选择合适的PyTorch版本
|
||||
echo 安装完成后,请运行: pip install -r requirements.txt
|
||||
pause
|
||||
exit /b 0
|
||||
) else (
|
||||
echo 无效的选择,将默认安装CUDA 12.x版本 (适用于最新显卡)
|
||||
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
)
|
||||
|
||||
echo 正在安装其他依赖...
|
||||
pip install -r requirements.txt
|
||||
|
||||
echo.
|
||||
echo 验证PyTorch安装是否支持CUDA...
|
||||
python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available()); print('CUDA版本:', torch.version.cuda if torch.cuda.is_available() else '不可用')"
|
||||
|
||||
echo.
|
||||
echo 虚拟环境创建成功!您可以通过以下命令激活它:
|
||||
echo venv\Scripts\activate
|
||||
echo.
|
||||
echo 或者直接运行 startup.bat 脚本启动API服务。
|
||||
echo.
|
||||
|
||||
pause
|
|
@ -0,0 +1,96 @@
|
|||
#!/bin/bash
|
||||
|
||||
echo "正在为FLUX风格塑形API创建Python虚拟环境..."
|
||||
echo ""
|
||||
|
||||
# 检查Python是否已安装
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
echo "错误:未检测到Python安装。请安装Python 3.8或更高版本。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查是否已存在虚拟环境
|
||||
if [ -d "venv" ]; then
|
||||
echo "虚拟环境已存在。是否重新创建? (y/n)"
|
||||
read recreate
|
||||
if [[ "$recreate" =~ ^[Yy]$ ]]; then
|
||||
echo "正在删除旧的虚拟环境..."
|
||||
rm -rf venv
|
||||
else
|
||||
echo "操作已取消。"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "正在创建新的虚拟环境..."
|
||||
python3 -m venv venv
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "创建虚拟环境失败。请检查Python版本并确保已安装venv模块。"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "正在激活虚拟环境并安装依赖..."
|
||||
source venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
|
||||
echo ""
|
||||
echo "现在需要安装PyTorch。请选择您系统上安装的CUDA版本:"
|
||||
echo "1. CUDA 12.x (最新,适用于RTX 40系列等新显卡)"
|
||||
echo "2. CUDA 11.8"
|
||||
echo "3. CUDA 11.7"
|
||||
echo "4. CUDA 11.6"
|
||||
echo "5. 无CUDA (CPU版本,不推荐)"
|
||||
echo "6. 手动选择其他版本"
|
||||
read -p "请选择 (1-6): " cuda_choice
|
||||
|
||||
case $cuda_choice in
|
||||
1)
|
||||
echo "正在安装PyTorch (CUDA 12.x)..."
|
||||
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
;;
|
||||
2)
|
||||
echo "正在安装PyTorch (CUDA 11.8)..."
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
;;
|
||||
3)
|
||||
echo "正在安装PyTorch (CUDA 11.7)..."
|
||||
pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117
|
||||
;;
|
||||
4)
|
||||
echo "正在安装PyTorch (CUDA 11.6)..."
|
||||
pip install torch==2.0.1+cu116 torchvision==0.15.2+cu116 --index-url https://download.pytorch.org/whl/cu116
|
||||
;;
|
||||
5)
|
||||
echo "正在安装PyTorch (CPU版本)..."
|
||||
echo "警告: CPU版本将无法使用GPU加速,不推荐用于图像生成!"
|
||||
pip install torch==2.0.1 torchvision==0.15.2
|
||||
;;
|
||||
6)
|
||||
echo "请访问 https://pytorch.org/get-started/locally/ 选择合适的PyTorch版本"
|
||||
echo "安装完成后,请运行: pip install -r requirements.txt"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "无效的选择,将默认安装CUDA 12.x版本 (适用于最新显卡)"
|
||||
pip install torch==2.2.0+cu121 torchvision==0.17.0+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "正在安装其他依赖..."
|
||||
pip install -r requirements.txt
|
||||
|
||||
echo ""
|
||||
echo "验证PyTorch安装是否支持CUDA..."
|
||||
python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available()); print('CUDA版本:', torch.version.cuda if torch.cuda.is_available() else '不可用')"
|
||||
|
||||
echo ""
|
||||
echo "虚拟环境创建成功!您可以通过以下命令激活它:"
|
||||
echo "source venv/bin/activate"
|
||||
echo ""
|
||||
echo "或者直接运行 ./startup.sh 脚本启动API服务。"
|
||||
echo ""
|
||||
|
||||
# 设置脚本为可执行
|
||||
chmod +x startup.sh
|
||||
chmod +x client_example.py
|
|
@ -0,0 +1,585 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Sequence, Mapping, Any, Union
|
||||
import torch
|
||||
from PIL import Image
|
||||
from huggingface_hub import hf_hub_download
|
||||
import logging
|
||||
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks
|
||||
from fastapi.responses import FileResponse
|
||||
import uvicorn
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量(从.env文件)
|
||||
load_dotenv(override=True)
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(title="FLUX风格塑形API", description="一个基于ComfyUI和FLUX模型的图像风格塑形API")
|
||||
|
||||
# 如果不存在则下载所需模型
|
||||
def download_models():
|
||||
logger.info("检查并下载所需模型...")
|
||||
# 获取ComfyUI根目录
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
|
||||
# 获取Hugging Face令牌
|
||||
hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
||||
if not hf_token:
|
||||
logger.warning("未找到Hugging Face访问令牌!受限模型可能无法下载。")
|
||||
logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。")
|
||||
logger.warning("访问 https://huggingface.co/settings/tokens 获取访问令牌")
|
||||
logger.warning("然后前往 https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev 等页面请求访问权限")
|
||||
else:
|
||||
logger.info("已找到Hugging Face访问令牌")
|
||||
|
||||
models = [
|
||||
{"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")},
|
||||
{"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")},
|
||||
{"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")},
|
||||
{"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")},
|
||||
{"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")},
|
||||
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")},
|
||||
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")}
|
||||
]
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
local_path = os.path.join(model["local_dir"], model["filename"])
|
||||
if not os.path.exists(local_path):
|
||||
logger.info(f"下载 {model['filename']} 到 {model['local_dir']}")
|
||||
os.makedirs(model["local_dir"], exist_ok=True)
|
||||
hf_hub_download(
|
||||
repo_id=model["repo_id"],
|
||||
filename=model["filename"],
|
||||
local_dir=model["local_dir"],
|
||||
token=hf_token # 使用令牌进行下载
|
||||
)
|
||||
else:
|
||||
logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}")
|
||||
except Exception as e:
|
||||
if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e):
|
||||
logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限")
|
||||
logger.error("请按照以下步骤解决:")
|
||||
logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌")
|
||||
logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限")
|
||||
logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌")
|
||||
logger.error(" 或在项目目录创建 .env 文件并添加:HUGGING_FACE_HUB_TOKEN=您的令牌")
|
||||
else:
|
||||
logger.error(f"下载 {model['filename']} 时出错: {e}")
|
||||
|
||||
# 如果是关键模型,则抛出异常阻止继续
|
||||
if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]:
|
||||
raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行")
|
||||
else:
|
||||
logger.warning(f"将尝试在没有 {model['filename']} 的情况下继续运行,但功能可能受限")
|
||||
|
||||
# 获取ComfyUI根目录
|
||||
def get_comfyui_dir():
|
||||
# 从当前目录向上一级查找ComfyUI根目录
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
|
||||
return comfyui_dir
|
||||
|
||||
# 从索引获取值的辅助函数
|
||||
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
|
||||
try:
|
||||
return obj[index]
|
||||
except KeyError:
|
||||
return obj["result"][index]
|
||||
|
||||
# 从原始脚本添加所有必要的设置函数
|
||||
def find_path(name: str, path: str = None) -> str:
|
||||
if path is None:
|
||||
path = get_comfyui_dir()
|
||||
if name in os.listdir(path):
|
||||
path_name = os.path.join(path, name)
|
||||
logger.info(f"{name} 找到: {path_name}")
|
||||
return path_name
|
||||
parent_directory = os.path.dirname(path)
|
||||
if parent_directory == path:
|
||||
return None
|
||||
return find_path(name, parent_directory)
|
||||
|
||||
def add_comfyui_directory_to_sys_path() -> None:
|
||||
comfyui_path = get_comfyui_dir()
|
||||
if comfyui_path is not None and os.path.isdir(comfyui_path):
|
||||
sys.path.append(comfyui_path)
|
||||
logger.info(f"'{comfyui_path}' 已添加到 sys.path")
|
||||
|
||||
def add_extra_model_paths() -> None:
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块
|
||||
|
||||
try:
|
||||
# 尝试直接从main.py导入load_extra_path_config
|
||||
from main import load_extra_path_config
|
||||
|
||||
extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml")
|
||||
if os.path.exists(extra_model_paths):
|
||||
logger.info(f"找到extra_model_paths配置: {extra_model_paths}")
|
||||
load_extra_path_config(extra_model_paths)
|
||||
else:
|
||||
logger.info("未找到extra_model_paths.yaml文件,将使用默认模型路径")
|
||||
except ImportError as e:
|
||||
# 如果导入失败,记录警告但继续运行
|
||||
logger.warning(f"无法导入load_extra_path_config: {e}")
|
||||
logger.info("将使用默认模型路径继续运行")
|
||||
# 这不是致命错误,可以继续运行
|
||||
|
||||
def import_custom_nodes() -> None:
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
sys.path.append(comfyui_dir)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
import execution
|
||||
from nodes import init_extra_nodes
|
||||
import server
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_instance = server.PromptServer(loop)
|
||||
execution.PromptQueue(server_instance)
|
||||
init_extra_nodes()
|
||||
logger.info("自定义节点导入成功")
|
||||
except Exception as e:
|
||||
logger.error(f"导入自定义节点时出错: {e}")
|
||||
raise
|
||||
|
||||
# 设置ComfyUI环境
|
||||
def setup_environment():
|
||||
"""设置ComfyUI环境,加载路径和模型"""
|
||||
logger.info("设置ComfyUI环境...")
|
||||
# 初始化路径
|
||||
add_comfyui_directory_to_sys_path()
|
||||
add_extra_model_paths()
|
||||
|
||||
# 下载所需模型
|
||||
download_models()
|
||||
|
||||
# 导入自定义节点
|
||||
try:
|
||||
import_custom_nodes()
|
||||
logger.info("自定义节点导入成功")
|
||||
except Exception as e:
|
||||
logger.error(f"导入自定义节点时出错: {e}")
|
||||
logger.error("请确保您已经安装了所需的ComfyUI自定义节点,然后再运行此程序。")
|
||||
logger.error("生成图像过程需要多个自定义节点,包括但不限于支持FLUX模型的节点。")
|
||||
raise RuntimeError("无法导入自定义节点,程序无法继续运行")
|
||||
|
||||
# 主要图像生成函数
|
||||
def generate_image(prompt, structure_image_path, style_image_path, depth_strength=15, style_strength=0.5):
|
||||
"""
|
||||
主要生成函数,处理输入并返回生成图像的路径。
|
||||
|
||||
该函数实现了FLUX风格塑形的核心流程,主要包含以下步骤:
|
||||
1. 加载各种模型(CLIP、VAE、UNET等)
|
||||
2. 处理结构图像与风格图像
|
||||
3. 通过深度估计增强结构感知
|
||||
4. 应用风格模型实现风格转移
|
||||
5. 采样生成最终图像
|
||||
|
||||
参数:
|
||||
prompt (str): 文本提示,用于指导生成过程
|
||||
structure_image_path (str): 结构图像文件路径,提供基本构图
|
||||
style_image_path (str): 风格图像文件路径,提供艺术风格
|
||||
depth_strength (float): 深度强度,控制结构保持程度
|
||||
style_strength (float): 风格强度,控制风格应用程度
|
||||
|
||||
返回:
|
||||
str: 生成图像的保存路径
|
||||
"""
|
||||
logger.info("开始图像生成过程...")
|
||||
|
||||
# ====================== 第1阶段:导入必要组件 ======================
|
||||
# 从ComfyUI的nodes.py导入所有需要的节点类型
|
||||
# 这些节点代表工作流中的不同操作组件
|
||||
from nodes import (
|
||||
StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader,
|
||||
SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage,
|
||||
VAEDecode, UNETLoader, CLIPTextEncode,
|
||||
)
|
||||
|
||||
# ====================== 第2阶段:初始化常量 ======================
|
||||
# 创建一个整数常量节点并设置值为1024(通常用于图像尺寸)
|
||||
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]() # 创建整数常量节点实例
|
||||
CONST_1024 = intconstant.get_value(value=1024) # 设置并获取值为1024的常量
|
||||
|
||||
logger.info("加载模型...")
|
||||
|
||||
# ====================== 第3阶段:加载模型 ======================
|
||||
|
||||
# 加载CLIP模型 - 用于文本和图像的理解
|
||||
# 使用DualCLIPLoader加载两个CLIP模型:T5和CLIP-L,这是FLUX模型的标准配置
|
||||
# T5是一个大型文本编码器,CLIP-L用于更好的图像-文本对齐
|
||||
dualcliploader = DualCLIPLoader()
|
||||
CLIP_MODEL = dualcliploader.load_clip(
|
||||
clip_name1="t5/t5xxl_fp16.safetensors", # T5文本编码器
|
||||
clip_name2="clip_l.safetensors", # CLIP-L编码器
|
||||
type="flux", # 使用FLUX配置
|
||||
)
|
||||
|
||||
# 加载VAE (变分自编码器) - 用于图像压缩和解压缩
|
||||
# VAE将图像转换为潜在空间表示,反之亦然
|
||||
# FLUX使用特定的VAE以获得更好的重建质量
|
||||
vaeloader = VAELoader()
|
||||
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
|
||||
|
||||
# 加载UNET - 这是扩散模型的核心,负责去噪过程
|
||||
# FLUX-Depth是专门针对深度感知训练的扩散模型
|
||||
unetloader = UNETLoader()
|
||||
UNET_MODEL = unetloader.load_unet(
|
||||
unet_name="flux1-depth-dev.safetensors", # 使用FLUX-Depth模型
|
||||
weight_dtype="default" # 使用默认精度
|
||||
)
|
||||
|
||||
# 加载CLIP Vision模型 - 专门用于图像理解
|
||||
# 这对于捕获风格图像的特征非常重要
|
||||
# SigCLIP是一个改进的CLIP视觉模型,分辨率为384x384
|
||||
clipvisionloader = CLIPVisionLoader()
|
||||
CLIP_VISION_MODEL = clipvisionloader.load_clip(
|
||||
clip_name="sigclip_vision_patch14_384.safetensors"
|
||||
)
|
||||
|
||||
# 加载Style Model (风格模型) - 用于风格转移
|
||||
# FLUX-Redux是一个专门用于风格塑形的模型
|
||||
stylemodelloader = StyleModelLoader()
|
||||
STYLE_MODEL = stylemodelloader.load_style_model(
|
||||
style_model_name="flux1-redux-dev.safetensors"
|
||||
)
|
||||
|
||||
# 初始化采样器 - 控制从噪声到图像的转换过程
|
||||
# Euler采样器是一种基于欧拉方法的简单但有效的采样算法
|
||||
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
|
||||
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
|
||||
|
||||
# 初始化深度模型 - 用于从图像估计深度
|
||||
# DepthAnything V2是一个高质量的深度估计模型
|
||||
# 它能从单一图像提取精确的深度信息,增强结构感知
|
||||
cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
|
||||
downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
|
||||
DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
|
||||
model="depth_anything_v2_vitl_fp32.safetensors"
|
||||
)
|
||||
|
||||
# ====================== 第4阶段:导入辅助节点 ======================
|
||||
# 这些都是工作流中需要的各种处理节点
|
||||
# 每个节点代表生成过程中的不同操作
|
||||
cliptextencode = CLIPTextEncode() # 用于编码文本提示
|
||||
loadimage = LoadImage() # 加载图像文件
|
||||
vaeencode = VAEEncode() # 将图像编码为潜在表示
|
||||
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() # FLUX特有的引导节点
|
||||
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]() # 条件设置
|
||||
clipvisionencode = CLIPVisionEncode() # 编码图像以提取视觉特征
|
||||
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() # 应用风格
|
||||
emptylatentimage = EmptyLatentImage() # 创建空白潜在图像
|
||||
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]() # 基本引导器
|
||||
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]() # 采样调度器
|
||||
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() # 生成随机噪声
|
||||
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]() # 高级采样器
|
||||
vaedecode = VAEDecode() # 将潜在表示解码为图像
|
||||
cr_text = NODE_CLASS_MAPPINGS["CR Text"]() # 文本处理
|
||||
saveimage = SaveImage() # 保存生成的图像
|
||||
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]() # 获取图像尺寸
|
||||
depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]() # 深度处理
|
||||
imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]() # 图像缩放
|
||||
|
||||
# ====================== 第5阶段:预加载模型到GPU ======================
|
||||
# 预加载模型到GPU以提高性能
|
||||
# 这减少了处理过程中的GPU内存碎片和延迟
|
||||
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
|
||||
|
||||
from comfy import model_management
|
||||
model_management.load_models_gpu([
|
||||
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
|
||||
])
|
||||
|
||||
logger.info("处理输入图像...")
|
||||
|
||||
# ====================== 第6阶段:准备输入图像 ======================
|
||||
# 将结构图像和风格图像复制到ComfyUI的input目录
|
||||
# 这是必须的,因为ComfyUI的LoadImage节点只能从该目录加载图像
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
input_dir = os.path.join(comfyui_dir, "input")
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
# 创建临时文件名以避免冲突
|
||||
structure_filename = f"structure_{random.randint(1000, 9999)}.png"
|
||||
style_filename = f"style_{random.randint(1000, 9999)}.png"
|
||||
|
||||
# 将上传的图像复制到输入目录
|
||||
structure_input_path = os.path.join(input_dir, structure_filename)
|
||||
style_input_path = os.path.join(input_dir, style_filename)
|
||||
|
||||
shutil.copy(structure_image_path, structure_input_path)
|
||||
shutil.copy(style_image_path, style_input_path)
|
||||
|
||||
# ====================== 第7阶段:图像生成流程 ======================
|
||||
# 使用torch.inference_mode()以减少内存使用并提高速度
|
||||
# 这告诉PyTorch不需要存储梯度信息
|
||||
with torch.inference_mode():
|
||||
# ------- 7.1: 设置CLIP -------
|
||||
# 初始化CLIP开关,使用同一个CLIP模型作为两个输入
|
||||
# 这是因为某些节点需要两个CLIP模型作为输入,但我们只需要使用一个
|
||||
clip_switch = cr_clip_input_switch.switch(
|
||||
Input=1, # 选择使用第一个CLIP输入
|
||||
clip1=get_value_at_index(CLIP_MODEL, 0),
|
||||
clip2=get_value_at_index(CLIP_MODEL, 0),
|
||||
)
|
||||
|
||||
# ------- 7.2: 文本编码 -------
|
||||
# 使用CLIP模型编码提示文本,创建条件嵌入
|
||||
# 这些嵌入用于引导扩散过程,告诉模型"生成什么"
|
||||
text_encoded = cliptextencode.encode(
|
||||
text=prompt, # 用户提供的文本提示
|
||||
clip=get_value_at_index(clip_switch, 0), # 使用CLIP模型
|
||||
)
|
||||
# 同时创建一个空文本编码,用于无条件引导(负面提示)
|
||||
empty_text = cliptextencode.encode(
|
||||
text="", # 空字符串作为负面提示
|
||||
clip=get_value_at_index(clip_switch, 0),
|
||||
)
|
||||
|
||||
logger.info("处理结构图像...")
|
||||
|
||||
# ------- 7.3: 处理结构图像 -------
|
||||
# 加载结构图像 - 这是生成过程的基础
|
||||
structure_img = loadimage.load_image(image=structure_filename)
|
||||
|
||||
# 调整图像大小到1024x1024,保持比例
|
||||
# 这确保图像尺寸适合模型,同时保持图像的原始比例
|
||||
resized_img = imageresize.execute(
|
||||
width=get_value_at_index(CONST_1024, 0), # 宽度设为1024
|
||||
height=get_value_at_index(CONST_1024, 0), # 高度设为1024
|
||||
interpolation="bicubic", # 使用双三次插值获得更好的质量
|
||||
method="keep proportion", # 保持图像比例
|
||||
condition="always", # 总是执行调整
|
||||
multiple_of=16, # 确保尺寸是16的倍数(扩散模型的要求)
|
||||
image=get_value_at_index(structure_img, 0),
|
||||
)
|
||||
|
||||
# 获取调整后图像的尺寸信息
|
||||
size_info = getimagesizeandcount.getsize(
|
||||
image=get_value_at_index(resized_img, 0)
|
||||
)
|
||||
|
||||
# 使用VAE编码图像到潜在空间
|
||||
# VAE将图像压缩到低维潜在空间,这是扩散模型操作的空间
|
||||
vae_encoded = vaeencode.encode(
|
||||
pixels=get_value_at_index(size_info, 0), # 图像像素
|
||||
vae=get_value_at_index(VAE_MODEL, 0), # VAE模型
|
||||
)
|
||||
|
||||
logger.info("处理深度...")
|
||||
|
||||
# ------- 7.4: 深度处理 -------
|
||||
# 使用DepthAnything模型从图像提取深度信息
|
||||
# 深度图帮助模型理解图像的3D结构
|
||||
depth_processed = depthanything_v2.process(
|
||||
da_model=get_value_at_index(DEPTH_MODEL, 0), # 深度模型
|
||||
images=get_value_at_index(size_info, 0), # 输入图像
|
||||
)
|
||||
|
||||
# 应用Flux引导 - 这将深度信息和文本提示结合起来
|
||||
# 深度强度参数控制深度信息对最终结果的影响程度
|
||||
flux_guided = fluxguidance.append(
|
||||
guidance=depth_strength, # 深度强度参数
|
||||
conditioning=get_value_at_index(text_encoded, 0), # 文本条件
|
||||
)
|
||||
|
||||
logger.info("处理风格图像...")
|
||||
|
||||
# ------- 7.5: 风格处理 -------
|
||||
# 加载风格图像 - 这提供了艺术风格
|
||||
style_img = loadimage.load_image(image=style_filename)
|
||||
|
||||
# 用CLIP Vision编码风格图像
|
||||
# 这提取风格图像的视觉特征,用于后续的风格转移
|
||||
style_encoded = clipvisionencode.encode(
|
||||
crop="center", # 中心裁剪
|
||||
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0), # CLIP视觉模型
|
||||
image=get_value_at_index(style_img, 0), # 风格图像
|
||||
)
|
||||
|
||||
logger.info("设置条件...")
|
||||
|
||||
# ------- 7.6: 设置条件 -------
|
||||
# InstructPixToPixConditioning将文本条件、VAE和深度信息结合起来
|
||||
# 这创建了一个复合条件,指导扩散过程
|
||||
conditioning = instructpixtopixconditioning.encode(
|
||||
positive=get_value_at_index(flux_guided, 0), # 正面条件(包含深度引导)
|
||||
negative=get_value_at_index(empty_text, 0), # 负面条件(空文本)
|
||||
vae=get_value_at_index(VAE_MODEL, 0), # VAE模型
|
||||
pixels=get_value_at_index(depth_processed, 0), # 深度处理后的图像
|
||||
)
|
||||
|
||||
# 应用风格 - 将风格特征应用到条件上
|
||||
# 风格强度参数控制风格应用的程度
|
||||
style_applied = stylemodelapplyadvanced.apply_stylemodel(
|
||||
strength=style_strength, # 风格强度参数
|
||||
conditioning=get_value_at_index(conditioning, 0), # 条件
|
||||
style_model=get_value_at_index(STYLE_MODEL, 0), # 风格模型
|
||||
clip_vision_output=get_value_at_index(style_encoded, 0), # 风格图像特征
|
||||
)
|
||||
|
||||
# ------- 7.7: 创建潜在空间 -------
|
||||
# 创建空的潜在图像 - 这是扩散过程的起点
|
||||
# 尺寸与处理后的结构图像匹配
|
||||
empty_latent = emptylatentimage.generate(
|
||||
width=get_value_at_index(resized_img, 1), # 宽度
|
||||
height=get_value_at_index(resized_img, 2), # 高度
|
||||
batch_size=1, # 生成一张图像
|
||||
)
|
||||
|
||||
logger.info("设置采样...")
|
||||
|
||||
# ------- 7.8: 设置扩散引导 -------
|
||||
# 设置引导器 - 将UNET模型和条件结合起来
|
||||
# 引导器负责根据条件引导扩散过程
|
||||
guided = basicguider.get_guider(
|
||||
model=get_value_at_index(UNET_MODEL, 0), # UNET模型
|
||||
conditioning=get_value_at_index(style_applied, 0), # 应用了风格的条件
|
||||
)
|
||||
|
||||
# 设置调度器 - 控制采样步骤和去噪强度
|
||||
# 调度器决定了如何从噪声逐步生成图像
|
||||
schedule = basicscheduler.get_sigmas(
|
||||
scheduler="simple", # 简单调度器
|
||||
steps=28, # 采样步数
|
||||
denoise=1, # 去噪强度
|
||||
model=get_value_at_index(UNET_MODEL, 0), # UNET模型
|
||||
)
|
||||
|
||||
# 生成随机噪声 - 作为扩散过程的起点
|
||||
# 不同的噪声种子会生成不同的图像
|
||||
noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
|
||||
|
||||
logger.info("采样中...")
|
||||
|
||||
# ------- 7.9: 执行采样 -------
|
||||
# 使用高级采样器从噪声生成图像
|
||||
# 这是扩散模型的核心过程,将噪声逐步转换为图像
|
||||
sampled = samplercustomadvanced.sample(
|
||||
noise=get_value_at_index(noise, 0), # 随机噪声
|
||||
guider=get_value_at_index(guided, 0), # 引导器
|
||||
sampler=get_value_at_index(SAMPLER, 0), # 采样器
|
||||
sigmas=get_value_at_index(schedule, 0), # 调度器提供的sigma值
|
||||
latent_image=get_value_at_index(empty_latent, 0), # 初始潜在图像
|
||||
)
|
||||
|
||||
logger.info("解码图像...")
|
||||
|
||||
# ------- 7.10: 解码结果 -------
|
||||
# 使用VAE将潜在表示解码回像素空间
|
||||
# 这是生成过程的最后一步,将潜在表示转换为可见图像
|
||||
decoded = vaedecode.decode(
|
||||
samples=get_value_at_index(sampled, 0), # 采样结果
|
||||
vae=get_value_at_index(VAE_MODEL, 0), # VAE模型
|
||||
)
|
||||
|
||||
# ------- 7.11: 保存图像 -------
|
||||
# 设置保存的文件名前缀
|
||||
prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
|
||||
|
||||
# 保存生成的图像到输出目录
|
||||
saved = saveimage.save_images(
|
||||
filename_prefix=get_value_at_index(prefix, 0), # 文件名前缀
|
||||
images=get_value_at_index(decoded, 0), # 解码后的图像
|
||||
)
|
||||
|
||||
# 获取输出路径
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
|
||||
logger.info(f"图像保存到 {saved_path}")
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(structure_input_path)
|
||||
os.remove(style_input_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件时出错: {e}")
|
||||
|
||||
# 返回生成图像的路径
|
||||
return saved_path
|
||||
|
||||
# 临时文件保存
|
||||
def save_upload_file_tmp(upload_file: UploadFile) -> str:
|
||||
try:
|
||||
suffix = Path(upload_file.filename).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
shutil.copyfileobj(upload_file.file, tmp)
|
||||
tmp_path = tmp.name
|
||||
return tmp_path
|
||||
finally:
|
||||
upload_file.file.close()
|
||||
|
||||
# 清理临时文件
|
||||
def cleanup_temp_file(filepath: str):
|
||||
try:
|
||||
os.unlink(filepath)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件失败 {filepath}: {e}")
|
||||
|
||||
# API路由:健康检查
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok", "message": "服务运行正常"}
|
||||
|
||||
# API路由:图像生成
|
||||
@app.post("/generate")
|
||||
async def create_image(
|
||||
background_tasks: BackgroundTasks,
|
||||
prompt: str = Form(""),
|
||||
structure_image: UploadFile = File(...),
|
||||
style_image: UploadFile = File(...),
|
||||
depth_strength: float = Form(15.0),
|
||||
style_strength: float = Form(0.5)
|
||||
):
|
||||
# 保存上传的文件
|
||||
structure_path = save_upload_file_tmp(structure_image)
|
||||
style_path = save_upload_file_tmp(style_image)
|
||||
|
||||
# 添加清理任务
|
||||
background_tasks.add_task(cleanup_temp_file, structure_path)
|
||||
background_tasks.add_task(cleanup_temp_file, style_path)
|
||||
|
||||
try:
|
||||
# 生成图像
|
||||
output_path = generate_image(
|
||||
prompt=prompt,
|
||||
structure_image_path=structure_path,
|
||||
style_image_path=style_path,
|
||||
depth_strength=depth_strength,
|
||||
style_strength=style_strength
|
||||
)
|
||||
|
||||
# 返回生成的图像
|
||||
return FileResponse(
|
||||
path=output_path,
|
||||
media_type="image/png",
|
||||
filename=os.path.basename(output_path)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成图像时出错: {e}")
|
||||
return {"error": f"生成图像失败: {str(e)}"}
|
||||
|
||||
# 初始化时设置环境
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
setup_environment()
|
||||
logger.info("API服务已启动并准备好接受请求")
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
# 启动FastAPI服务器
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,60 @@
|
|||
@echo off
|
||||
echo 正在启动FLUX风格塑形API服务...
|
||||
echo.
|
||||
|
||||
rem 检查是否存在虚拟环境
|
||||
IF EXIST venv (
|
||||
echo 检测到虚拟环境,正在激活...
|
||||
call venv\Scripts\activate
|
||||
) ELSE (
|
||||
echo 虚拟环境不存在,是否创建? (Y/N)
|
||||
set /p create_venv=
|
||||
if /I "%create_venv%"=="Y" (
|
||||
echo 正在创建虚拟环境...
|
||||
call create_venv.bat
|
||||
echo 虚拟环境已创建并激活
|
||||
) else (
|
||||
echo 将使用系统Python环境
|
||||
)
|
||||
)
|
||||
|
||||
rem 安装依赖
|
||||
echo 正在检查依赖...
|
||||
pip install -r requirements.txt
|
||||
|
||||
rem 检查CUDA是否可用
|
||||
echo 检查CUDA支持...
|
||||
python -c "import torch; is_cuda = torch.cuda.is_available(); print('CUDA是否可用:', is_cuda); cuda_ver = torch.version.cuda if is_cuda else '不可用'; print('CUDA版本:', cuda_ver); exit(0 if is_cuda else 1)" > nul 2>&1
|
||||
if %ERRORLEVEL% NEQ 0 (
|
||||
echo 警告: 未检测到可用的CUDA支持。这将显著影响图像生成性能。
|
||||
echo 如果您有支持CUDA的GPU,请尝试运行create_venv.bat脚本重新安装适合您系统的PyTorch版本。
|
||||
echo.
|
||||
echo 是否继续使用CPU模式运行? (Y/N)
|
||||
set /p continue_cpu=
|
||||
if /I NOT "%continue_cpu%"=="Y" (
|
||||
echo 操作已取消。
|
||||
goto cleanup
|
||||
)
|
||||
echo 将使用CPU模式继续运行,请注意性能会较低...
|
||||
echo.
|
||||
) else (
|
||||
echo CUDA支持正常,将使用GPU加速!
|
||||
echo.
|
||||
)
|
||||
|
||||
echo.
|
||||
echo 正在启动API服务...
|
||||
echo.
|
||||
echo API将在 http://localhost:8000 上运行
|
||||
echo 可访问 http://localhost:8000/docs 查看API文档
|
||||
echo.
|
||||
echo 按Ctrl+C可停止服务
|
||||
echo.
|
||||
|
||||
python flux_style_shaper_api.py
|
||||
|
||||
:cleanup
|
||||
rem 如果使用了虚拟环境,退出时取消激活
|
||||
IF EXIST venv (
|
||||
call deactivate
|
||||
)
|
|
@ -0,0 +1,41 @@
|
|||
#!/bin/bash
|
||||
|
||||
echo "正在启动FLUX风格塑形API服务..."
|
||||
echo ""
|
||||
|
||||
# 检查是否存在虚拟环境
|
||||
if [ -d "venv" ]; then
|
||||
echo "检测到虚拟环境,正在激活..."
|
||||
source venv/bin/activate
|
||||
else
|
||||
echo "虚拟环境不存在,是否创建? (y/n)"
|
||||
read create_venv
|
||||
if [[ "$create_venv" =~ ^[Yy]$ ]]; then
|
||||
echo "正在创建虚拟环境..."
|
||||
python -m venv venv
|
||||
source venv/bin/activate
|
||||
echo "虚拟环境已创建并激活"
|
||||
else
|
||||
echo "将使用系统Python环境"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 安装依赖
|
||||
echo "正在检查依赖..."
|
||||
pip install -r requirements.txt
|
||||
|
||||
echo ""
|
||||
echo "正在启动API服务..."
|
||||
echo ""
|
||||
echo "API将在 http://localhost:8000 上运行"
|
||||
echo "可访问 http://localhost:8000/docs 查看API文档"
|
||||
echo ""
|
||||
echo "按Ctrl+C可停止服务"
|
||||
echo ""
|
||||
|
||||
python flux_style_shaper_api.py
|
||||
|
||||
# 如果使用了虚拟环境,退出时取消激活
|
||||
if [ -d "venv" ]; then
|
||||
deactivate
|
||||
fi
|
|
@ -0,0 +1,94 @@
|
|||
# MCP 服务器 - FLUX风格塑形
|
||||
|
||||
基于[Model Context Protocol (MCP)](https://github.com/modelcontextprotocol/python-sdk) 实现的服务器,提供FLUX风格塑形功能。
|
||||
|
||||
## 功能特点
|
||||
|
||||
- 利用MCP协议提供标准化的API接口
|
||||
- 支持FLUX风格塑形功能,将结构图像和风格图像结合生成新图像
|
||||
- 提供资源、工具和提示模板三种MCP核心组件
|
||||
|
||||
## 安装步骤
|
||||
|
||||
1. 安装依赖项
|
||||
|
||||
```bash
|
||||
# 安装所有依赖
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. 获取必要的模型访问权限
|
||||
|
||||
FLUX模型需要从Hugging Face下载,部分模型需要授权访问:
|
||||
|
||||
- 访问 https://huggingface.co/settings/tokens 创建访问令牌
|
||||
- 访问以下模型页面申请访问权限:
|
||||
- https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev
|
||||
- https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev
|
||||
- https://huggingface.co/black-forest-labs/FLUX.1-dev
|
||||
- 创建`.env`文件并添加你的令牌:
|
||||
```
|
||||
HUGGING_FACE_HUB_TOKEN=your_token_here
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 启动服务器
|
||||
|
||||
```bash
|
||||
python mcp_server/flux_starter.py --host 0.0.0.0 --port 8189
|
||||
```
|
||||
|
||||
可用参数:
|
||||
- `--host`: 监听地址(默认:127.0.0.1)
|
||||
- `--port`: 监听端口(默认:8189)
|
||||
- `--name`: 服务器名称(默认:FLUX-MCP)
|
||||
|
||||
### 客户端演示
|
||||
|
||||
服务器启动后,可以使用演示客户端测试功能:
|
||||
|
||||
```bash
|
||||
python mcp_server/flux_demo_client.py --structure 图像路径.jpg --style 风格图像路径.jpg
|
||||
```
|
||||
|
||||
可用参数:
|
||||
- `--host`: 服务器地址(默认:127.0.0.1)
|
||||
- `--port`: 服务器端口(默认:8189)
|
||||
- `--structure`: 结构图像路径(必需)
|
||||
- `--style`: 风格图像路径(必需)
|
||||
- `--prompt`: 文本提示(默认为空)
|
||||
- `--depth-strength`: 深度强度(默认:15.0)
|
||||
- `--style-strength`: 风格强度(默认:0.5)
|
||||
- `--output`: 输出目录(默认:./output)
|
||||
|
||||
## MCP资源和工具
|
||||
|
||||
本服务器提供以下MCP组件:
|
||||
|
||||
### 资源 (Resources)
|
||||
|
||||
- `flux://status`: 获取FLUX服务状态
|
||||
|
||||
### 工具 (Tools)
|
||||
|
||||
- `生成风格化图像`: 生成风格化图像
|
||||
- 参数:
|
||||
- `prompt`: 文本提示
|
||||
- `structure_image_base64`: 结构图像(Base64编码)
|
||||
- `style_image_base64`: 风格图像(Base64编码)
|
||||
- `depth_strength`: 深度强度(默认:15.0)
|
||||
- `style_strength`: 风格强度(默认:0.5)
|
||||
|
||||
### 提示 (Prompts)
|
||||
|
||||
- `风格转移提示`: 生成风格转移提示模板
|
||||
- 参数:
|
||||
- `subject`: 主题
|
||||
- `style`: 风格
|
||||
|
||||
## 开发信息
|
||||
|
||||
- 需要ComfyUI环境作为基础
|
||||
- 使用`mcp` Python SDK版本 1.0.0 或更高
|
||||
- 完全中文注释,便于理解和维护
|
|
@ -0,0 +1,12 @@
|
|||
"""
|
||||
MCP服务器模块
|
||||
------------
|
||||
此模块提供了Model Context Protocol (MCP)服务器的实现,
|
||||
用于管理模型连接和资源交互。
|
||||
|
||||
基于 https://github.com/modelcontextprotocol/python-sdk
|
||||
"""
|
||||
|
||||
from .mcp import MCPServer, run, run_server
|
||||
|
||||
__all__ = ['MCPServer', 'run', 'run_server']
|
|
@ -0,0 +1,96 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MCP服务器示例脚本
|
||||
---------------
|
||||
这个脚本演示了如何启动和使用MCP服务器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
|
||||
# 设置日志级别
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 导入MCP服务器
|
||||
from mcp_server import MCPServer
|
||||
|
||||
# 自定义资源和工具示例
|
||||
async def setup_custom_handlers(server: MCPServer):
|
||||
"""设置自定义资源和工具"""
|
||||
|
||||
# 注册自定义资源
|
||||
@server.register_resource("comfyui://status")
|
||||
def get_comfy_status() -> str:
|
||||
"""获取ComfyUI状态信息"""
|
||||
return "ComfyUI 正在运行中"
|
||||
|
||||
# 注册自定义工具
|
||||
@server.register_tool()
|
||||
def generate_image(prompt: str, width: int = 512, height: int = 512) -> Dict[str, Any]:
|
||||
"""生成图像的工具(示例)"""
|
||||
# 这里只是示例,实际实现需要与ComfyUI集成
|
||||
return {
|
||||
"status": "success",
|
||||
"prompt": prompt,
|
||||
"dimensions": f"{width}x{height}",
|
||||
"message": "图像生成请求已提交"
|
||||
}
|
||||
|
||||
# 注册自定义提示
|
||||
@server.register_prompt("图像生成")
|
||||
def image_generation_prompt(style: str, subject: str) -> str:
|
||||
"""创建图像生成提示"""
|
||||
return f"请创建一个{style}风格的{subject}图像。"
|
||||
|
||||
async def register_models(server: MCPServer):
|
||||
"""注册示例模型"""
|
||||
# 注册示例模型
|
||||
model1 = server.register_model({
|
||||
"name": "stable-diffusion-v1.5",
|
||||
"type": "diffusion",
|
||||
"description": "Stable Diffusion v1.5图像生成模型"
|
||||
})
|
||||
|
||||
model2 = server.register_model({
|
||||
"name": "llama-3-8b",
|
||||
"type": "llm",
|
||||
"description": "Llama 3 8B语言模型"
|
||||
})
|
||||
|
||||
logging.info(f"已注册模型: {model1}, {model2}")
|
||||
return [model1, model2]
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 创建MCP服务器实例
|
||||
server = MCPServer(server_name="ComfyUI-MCP示例")
|
||||
|
||||
# 设置自定义处理器
|
||||
await setup_custom_handlers(server)
|
||||
|
||||
# 注册模型
|
||||
models = await register_models(server)
|
||||
|
||||
# 打印服务器信息
|
||||
logging.info(f"服务器统计信息: {server.get_stats()}")
|
||||
|
||||
# 启动服务器(默认地址为127.0.0.1:8189)
|
||||
try:
|
||||
logging.info("正在启动MCP服务器...")
|
||||
await server.start()
|
||||
except KeyboardInterrupt:
|
||||
logging.info("服务器被用户中断")
|
||||
finally:
|
||||
# 清理资源
|
||||
for model_id in models:
|
||||
server.unregister_model(model_id)
|
||||
logging.info("服务器已关闭")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行主函数
|
||||
asyncio.run(main())
|
|
@ -0,0 +1,245 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FLUX风格塑形MCP客户端演示脚本
|
||||
--------------------------
|
||||
演示如何使用MCP客户端连接到服务器并使用FLUX风格塑形功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import base64
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入MCP客户端库
|
||||
try:
|
||||
from mcp import ClientSession, HttpServerParameters, types
|
||||
from mcp.client.http import http_client
|
||||
except ImportError:
|
||||
logger.error("未找到MCP客户端库。请安装Model Context Protocol (MCP) Python SDK。")
|
||||
logger.error("安装命令: pip install mcp")
|
||||
sys.exit(1)
|
||||
|
||||
async def save_base64_image(base64_data: str, output_path: str) -> str:
|
||||
"""
|
||||
将base64编码的图像数据保存到文件
|
||||
|
||||
参数:
|
||||
base64_data: 包含MIME类型的base64编码图像数据
|
||||
output_path: 输出目录路径
|
||||
|
||||
返回:
|
||||
str: 保存的文件路径
|
||||
"""
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# 从base64字符串提取MIME类型和数据
|
||||
if ";" in base64_data and "," in base64_data:
|
||||
mime_part, data_part = base64_data.split(",", 1)
|
||||
content_type = mime_part.split(";")[0].split(":")[1]
|
||||
|
||||
# 确定文件扩展名
|
||||
if content_type == "image/jpeg":
|
||||
ext = ".jpg"
|
||||
elif content_type == "image/png":
|
||||
ext = ".png"
|
||||
elif content_type == "image/webp":
|
||||
ext = ".webp"
|
||||
else:
|
||||
ext = ".png" # 默认扩展名
|
||||
else:
|
||||
# 如果格式不正确,假设为PNG
|
||||
data_part = base64_data
|
||||
ext = ".png"
|
||||
|
||||
# 解码base64数据
|
||||
image_data = base64.b64decode(data_part)
|
||||
|
||||
# 创建输出文件名
|
||||
timestamp = asyncio.get_event_loop().time()
|
||||
filename = f"flux_output_{int(timestamp)}{ext}"
|
||||
file_path = os.path.join(output_path, filename)
|
||||
|
||||
# 保存图像
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
logger.info(f"图像已保存到: {file_path}")
|
||||
return file_path
|
||||
|
||||
async def read_image_to_base64(image_path: str) -> str:
|
||||
"""
|
||||
读取图像文件并转换为base64编码
|
||||
|
||||
参数:
|
||||
image_path: 图像文件路径
|
||||
|
||||
返回:
|
||||
str: base64编码的图像数据
|
||||
"""
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"图像文件不存在: {image_path}")
|
||||
|
||||
# 读取图像文件
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
# 确定MIME类型
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
if ext in [".jpg", ".jpeg"]:
|
||||
mime_type = "image/jpeg"
|
||||
elif ext == ".png":
|
||||
mime_type = "image/png"
|
||||
elif ext == ".webp":
|
||||
mime_type = "image/webp"
|
||||
else:
|
||||
mime_type = "image/png" # 默认MIME类型
|
||||
|
||||
# 编码为base64
|
||||
encoded = base64.b64encode(image_data).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
async def run_demo(host: str, port: int, structure_image: str, style_image: str, prompt: str,
|
||||
depth_strength: float, style_strength: float, output_dir: str):
|
||||
"""
|
||||
运行FLUX风格塑形演示
|
||||
|
||||
参数:
|
||||
host: 服务器主机地址
|
||||
port: 服务器端口
|
||||
structure_image: 结构图像路径
|
||||
style_image: 风格图像路径
|
||||
prompt: 文本提示
|
||||
depth_strength: 深度强度
|
||||
style_strength: 风格强度
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
# 创建服务器参数
|
||||
server_params = HttpServerParameters(
|
||||
host=host,
|
||||
port=port
|
||||
)
|
||||
|
||||
# 连接到服务器
|
||||
logger.info(f"正在连接到MCP服务器: {host}:{port}")
|
||||
async with http_client(server_params) as (read, write):
|
||||
# 创建客户端会话
|
||||
async with ClientSession(read, write) as session:
|
||||
# 初始化连接
|
||||
logger.info("正在初始化MCP客户端会话...")
|
||||
await session.initialize()
|
||||
|
||||
# 检查FLUX状态
|
||||
try:
|
||||
logger.info("正在检查FLUX服务状态...")
|
||||
content, mime_type = await session.read_resource("flux://status")
|
||||
logger.info(f"FLUX服务状态: {content}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取FLUX状态失败: {e}")
|
||||
return
|
||||
|
||||
# 列出可用工具
|
||||
logger.info("正在获取可用工具列表...")
|
||||
tools = await session.list_tools()
|
||||
logger.info(f"可用工具: {[tool.name for tool in tools]}")
|
||||
|
||||
# 转换输入图像为base64
|
||||
logger.info("正在读取输入图像...")
|
||||
structure_image_base64 = await read_image_to_base64(structure_image)
|
||||
style_image_base64 = await read_image_to_base64(style_image)
|
||||
|
||||
# 调用生成风格化图像工具
|
||||
logger.info("正在调用生成风格化图像工具...")
|
||||
try:
|
||||
result = await session.call_tool(
|
||||
"生成风格化图像",
|
||||
arguments={
|
||||
"prompt": prompt,
|
||||
"structure_image_base64": structure_image_base64,
|
||||
"style_image_base64": style_image_base64,
|
||||
"depth_strength": depth_strength,
|
||||
"style_strength": style_strength
|
||||
}
|
||||
)
|
||||
|
||||
# 处理结果
|
||||
if result.get("status") == "success":
|
||||
logger.info("图像生成成功")
|
||||
|
||||
# 保存生成的图像
|
||||
image_base64 = result.get("image_base64")
|
||||
if image_base64:
|
||||
await save_base64_image(image_base64, output_dir)
|
||||
else:
|
||||
logger.warning("结果中不包含图像数据")
|
||||
else:
|
||||
logger.error(f"图像生成失败: {result.get('message', '未知错误')}")
|
||||
except Exception as e:
|
||||
logger.error(f"调用生成风格化图像工具失败: {e}")
|
||||
|
||||
async def main_async():
|
||||
"""异步主函数"""
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description='FLUX风格塑形MCP客户端演示')
|
||||
parser.add_argument('--host', type=str, default='127.0.0.1', help='MCP服务器地址')
|
||||
parser.add_argument('--port', type=int, default=8189, help='MCP服务器端口')
|
||||
parser.add_argument('--structure', type=str, required=True, help='结构图像路径')
|
||||
parser.add_argument('--style', type=str, required=True, help='风格图像路径')
|
||||
parser.add_argument('--prompt', type=str, default='', help='文本提示')
|
||||
parser.add_argument('--depth-strength', type=float, default=15.0, help='深度强度')
|
||||
parser.add_argument('--style-strength', type=float, default=0.5, help='风格强度')
|
||||
parser.add_argument('--output', type=str, default='./output', help='输出目录')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 检查输入文件是否存在
|
||||
if not os.path.exists(args.structure):
|
||||
logger.error(f"结构图像文件不存在: {args.structure}")
|
||||
return 1
|
||||
|
||||
if not os.path.exists(args.style):
|
||||
logger.error(f"风格图像文件不存在: {args.style}")
|
||||
return 1
|
||||
|
||||
# 运行演示
|
||||
try:
|
||||
await run_demo(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
structure_image=args.structure,
|
||||
style_image=args.style,
|
||||
prompt=args.prompt,
|
||||
depth_strength=args.depth_strength,
|
||||
style_strength=args.style_strength,
|
||||
output_dir=args.output
|
||||
)
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断操作")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"演示运行出错: {e}")
|
||||
return 1
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
try:
|
||||
return asyncio.run(main_async())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断操作")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"程序运行出错: {e}")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
|
@ -0,0 +1,63 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FLUX风格塑形MCP服务器启动脚本
|
||||
--------------------------
|
||||
启动带有FLUX风格塑形功能的MCP服务器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 导入MCP服务器
|
||||
from mcp_server import MCPServer
|
||||
from flux_style_resource import register_flux_resources
|
||||
|
||||
async def main_async():
|
||||
"""异步主函数"""
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description='启动带有FLUX风格塑形功能的MCP服务器')
|
||||
parser.add_argument('--host', type=str, default='127.0.0.1', help='服务器监听地址')
|
||||
parser.add_argument('--port', type=int, default=8189, help='服务器监听端口')
|
||||
parser.add_argument('--name', type=str, default='FLUX-MCP', help='服务器名称')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 创建MCP服务器实例
|
||||
logger.info(f"正在初始化MCP服务器: {args.name}...")
|
||||
server = MCPServer(server_name=args.name)
|
||||
|
||||
# 注册FLUX风格塑形资源和工具
|
||||
logger.info("正在注册FLUX风格塑形资源...")
|
||||
register_flux_resources(server)
|
||||
|
||||
# 启动服务器
|
||||
logger.info(f"正在启动MCP服务器,监听地址: {args.host}:{args.port}")
|
||||
await server.start(host=args.host, port=args.port)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("收到中断信号,服务器关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"服务器启动失败: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
try:
|
||||
return asyncio.run(main_async())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("收到中断信号,服务器关闭")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"服务器运行出错: {e}")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
|
@ -0,0 +1,593 @@
|
|||
"""
|
||||
FLUX风格塑形MCP资源模块
|
||||
---------------------
|
||||
将FLUX风格塑形API集成到MCP服务器中
|
||||
作为资源和工具提供给MCP客户端
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import logging
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Union, Tuple, Optional
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from huggingface_hub import hf_hub_download
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv(override=True)
|
||||
|
||||
# 获取ComfyUI根目录
|
||||
def get_comfyui_dir():
|
||||
"""获取ComfyUI根目录路径"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
comfyui_dir = os.path.dirname(os.path.dirname(current_dir)) # 假设mcp_server在ComfyUI根目录下
|
||||
return comfyui_dir
|
||||
|
||||
# 下载所需模型
|
||||
def download_models():
|
||||
"""检查并下载所需模型"""
|
||||
logger.info("检查并下载所需模型...")
|
||||
# 获取ComfyUI根目录
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
|
||||
# 获取Hugging Face令牌
|
||||
hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
||||
if not hf_token:
|
||||
logger.warning("未找到Hugging Face访问令牌!受限模型可能无法下载。")
|
||||
logger.warning("请设置HUGGING_FACE_HUB_TOKEN环境变量或在.env文件中添加此变量。")
|
||||
else:
|
||||
logger.info("已找到Hugging Face访问令牌")
|
||||
|
||||
models = [
|
||||
{"repo_id": "black-forest-labs/FLUX.1-Redux-dev", "filename": "flux1-redux-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/style_models")},
|
||||
{"repo_id": "black-forest-labs/FLUX.1-Depth-dev", "filename": "flux1-depth-dev.safetensors", "local_dir": os.path.join(comfyui_dir, "models/diffusion_models")},
|
||||
{"repo_id": "Comfy-Org/sigclip_vision_384", "filename": "sigclip_vision_patch14_384.safetensors", "local_dir": os.path.join(comfyui_dir, "models/clip_vision")},
|
||||
{"repo_id": "Kijai/DepthAnythingV2-safetensors", "filename": "depth_anything_v2_vitl_fp32.safetensors", "local_dir": os.path.join(comfyui_dir, "models/depthanything")},
|
||||
{"repo_id": "black-forest-labs/FLUX.1-dev", "filename": "ae.safetensors", "local_dir": os.path.join(comfyui_dir, "models/vae/FLUX1")},
|
||||
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "clip_l.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders")},
|
||||
{"repo_id": "comfyanonymous/flux_text_encoders", "filename": "t5xxl_fp16.safetensors", "local_dir": os.path.join(comfyui_dir, "models/text_encoders/t5")}
|
||||
]
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
local_path = os.path.join(model["local_dir"], model["filename"])
|
||||
if not os.path.exists(local_path):
|
||||
logger.info(f"下载 {model['filename']} 到 {model['local_dir']}")
|
||||
os.makedirs(model["local_dir"], exist_ok=True)
|
||||
hf_hub_download(
|
||||
repo_id=model["repo_id"],
|
||||
filename=model["filename"],
|
||||
local_dir=model["local_dir"],
|
||||
token=hf_token # 使用令牌进行下载
|
||||
)
|
||||
else:
|
||||
logger.info(f"模型 {model['filename']} 已存在于 {model['local_dir']}")
|
||||
except Exception as e:
|
||||
if "403 Client Error" in str(e) and "Access to model" in str(e) and "is restricted" in str(e):
|
||||
logger.error(f"下载 {model['filename']} 时权限错误:您没有访问权限")
|
||||
logger.error("请按照以下步骤解决:")
|
||||
logger.error("1. 访问 https://huggingface.co/settings/tokens 创建访问令牌")
|
||||
logger.error(f"2. 访问 https://huggingface.co/{model['repo_id']} 页面申请访问权限")
|
||||
logger.error("3. 设置环境变量 HUGGING_FACE_HUB_TOKEN 为您的访问令牌")
|
||||
else:
|
||||
logger.error(f"下载 {model['filename']} 时出错: {e}")
|
||||
|
||||
# 如果是关键模型,则抛出异常阻止继续
|
||||
if model["filename"] in ["flux1-redux-dev.safetensors", "flux1-depth-dev.safetensors", "ae.safetensors"]:
|
||||
raise RuntimeError(f"无法下载关键模型 {model['filename']},程序无法继续运行")
|
||||
|
||||
# 从索引获取值的辅助函数
|
||||
def get_value_at_index(obj: Union[Dict, list], index: int) -> Any:
|
||||
"""从索引获取值的辅助函数"""
|
||||
try:
|
||||
return obj[index]
|
||||
except KeyError:
|
||||
return obj["result"][index]
|
||||
|
||||
# 添加ComfyUI目录到系统路径
|
||||
def add_comfyui_directory_to_sys_path() -> None:
|
||||
"""将ComfyUI目录添加到系统路径"""
|
||||
comfyui_path = get_comfyui_dir()
|
||||
if comfyui_path is not None and os.path.isdir(comfyui_path):
|
||||
sys.path.append(comfyui_path)
|
||||
logger.info(f"'{comfyui_path}' 已添加到 sys.path")
|
||||
|
||||
# 添加额外模型路径
|
||||
def add_extra_model_paths() -> None:
|
||||
"""添加额外模型路径配置"""
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
sys.path.append(comfyui_dir) # 确保能导入ComfyUI模块
|
||||
|
||||
try:
|
||||
# 尝试直接从main.py导入load_extra_path_config
|
||||
from main import load_extra_path_config
|
||||
|
||||
extra_model_paths = os.path.join(comfyui_dir, "extra_model_paths.yaml")
|
||||
if os.path.exists(extra_model_paths):
|
||||
logger.info(f"找到extra_model_paths配置: {extra_model_paths}")
|
||||
load_extra_path_config(extra_model_paths)
|
||||
else:
|
||||
logger.info("未找到extra_model_paths.yaml文件,将使用默认模型路径")
|
||||
except ImportError as e:
|
||||
logger.warning(f"无法导入load_extra_path_config: {e}")
|
||||
logger.info("将使用默认模型路径继续运行")
|
||||
|
||||
# 导入自定义节点
|
||||
def import_custom_nodes() -> None:
|
||||
"""导入ComfyUI自定义节点"""
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
sys.path.append(comfyui_dir)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
import execution
|
||||
from nodes import init_extra_nodes
|
||||
import server
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server_instance = server.PromptServer(loop)
|
||||
execution.PromptQueue(server_instance)
|
||||
init_extra_nodes()
|
||||
logger.info("自定义节点导入成功")
|
||||
except Exception as e:
|
||||
logger.error(f"导入自定义节点时出错: {e}")
|
||||
raise
|
||||
|
||||
# 设置ComfyUI环境
|
||||
def setup_environment():
|
||||
"""设置ComfyUI环境,加载路径和模型"""
|
||||
logger.info("设置ComfyUI环境...")
|
||||
# 初始化路径
|
||||
add_comfyui_directory_to_sys_path()
|
||||
add_extra_model_paths()
|
||||
|
||||
# 下载所需模型
|
||||
download_models()
|
||||
|
||||
# 导入自定义节点
|
||||
try:
|
||||
import_custom_nodes()
|
||||
logger.info("自定义节点导入成功")
|
||||
except Exception as e:
|
||||
logger.error(f"导入自定义节点时出错: {e}")
|
||||
raise RuntimeError("无法导入自定义节点,程序无法继续运行")
|
||||
|
||||
# 保存临时图像文件
|
||||
def save_image_to_temp(image_data: Union[str, bytes]) -> str:
|
||||
"""
|
||||
将图像数据保存到临时文件
|
||||
|
||||
参数:
|
||||
image_data: 图像数据,可以是base64编码字符串或字节数据
|
||||
|
||||
返回:
|
||||
str: 临时文件路径
|
||||
"""
|
||||
# 如果是base64编码字符串,先解码
|
||||
if isinstance(image_data, str) and image_data.startswith(('data:image', 'data:application')):
|
||||
# 提取实际的base64编码部分
|
||||
content_type, base64_data = image_data.split(';base64,')
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 创建临时文件并保存图像
|
||||
suffix = '.png' # 默认文件扩展名
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
if isinstance(image_data, str):
|
||||
# 如果仍然是字符串,假设是文件路径
|
||||
shutil.copyfile(image_data, tmp.name)
|
||||
else:
|
||||
# 否则是字节数据
|
||||
tmp.write(image_data)
|
||||
tmp_path = tmp.name
|
||||
|
||||
return tmp_path
|
||||
|
||||
# 图像生成主函数
|
||||
def generate_image(prompt: str, structure_image: Union[str, bytes], style_image: Union[str, bytes],
|
||||
depth_strength: float = 15.0, style_strength: float = 0.5) -> Tuple[str, bytes]:
|
||||
"""
|
||||
生成风格化图像
|
||||
|
||||
参数:
|
||||
prompt: 文本提示
|
||||
structure_image: 结构图像数据(可以是base64编码字符串、字节数据或文件路径)
|
||||
style_image: 风格图像数据(可以是base64编码字符串、字节数据或文件路径)
|
||||
depth_strength: 深度强度
|
||||
style_strength: 风格强度
|
||||
|
||||
返回:
|
||||
Tuple[str, bytes]: 返回生成图像的路径和图像数据
|
||||
"""
|
||||
# 保存输入图像到临时文件
|
||||
structure_image_path = save_image_to_temp(structure_image)
|
||||
style_image_path = save_image_to_temp(style_image)
|
||||
|
||||
try:
|
||||
# 从这里开始复制原始生成函数的实现
|
||||
logger.info("开始图像生成过程...")
|
||||
|
||||
# ====================== 第1阶段:导入必要组件 ======================
|
||||
# 从ComfyUI的nodes.py导入所有需要的节点类型
|
||||
from nodes import (
|
||||
StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader,
|
||||
SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage,
|
||||
VAEDecode, UNETLoader, CLIPTextEncode,
|
||||
)
|
||||
|
||||
# ====================== 第2阶段:初始化常量 ======================
|
||||
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
|
||||
CONST_1024 = intconstant.get_value(value=1024)
|
||||
|
||||
logger.info("加载模型...")
|
||||
|
||||
# ====================== 第3阶段:加载模型 ======================
|
||||
dualcliploader = DualCLIPLoader()
|
||||
CLIP_MODEL = dualcliploader.load_clip(
|
||||
clip_name1="t5/t5xxl_fp16.safetensors",
|
||||
clip_name2="clip_l.safetensors",
|
||||
type="flux",
|
||||
)
|
||||
|
||||
vaeloader = VAELoader()
|
||||
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
|
||||
|
||||
unetloader = UNETLoader()
|
||||
UNET_MODEL = unetloader.load_unet(
|
||||
unet_name="flux1-depth-dev.safetensors",
|
||||
weight_dtype="default"
|
||||
)
|
||||
|
||||
clipvisionloader = CLIPVisionLoader()
|
||||
CLIP_VISION_MODEL = clipvisionloader.load_clip(
|
||||
clip_name="sigclip_vision_patch14_384.safetensors"
|
||||
)
|
||||
|
||||
stylemodelloader = StyleModelLoader()
|
||||
STYLE_MODEL = stylemodelloader.load_style_model(
|
||||
style_model_name="flux1-redux-dev.safetensors"
|
||||
)
|
||||
|
||||
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
|
||||
SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
|
||||
|
||||
cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
|
||||
downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
|
||||
DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
|
||||
model="depth_anything_v2_vitl_fp32.safetensors"
|
||||
)
|
||||
|
||||
# ====================== 第4阶段:导入辅助节点 ======================
|
||||
cliptextencode = CLIPTextEncode()
|
||||
loadimage = LoadImage()
|
||||
vaeencode = VAEEncode()
|
||||
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
|
||||
instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
|
||||
clipvisionencode = CLIPVisionEncode()
|
||||
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
|
||||
emptylatentimage = EmptyLatentImage()
|
||||
basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
|
||||
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
|
||||
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
|
||||
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
|
||||
vaedecode = VAEDecode()
|
||||
cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
|
||||
saveimage = SaveImage()
|
||||
getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
|
||||
depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
|
||||
imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
|
||||
|
||||
# ====================== 第5阶段:预加载模型到GPU ======================
|
||||
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
|
||||
|
||||
from comfy import model_management
|
||||
model_management.load_models_gpu([
|
||||
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
|
||||
])
|
||||
|
||||
logger.info("处理输入图像...")
|
||||
|
||||
# ====================== 第6阶段:准备输入图像 ======================
|
||||
comfyui_dir = get_comfyui_dir()
|
||||
input_dir = os.path.join(comfyui_dir, "input")
|
||||
os.makedirs(input_dir, exist_ok=True)
|
||||
|
||||
# 创建临时文件名以避免冲突
|
||||
structure_filename = f"structure_{random.randint(1000, 9999)}.png"
|
||||
style_filename = f"style_{random.randint(1000, 9999)}.png"
|
||||
|
||||
# 将上传的图像复制到输入目录
|
||||
structure_input_path = os.path.join(input_dir, structure_filename)
|
||||
style_input_path = os.path.join(input_dir, style_filename)
|
||||
|
||||
shutil.copy(structure_image_path, structure_input_path)
|
||||
shutil.copy(style_image_path, style_input_path)
|
||||
|
||||
# ====================== 第7阶段:图像生成流程 ======================
|
||||
with torch.inference_mode():
|
||||
# ------- 7.1: 设置CLIP -------
|
||||
clip_switch = cr_clip_input_switch.switch(
|
||||
Input=1,
|
||||
clip1=get_value_at_index(CLIP_MODEL, 0),
|
||||
clip2=get_value_at_index(CLIP_MODEL, 0),
|
||||
)
|
||||
|
||||
# ------- 7.2: 文本编码 -------
|
||||
text_encoded = cliptextencode.encode(
|
||||
text=prompt,
|
||||
clip=get_value_at_index(clip_switch, 0),
|
||||
)
|
||||
empty_text = cliptextencode.encode(
|
||||
text="",
|
||||
clip=get_value_at_index(clip_switch, 0),
|
||||
)
|
||||
|
||||
logger.info("处理结构图像...")
|
||||
|
||||
# ------- 7.3: 处理结构图像 -------
|
||||
structure_img = loadimage.load_image(image=structure_filename)
|
||||
|
||||
resized_img = imageresize.execute(
|
||||
width=get_value_at_index(CONST_1024, 0),
|
||||
height=get_value_at_index(CONST_1024, 0),
|
||||
interpolation="bicubic",
|
||||
method="keep proportion",
|
||||
condition="always",
|
||||
multiple_of=16,
|
||||
image=get_value_at_index(structure_img, 0),
|
||||
)
|
||||
|
||||
size_info = getimagesizeandcount.getsize(
|
||||
image=get_value_at_index(resized_img, 0)
|
||||
)
|
||||
|
||||
vae_encoded = vaeencode.encode(
|
||||
pixels=get_value_at_index(size_info, 0),
|
||||
vae=get_value_at_index(VAE_MODEL, 0),
|
||||
)
|
||||
|
||||
logger.info("处理深度...")
|
||||
|
||||
# ------- 7.4: 深度处理 -------
|
||||
depth_processed = depthanything_v2.process(
|
||||
da_model=get_value_at_index(DEPTH_MODEL, 0),
|
||||
images=get_value_at_index(size_info, 0),
|
||||
)
|
||||
|
||||
flux_guided = fluxguidance.append(
|
||||
guidance=depth_strength,
|
||||
conditioning=get_value_at_index(text_encoded, 0),
|
||||
)
|
||||
|
||||
logger.info("处理风格图像...")
|
||||
|
||||
# ------- 7.5: 风格处理 -------
|
||||
style_img = loadimage.load_image(image=style_filename)
|
||||
|
||||
style_encoded = clipvisionencode.encode(
|
||||
crop="center",
|
||||
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
|
||||
image=get_value_at_index(style_img, 0),
|
||||
)
|
||||
|
||||
logger.info("设置条件...")
|
||||
|
||||
# ------- 7.6: 设置条件 -------
|
||||
conditioning = instructpixtopixconditioning.encode(
|
||||
positive=get_value_at_index(flux_guided, 0),
|
||||
negative=get_value_at_index(empty_text, 0),
|
||||
vae=get_value_at_index(VAE_MODEL, 0),
|
||||
pixels=get_value_at_index(depth_processed, 0),
|
||||
)
|
||||
|
||||
style_applied = stylemodelapplyadvanced.apply_stylemodel(
|
||||
strength=style_strength,
|
||||
conditioning=get_value_at_index(conditioning, 0),
|
||||
style_model=get_value_at_index(STYLE_MODEL, 0),
|
||||
clip_vision_output=get_value_at_index(style_encoded, 0),
|
||||
)
|
||||
|
||||
# ------- 7.7: 创建潜在空间 -------
|
||||
empty_latent = emptylatentimage.generate(
|
||||
width=get_value_at_index(resized_img, 1),
|
||||
height=get_value_at_index(resized_img, 2),
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
logger.info("设置采样...")
|
||||
|
||||
# ------- 7.8: 设置扩散引导 -------
|
||||
guided = basicguider.get_guider(
|
||||
model=get_value_at_index(UNET_MODEL, 0),
|
||||
conditioning=get_value_at_index(style_applied, 0),
|
||||
)
|
||||
|
||||
schedule = basicscheduler.get_sigmas(
|
||||
scheduler="simple",
|
||||
steps=28,
|
||||
denoise=1,
|
||||
model=get_value_at_index(UNET_MODEL, 0),
|
||||
)
|
||||
|
||||
noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
|
||||
|
||||
logger.info("采样中...")
|
||||
|
||||
# ------- 7.9: 执行采样 -------
|
||||
sampled = samplercustomadvanced.sample(
|
||||
noise=get_value_at_index(noise, 0),
|
||||
guider=get_value_at_index(guided, 0),
|
||||
sampler=get_value_at_index(SAMPLER, 0),
|
||||
sigmas=get_value_at_index(schedule, 0),
|
||||
latent_image=get_value_at_index(empty_latent, 0),
|
||||
)
|
||||
|
||||
logger.info("解码图像...")
|
||||
|
||||
# ------- 7.10: 解码结果 -------
|
||||
decoded = vaedecode.decode(
|
||||
samples=get_value_at_index(sampled, 0),
|
||||
vae=get_value_at_index(VAE_MODEL, 0),
|
||||
)
|
||||
|
||||
# ------- 7.11: 保存图像 -------
|
||||
prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
|
||||
|
||||
saved = saveimage.save_images(
|
||||
filename_prefix=get_value_at_index(prefix, 0),
|
||||
images=get_value_at_index(decoded, 0),
|
||||
)
|
||||
|
||||
# 获取输出路径
|
||||
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
|
||||
logger.info(f"图像保存到 {saved_path}")
|
||||
|
||||
# 读取生成的图像数据
|
||||
with open(saved_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(structure_input_path)
|
||||
os.remove(style_input_path)
|
||||
os.remove(structure_image_path)
|
||||
os.remove(style_image_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件时出错: {e}")
|
||||
|
||||
# 返回生成图像的路径和数据
|
||||
return saved_path, image_data
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(structure_image_path)
|
||||
os.remove(style_image_path)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"生成图像时出错: {e}")
|
||||
raise
|
||||
|
||||
# 图像数据转换为base64
|
||||
def image_to_base64(image_path: str) -> str:
|
||||
"""
|
||||
将图像文件转换为base64编码字符串
|
||||
|
||||
参数:
|
||||
image_path: 图像文件路径
|
||||
|
||||
返回:
|
||||
str: base64编码的图像数据,包含MIME类型前缀
|
||||
"""
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
encoded = base64.b64encode(image_data).decode("utf-8")
|
||||
mime_type = "image/png" # 默认MIME类型
|
||||
|
||||
# 根据文件扩展名确定MIME类型
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
if ext == '.jpg' or ext == '.jpeg':
|
||||
mime_type = "image/jpeg"
|
||||
elif ext == '.webp':
|
||||
mime_type = "image/webp"
|
||||
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
# 初始化FLUX环境
|
||||
def init_flux():
|
||||
"""初始化FLUX环境,设置路径和导入模型"""
|
||||
try:
|
||||
setup_environment()
|
||||
logger.info("FLUX环境初始化成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"FLUX环境初始化失败: {e}")
|
||||
return False
|
||||
|
||||
# 集成到MCP服务器的函数
|
||||
def register_flux_resources(server):
|
||||
"""
|
||||
将FLUX风格塑形功能注册到MCP服务器
|
||||
|
||||
参数:
|
||||
server: MCP服务器实例
|
||||
"""
|
||||
# 初始化FLUX环境
|
||||
initialized = init_flux()
|
||||
if not initialized:
|
||||
logger.error("FLUX环境初始化失败,无法注册资源")
|
||||
return
|
||||
|
||||
# 注册资源: 服务状态
|
||||
@server.register_resource("flux://status")
|
||||
def get_flux_status() -> str:
|
||||
"""返回FLUX服务状态"""
|
||||
return "FLUX风格塑形服务已启动并运行正常"
|
||||
|
||||
# 注册工具: 生成风格化图像
|
||||
@server.register_tool("生成风格化图像")
|
||||
def generate_styled_image(prompt: str, structure_image_base64: str, style_image_base64: str,
|
||||
depth_strength: float = 15.0, style_strength: float = 0.5) -> Dict[str, Any]:
|
||||
"""
|
||||
生成风格化图像工具
|
||||
|
||||
参数:
|
||||
prompt: 文本提示,用于指导生成过程
|
||||
structure_image_base64: 结构图像的base64编码,提供基本构图
|
||||
style_image_base64: 风格图像的base64编码,提供艺术风格
|
||||
depth_strength: 深度强度,控制结构保持程度
|
||||
style_strength: 风格强度,控制风格应用程度
|
||||
|
||||
返回:
|
||||
Dict: 包含生成图像的base64编码和其他信息
|
||||
"""
|
||||
try:
|
||||
# 生成图像
|
||||
saved_path, _ = generate_image(
|
||||
prompt=prompt,
|
||||
structure_image=structure_image_base64,
|
||||
style_image=style_image_base64,
|
||||
depth_strength=depth_strength,
|
||||
style_strength=style_strength
|
||||
)
|
||||
|
||||
# 将生成的图像转换为base64
|
||||
output_base64 = image_to_base64(saved_path)
|
||||
|
||||
# 返回结果
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "图像生成成功",
|
||||
"image_base64": output_base64,
|
||||
"parameters": {
|
||||
"prompt": prompt,
|
||||
"depth_strength": depth_strength,
|
||||
"style_strength": style_strength
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"生成图像失败: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"生成图像失败: {str(e)}"
|
||||
}
|
||||
|
||||
# 注册提示: 创建风格转移提示
|
||||
@server.register_prompt("风格转移提示")
|
||||
def style_transfer_prompt(subject: str, style: str) -> str:
|
||||
"""创建风格转移的提示模板"""
|
||||
return f"将{subject}转换为{style}风格"
|
||||
|
||||
logger.info("FLUX风格塑形资源和工具已注册到MCP服务器")
|
|
@ -0,0 +1,245 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Callable, Union
|
||||
import os
|
||||
|
||||
# 导入MCP SDK相关库
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.types import Resource, Tool, Prompt
|
||||
from mcp import types
|
||||
|
||||
class MCPServer:
|
||||
"""
|
||||
Model Context Protocol 服务器
|
||||
用于管理模型连接和资源交互
|
||||
"""
|
||||
instance = None
|
||||
|
||||
def __init__(self, server_name: str = "ComfyUI-MCP"):
|
||||
"""
|
||||
初始化MCP服务器
|
||||
|
||||
参数:
|
||||
server_name: 服务器名称
|
||||
"""
|
||||
MCPServer.instance = self
|
||||
|
||||
# 初始化FastMCP实例
|
||||
self.mcp = FastMCP(server_name)
|
||||
|
||||
# 存储注册的资源、工具和提示
|
||||
self.registered_resources: Dict[str, Callable] = {}
|
||||
self.registered_tools: Dict[str, Callable] = {}
|
||||
self.registered_prompts: Dict[str, Callable] = {}
|
||||
|
||||
# 存储模型连接信息
|
||||
self.active_models: Dict[str, Dict[str, Any]] = {}
|
||||
self.model_count: int = 0
|
||||
|
||||
# 状态标志
|
||||
self.is_running: bool = False
|
||||
|
||||
# 设置基本资源和工具
|
||||
self._setup_default_handlers()
|
||||
|
||||
logging.info(f"[MCP服务器] 已初始化,服务器名称: {server_name}")
|
||||
|
||||
def _setup_default_handlers(self):
|
||||
"""设置默认的资源处理器和工具"""
|
||||
|
||||
# 注册状态资源
|
||||
@self.mcp.resource("status://server")
|
||||
def get_server_status() -> str:
|
||||
"""返回服务器状态信息"""
|
||||
return f"""
|
||||
服务器状态:
|
||||
- 运行中: {self.is_running}
|
||||
- 活跃模型数: {len(self.active_models)}
|
||||
- 已注册资源数: {len(self.registered_resources)}
|
||||
- 已注册工具数: {len(self.registered_tools)}
|
||||
- 已注册提示数: {len(self.registered_prompts)}
|
||||
"""
|
||||
|
||||
# 注册模型列表资源
|
||||
@self.mcp.resource("models://list")
|
||||
def list_models() -> str:
|
||||
"""返回当前连接的模型列表"""
|
||||
if not self.active_models:
|
||||
return "当前没有活跃的模型连接"
|
||||
|
||||
result = "活跃模型列表:\n"
|
||||
for model_id, model_info in self.active_models.items():
|
||||
result += f"- ID: {model_id}, 名称: {model_info.get('name', 'Unknown')}\n"
|
||||
return result
|
||||
|
||||
# 注册Echo工具
|
||||
@self.mcp.tool()
|
||||
def echo(message: str) -> str:
|
||||
"""简单的Echo工具,用于测试连接"""
|
||||
return f"MCP服务器回声: {message}"
|
||||
|
||||
# 注册系统信息工具
|
||||
@self.mcp.tool()
|
||||
def system_info() -> dict:
|
||||
"""返回系统信息"""
|
||||
import platform
|
||||
return {
|
||||
"os": platform.system(),
|
||||
"python_version": platform.python_version(),
|
||||
"hostname": platform.node(),
|
||||
"cpu": platform.processor()
|
||||
}
|
||||
|
||||
def register_resource(self, uri_pattern: str):
|
||||
"""
|
||||
注册一个资源处理函数
|
||||
|
||||
参数:
|
||||
uri_pattern: 资源URI模式
|
||||
"""
|
||||
def decorator(func):
|
||||
self.registered_resources[uri_pattern] = func
|
||||
self.mcp.resource(uri_pattern)(func)
|
||||
logging.info(f"[MCP服务器] 已注册资源: {uri_pattern}")
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def register_tool(self, name: Optional[str] = None):
|
||||
"""
|
||||
注册一个工具函数
|
||||
|
||||
参数:
|
||||
name: 工具名称(可选)
|
||||
"""
|
||||
def decorator(func):
|
||||
tool_name = name or func.__name__
|
||||
self.registered_tools[tool_name] = func
|
||||
self.mcp.tool(name=tool_name)(func)
|
||||
logging.info(f"[MCP服务器] 已注册工具: {tool_name}")
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def register_prompt(self, name: Optional[str] = None):
|
||||
"""
|
||||
注册一个提示模板
|
||||
|
||||
参数:
|
||||
name: 提示名称(可选)
|
||||
"""
|
||||
def decorator(func):
|
||||
prompt_name = name or func.__name__
|
||||
self.registered_prompts[prompt_name] = func
|
||||
self.mcp.prompt(name=prompt_name)(func)
|
||||
logging.info(f"[MCP服务器] 已注册提示: {prompt_name}")
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def register_model(self, model_info: Dict[str, Any]) -> str:
|
||||
"""
|
||||
注册一个模型到MCP服务器
|
||||
|
||||
参数:
|
||||
model_info: 模型信息字典
|
||||
|
||||
返回:
|
||||
model_id: 模型ID
|
||||
"""
|
||||
model_id = f"model_{self.model_count}"
|
||||
self.model_count += 1
|
||||
|
||||
self.active_models[model_id] = {
|
||||
"id": model_id,
|
||||
"registered_at": asyncio.get_event_loop().time(),
|
||||
**model_info
|
||||
}
|
||||
|
||||
logging.info(f"[MCP服务器] 已注册模型: {model_id}")
|
||||
return model_id
|
||||
|
||||
def unregister_model(self, model_id: str) -> bool:
|
||||
"""
|
||||
从MCP服务器注销一个模型
|
||||
|
||||
参数:
|
||||
model_id: 模型ID
|
||||
|
||||
返回:
|
||||
成功与否
|
||||
"""
|
||||
if model_id in self.active_models:
|
||||
del self.active_models[model_id]
|
||||
logging.info(f"[MCP服务器] 已注销模型: {model_id}")
|
||||
return True
|
||||
|
||||
logging.warning(f"[MCP服务器] 尝试注销未知模型: {model_id}")
|
||||
return False
|
||||
|
||||
async def start(self, host: str = "127.0.0.1", port: int = 8189):
|
||||
"""
|
||||
启动MCP服务器
|
||||
|
||||
参数:
|
||||
host: 主机地址
|
||||
port: 端口号
|
||||
"""
|
||||
self.is_running = True
|
||||
try:
|
||||
# 启动FastMCP HTTP服务器
|
||||
uvicorn_config = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"log_level": "info"
|
||||
}
|
||||
|
||||
logging.info(f"[MCP服务器] 正在启动... 地址: {host}:{port}")
|
||||
|
||||
# 使用FastMCP的HTTP服务器启动方法
|
||||
await self.mcp.serve(**uvicorn_config)
|
||||
except Exception as e:
|
||||
self.is_running = False
|
||||
logging.error(f"[MCP服务器] 启动失败: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
self.is_running = False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取服务器统计信息"""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"active_models_count": len(self.active_models),
|
||||
"resources_count": len(self.registered_resources),
|
||||
"tools_count": len(self.registered_tools),
|
||||
"prompts_count": len(self.registered_prompts)
|
||||
}
|
||||
|
||||
# 运行MCP服务器的异步函数
|
||||
async def run_server(server_name: str = "ComfyUI-MCP", host: str = "127.0.0.1", port: int = 8189):
|
||||
"""
|
||||
运行MCP服务器的便捷函数
|
||||
|
||||
参数:
|
||||
server_name: 服务器名称
|
||||
host: 主机地址
|
||||
port: 端口号
|
||||
"""
|
||||
server = MCPServer(server_name)
|
||||
await server.start(host, port)
|
||||
return server
|
||||
|
||||
# 同步运行MCP服务器的函数
|
||||
def run(server_name: str = "ComfyUI-MCP", host: str = "127.0.0.1", port: int = 8189):
|
||||
"""
|
||||
同步运行MCP服务器
|
||||
|
||||
参数:
|
||||
server_name: 服务器名称
|
||||
host: 主机地址
|
||||
port: 端口号
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(run_server(server_name, host, port))
|
||||
except KeyboardInterrupt:
|
||||
logging.info("[MCP服务器] 收到中断信号,正在关闭...")
|
||||
finally:
|
||||
loop.close()
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MCP服务器启动脚本
|
||||
--------------
|
||||
通过命令行启动MCP服务器
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from mcp_server import run
|
||||
|
||||
def main():
|
||||
"""主函数,解析命令行参数并启动服务器"""
|
||||
|
||||
# 配置命令行参数解析
|
||||
parser = argparse.ArgumentParser(description='启动Model Context Protocol (MCP)服务器')
|
||||
|
||||
# 服务器配置参数
|
||||
parser.add_argument('--name', type=str, default='ComfyUI-MCP',
|
||||
help='服务器名称 (默认: ComfyUI-MCP)')
|
||||
parser.add_argument('--host', type=str, default='127.0.0.1',
|
||||
help='服务器监听地址 (默认: 127.0.0.1)')
|
||||
parser.add_argument('--port', type=int, default=8189,
|
||||
help='服务器监听端口 (默认: 8189)')
|
||||
parser.add_argument('--verbose', action='store_true',
|
||||
help='启用详细日志输出')
|
||||
|
||||
# 解析命令行参数
|
||||
args = parser.parse_args()
|
||||
|
||||
# 配置日志级别
|
||||
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# 记录启动信息
|
||||
logging.info(f"正在启动MCP服务器 '{args.name}' 在 {args.host}:{args.port}")
|
||||
|
||||
try:
|
||||
# 启动服务器
|
||||
run(server_name=args.name, host=args.host, port=args.port)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("服务器被用户中断")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logging.error(f"服务器启动失败: {str(e)}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
|
@ -1,24 +1,41 @@
|
|||
# ComfyUI 核心依赖
|
||||
comfyui-frontend-package==1.16.9
|
||||
comfyui-workflow-templates==0.1.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
torchaudio
|
||||
numpy>=1.25.0
|
||||
numpy>=1.25.0 # flux版本是1.24.3,但我们使用更新版本
|
||||
einops
|
||||
transformers>=4.28.1
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
aiohttp>=3.11.8
|
||||
aiohttp>=3.11.8 # 使用更新版本的要求
|
||||
yarl>=1.18.0
|
||||
pyyaml
|
||||
Pillow
|
||||
Pillow>=9.3.0 # 使用flux版本的最低要求
|
||||
scipy
|
||||
tqdm
|
||||
tqdm>=4.65.0 # 使用flux版本的最低要求
|
||||
psutil
|
||||
|
||||
#non essential dependencies:
|
||||
# FLUX风格塑形API依赖
|
||||
fastapi>=0.102.0
|
||||
uvicorn>=0.23.0
|
||||
python-multipart>=0.0.6
|
||||
huggingface_hub>=0.19.0
|
||||
GitPython>=3.1.30
|
||||
python-dotenv>=1.0.0
|
||||
numba
|
||||
colour-science
|
||||
rembg
|
||||
pixeloe
|
||||
transparent-background
|
||||
|
||||
# MCP服务器依赖
|
||||
mcp>=1.0.0 # Model Context Protocol Python SDK
|
||||
|
||||
# 非必要依赖
|
||||
kornia>=0.7.1
|
||||
spandrel
|
||||
soundfile
|
||||
|
|
Loading…
Reference in New Issue