KICCO_AI_IMAGE/mcp_server/flux_demo_client.py

245 lines
8.3 KiB
Python
Raw Normal View History

2025-04-24 19:25:29 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
FLUX风格塑形MCP客户端演示脚本
--------------------------
演示如何使用MCP客户端连接到服务器并使用FLUX风格塑形功能
"""
import asyncio
import argparse
import logging
import sys
import os
import base64
from typing import Dict, Any
from pathlib import Path
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 尝试导入MCP客户端库
try:
from mcp import ClientSession, HttpServerParameters, types
from mcp.client.http import http_client
except ImportError:
logger.error("未找到MCP客户端库。请安装Model Context Protocol (MCP) Python SDK。")
logger.error("安装命令: pip install mcp")
sys.exit(1)
async def save_base64_image(base64_data: str, output_path: str) -> str:
"""
将base64编码的图像数据保存到文件
参数:
base64_data: 包含MIME类型的base64编码图像数据
output_path: 输出目录路径
返回:
str: 保存的文件路径
"""
# 确保输出目录存在
os.makedirs(output_path, exist_ok=True)
# 从base64字符串提取MIME类型和数据
if ";" in base64_data and "," in base64_data:
mime_part, data_part = base64_data.split(",", 1)
content_type = mime_part.split(";")[0].split(":")[1]
# 确定文件扩展名
if content_type == "image/jpeg":
ext = ".jpg"
elif content_type == "image/png":
ext = ".png"
elif content_type == "image/webp":
ext = ".webp"
else:
ext = ".png" # 默认扩展名
else:
# 如果格式不正确假设为PNG
data_part = base64_data
ext = ".png"
# 解码base64数据
image_data = base64.b64decode(data_part)
# 创建输出文件名
timestamp = asyncio.get_event_loop().time()
filename = f"flux_output_{int(timestamp)}{ext}"
file_path = os.path.join(output_path, filename)
# 保存图像
with open(file_path, "wb") as f:
f.write(image_data)
logger.info(f"图像已保存到: {file_path}")
return file_path
async def read_image_to_base64(image_path: str) -> str:
"""
读取图像文件并转换为base64编码
参数:
image_path: 图像文件路径
返回:
str: base64编码的图像数据
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"图像文件不存在: {image_path}")
# 读取图像文件
with open(image_path, "rb") as f:
image_data = f.read()
# 确定MIME类型
ext = os.path.splitext(image_path)[1].lower()
if ext in [".jpg", ".jpeg"]:
mime_type = "image/jpeg"
elif ext == ".png":
mime_type = "image/png"
elif ext == ".webp":
mime_type = "image/webp"
else:
mime_type = "image/png" # 默认MIME类型
# 编码为base64
encoded = base64.b64encode(image_data).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
async def run_demo(host: str, port: int, structure_image: str, style_image: str, prompt: str,
depth_strength: float, style_strength: float, output_dir: str):
"""
运行FLUX风格塑形演示
参数:
host: 服务器主机地址
port: 服务器端口
structure_image: 结构图像路径
style_image: 风格图像路径
prompt: 文本提示
depth_strength: 深度强度
style_strength: 风格强度
output_dir: 输出目录
"""
# 创建服务器参数
server_params = HttpServerParameters(
host=host,
port=port
)
# 连接到服务器
logger.info(f"正在连接到MCP服务器: {host}:{port}")
async with http_client(server_params) as (read, write):
# 创建客户端会话
async with ClientSession(read, write) as session:
# 初始化连接
logger.info("正在初始化MCP客户端会话...")
await session.initialize()
# 检查FLUX状态
try:
logger.info("正在检查FLUX服务状态...")
content, mime_type = await session.read_resource("flux://status")
logger.info(f"FLUX服务状态: {content}")
except Exception as e:
logger.error(f"获取FLUX状态失败: {e}")
return
# 列出可用工具
logger.info("正在获取可用工具列表...")
tools = await session.list_tools()
logger.info(f"可用工具: {[tool.name for tool in tools]}")
# 转换输入图像为base64
logger.info("正在读取输入图像...")
structure_image_base64 = await read_image_to_base64(structure_image)
style_image_base64 = await read_image_to_base64(style_image)
# 调用生成风格化图像工具
logger.info("正在调用生成风格化图像工具...")
try:
result = await session.call_tool(
"生成风格化图像",
arguments={
"prompt": prompt,
"structure_image_base64": structure_image_base64,
"style_image_base64": style_image_base64,
"depth_strength": depth_strength,
"style_strength": style_strength
}
)
# 处理结果
if result.get("status") == "success":
logger.info("图像生成成功")
# 保存生成的图像
image_base64 = result.get("image_base64")
if image_base64:
await save_base64_image(image_base64, output_dir)
else:
logger.warning("结果中不包含图像数据")
else:
logger.error(f"图像生成失败: {result.get('message', '未知错误')}")
except Exception as e:
logger.error(f"调用生成风格化图像工具失败: {e}")
async def main_async():
"""异步主函数"""
# 解析命令行参数
parser = argparse.ArgumentParser(description='FLUX风格塑形MCP客户端演示')
parser.add_argument('--host', type=str, default='127.0.0.1', help='MCP服务器地址')
parser.add_argument('--port', type=int, default=8189, help='MCP服务器端口')
parser.add_argument('--structure', type=str, required=True, help='结构图像路径')
parser.add_argument('--style', type=str, required=True, help='风格图像路径')
parser.add_argument('--prompt', type=str, default='', help='文本提示')
parser.add_argument('--depth-strength', type=float, default=15.0, help='深度强度')
parser.add_argument('--style-strength', type=float, default=0.5, help='风格强度')
parser.add_argument('--output', type=str, default='./output', help='输出目录')
args = parser.parse_args()
# 检查输入文件是否存在
if not os.path.exists(args.structure):
logger.error(f"结构图像文件不存在: {args.structure}")
return 1
if not os.path.exists(args.style):
logger.error(f"风格图像文件不存在: {args.style}")
return 1
# 运行演示
try:
await run_demo(
host=args.host,
port=args.port,
structure_image=args.structure,
style_image=args.style,
prompt=args.prompt,
depth_strength=args.depth_strength,
style_strength=args.style_strength,
output_dir=args.output
)
return 0
except KeyboardInterrupt:
logger.info("用户中断操作")
return 0
except Exception as e:
logger.error(f"演示运行出错: {e}")
return 1
def main():
"""主函数"""
try:
return asyncio.run(main_async())
except KeyboardInterrupt:
logger.info("用户中断操作")
return 0
except Exception as e:
logger.error(f"程序运行出错: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())