245 lines
8.3 KiB
Python
245 lines
8.3 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
FLUX风格塑形MCP客户端演示脚本
|
||
--------------------------
|
||
演示如何使用MCP客户端连接到服务器并使用FLUX风格塑形功能
|
||
"""
|
||
|
||
import asyncio
|
||
import argparse
|
||
import logging
|
||
import sys
|
||
import os
|
||
import base64
|
||
from typing import Dict, Any
|
||
from pathlib import Path
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 尝试导入MCP客户端库
|
||
try:
|
||
from mcp import ClientSession, HttpServerParameters, types
|
||
from mcp.client.http import http_client
|
||
except ImportError:
|
||
logger.error("未找到MCP客户端库。请安装Model Context Protocol (MCP) Python SDK。")
|
||
logger.error("安装命令: pip install mcp")
|
||
sys.exit(1)
|
||
|
||
async def save_base64_image(base64_data: str, output_path: str) -> str:
|
||
"""
|
||
将base64编码的图像数据保存到文件
|
||
|
||
参数:
|
||
base64_data: 包含MIME类型的base64编码图像数据
|
||
output_path: 输出目录路径
|
||
|
||
返回:
|
||
str: 保存的文件路径
|
||
"""
|
||
# 确保输出目录存在
|
||
os.makedirs(output_path, exist_ok=True)
|
||
|
||
# 从base64字符串提取MIME类型和数据
|
||
if ";" in base64_data and "," in base64_data:
|
||
mime_part, data_part = base64_data.split(",", 1)
|
||
content_type = mime_part.split(";")[0].split(":")[1]
|
||
|
||
# 确定文件扩展名
|
||
if content_type == "image/jpeg":
|
||
ext = ".jpg"
|
||
elif content_type == "image/png":
|
||
ext = ".png"
|
||
elif content_type == "image/webp":
|
||
ext = ".webp"
|
||
else:
|
||
ext = ".png" # 默认扩展名
|
||
else:
|
||
# 如果格式不正确,假设为PNG
|
||
data_part = base64_data
|
||
ext = ".png"
|
||
|
||
# 解码base64数据
|
||
image_data = base64.b64decode(data_part)
|
||
|
||
# 创建输出文件名
|
||
timestamp = asyncio.get_event_loop().time()
|
||
filename = f"flux_output_{int(timestamp)}{ext}"
|
||
file_path = os.path.join(output_path, filename)
|
||
|
||
# 保存图像
|
||
with open(file_path, "wb") as f:
|
||
f.write(image_data)
|
||
|
||
logger.info(f"图像已保存到: {file_path}")
|
||
return file_path
|
||
|
||
async def read_image_to_base64(image_path: str) -> str:
|
||
"""
|
||
读取图像文件并转换为base64编码
|
||
|
||
参数:
|
||
image_path: 图像文件路径
|
||
|
||
返回:
|
||
str: base64编码的图像数据
|
||
"""
|
||
if not os.path.exists(image_path):
|
||
raise FileNotFoundError(f"图像文件不存在: {image_path}")
|
||
|
||
# 读取图像文件
|
||
with open(image_path, "rb") as f:
|
||
image_data = f.read()
|
||
|
||
# 确定MIME类型
|
||
ext = os.path.splitext(image_path)[1].lower()
|
||
if ext in [".jpg", ".jpeg"]:
|
||
mime_type = "image/jpeg"
|
||
elif ext == ".png":
|
||
mime_type = "image/png"
|
||
elif ext == ".webp":
|
||
mime_type = "image/webp"
|
||
else:
|
||
mime_type = "image/png" # 默认MIME类型
|
||
|
||
# 编码为base64
|
||
encoded = base64.b64encode(image_data).decode("utf-8")
|
||
return f"data:{mime_type};base64,{encoded}"
|
||
|
||
async def run_demo(host: str, port: int, structure_image: str, style_image: str, prompt: str,
|
||
depth_strength: float, style_strength: float, output_dir: str):
|
||
"""
|
||
运行FLUX风格塑形演示
|
||
|
||
参数:
|
||
host: 服务器主机地址
|
||
port: 服务器端口
|
||
structure_image: 结构图像路径
|
||
style_image: 风格图像路径
|
||
prompt: 文本提示
|
||
depth_strength: 深度强度
|
||
style_strength: 风格强度
|
||
output_dir: 输出目录
|
||
"""
|
||
# 创建服务器参数
|
||
server_params = HttpServerParameters(
|
||
host=host,
|
||
port=port
|
||
)
|
||
|
||
# 连接到服务器
|
||
logger.info(f"正在连接到MCP服务器: {host}:{port}")
|
||
async with http_client(server_params) as (read, write):
|
||
# 创建客户端会话
|
||
async with ClientSession(read, write) as session:
|
||
# 初始化连接
|
||
logger.info("正在初始化MCP客户端会话...")
|
||
await session.initialize()
|
||
|
||
# 检查FLUX状态
|
||
try:
|
||
logger.info("正在检查FLUX服务状态...")
|
||
content, mime_type = await session.read_resource("flux://status")
|
||
logger.info(f"FLUX服务状态: {content}")
|
||
except Exception as e:
|
||
logger.error(f"获取FLUX状态失败: {e}")
|
||
return
|
||
|
||
# 列出可用工具
|
||
logger.info("正在获取可用工具列表...")
|
||
tools = await session.list_tools()
|
||
logger.info(f"可用工具: {[tool.name for tool in tools]}")
|
||
|
||
# 转换输入图像为base64
|
||
logger.info("正在读取输入图像...")
|
||
structure_image_base64 = await read_image_to_base64(structure_image)
|
||
style_image_base64 = await read_image_to_base64(style_image)
|
||
|
||
# 调用生成风格化图像工具
|
||
logger.info("正在调用生成风格化图像工具...")
|
||
try:
|
||
result = await session.call_tool(
|
||
"生成风格化图像",
|
||
arguments={
|
||
"prompt": prompt,
|
||
"structure_image_base64": structure_image_base64,
|
||
"style_image_base64": style_image_base64,
|
||
"depth_strength": depth_strength,
|
||
"style_strength": style_strength
|
||
}
|
||
)
|
||
|
||
# 处理结果
|
||
if result.get("status") == "success":
|
||
logger.info("图像生成成功")
|
||
|
||
# 保存生成的图像
|
||
image_base64 = result.get("image_base64")
|
||
if image_base64:
|
||
await save_base64_image(image_base64, output_dir)
|
||
else:
|
||
logger.warning("结果中不包含图像数据")
|
||
else:
|
||
logger.error(f"图像生成失败: {result.get('message', '未知错误')}")
|
||
except Exception as e:
|
||
logger.error(f"调用生成风格化图像工具失败: {e}")
|
||
|
||
async def main_async():
|
||
"""异步主函数"""
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description='FLUX风格塑形MCP客户端演示')
|
||
parser.add_argument('--host', type=str, default='127.0.0.1', help='MCP服务器地址')
|
||
parser.add_argument('--port', type=int, default=8189, help='MCP服务器端口')
|
||
parser.add_argument('--structure', type=str, required=True, help='结构图像路径')
|
||
parser.add_argument('--style', type=str, required=True, help='风格图像路径')
|
||
parser.add_argument('--prompt', type=str, default='', help='文本提示')
|
||
parser.add_argument('--depth-strength', type=float, default=15.0, help='深度强度')
|
||
parser.add_argument('--style-strength', type=float, default=0.5, help='风格强度')
|
||
parser.add_argument('--output', type=str, default='./output', help='输出目录')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查输入文件是否存在
|
||
if not os.path.exists(args.structure):
|
||
logger.error(f"结构图像文件不存在: {args.structure}")
|
||
return 1
|
||
|
||
if not os.path.exists(args.style):
|
||
logger.error(f"风格图像文件不存在: {args.style}")
|
||
return 1
|
||
|
||
# 运行演示
|
||
try:
|
||
await run_demo(
|
||
host=args.host,
|
||
port=args.port,
|
||
structure_image=args.structure,
|
||
style_image=args.style,
|
||
prompt=args.prompt,
|
||
depth_strength=args.depth_strength,
|
||
style_strength=args.style_strength,
|
||
output_dir=args.output
|
||
)
|
||
return 0
|
||
except KeyboardInterrupt:
|
||
logger.info("用户中断操作")
|
||
return 0
|
||
except Exception as e:
|
||
logger.error(f"演示运行出错: {e}")
|
||
return 1
|
||
|
||
def main():
|
||
"""主函数"""
|
||
try:
|
||
return asyncio.run(main_async())
|
||
except KeyboardInterrupt:
|
||
logger.info("用户中断操作")
|
||
return 0
|
||
except Exception as e:
|
||
logger.error(f"程序运行出错: {e}")
|
||
return 1
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main()) |