KICCO_AI_IMAGE/mcp_server/flux_demo_client.py

245 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
FLUX风格塑形MCP客户端演示脚本
--------------------------
演示如何使用MCP客户端连接到服务器并使用FLUX风格塑形功能
"""
import asyncio
import argparse
import logging
import sys
import os
import base64
from typing import Dict, Any
from pathlib import Path
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 尝试导入MCP客户端库
try:
from mcp import ClientSession, HttpServerParameters, types
from mcp.client.http import http_client
except ImportError:
logger.error("未找到MCP客户端库。请安装Model Context Protocol (MCP) Python SDK。")
logger.error("安装命令: pip install mcp")
sys.exit(1)
async def save_base64_image(base64_data: str, output_path: str) -> str:
"""
将base64编码的图像数据保存到文件
参数:
base64_data: 包含MIME类型的base64编码图像数据
output_path: 输出目录路径
返回:
str: 保存的文件路径
"""
# 确保输出目录存在
os.makedirs(output_path, exist_ok=True)
# 从base64字符串提取MIME类型和数据
if ";" in base64_data and "," in base64_data:
mime_part, data_part = base64_data.split(",", 1)
content_type = mime_part.split(";")[0].split(":")[1]
# 确定文件扩展名
if content_type == "image/jpeg":
ext = ".jpg"
elif content_type == "image/png":
ext = ".png"
elif content_type == "image/webp":
ext = ".webp"
else:
ext = ".png" # 默认扩展名
else:
# 如果格式不正确假设为PNG
data_part = base64_data
ext = ".png"
# 解码base64数据
image_data = base64.b64decode(data_part)
# 创建输出文件名
timestamp = asyncio.get_event_loop().time()
filename = f"flux_output_{int(timestamp)}{ext}"
file_path = os.path.join(output_path, filename)
# 保存图像
with open(file_path, "wb") as f:
f.write(image_data)
logger.info(f"图像已保存到: {file_path}")
return file_path
async def read_image_to_base64(image_path: str) -> str:
"""
读取图像文件并转换为base64编码
参数:
image_path: 图像文件路径
返回:
str: base64编码的图像数据
"""
if not os.path.exists(image_path):
raise FileNotFoundError(f"图像文件不存在: {image_path}")
# 读取图像文件
with open(image_path, "rb") as f:
image_data = f.read()
# 确定MIME类型
ext = os.path.splitext(image_path)[1].lower()
if ext in [".jpg", ".jpeg"]:
mime_type = "image/jpeg"
elif ext == ".png":
mime_type = "image/png"
elif ext == ".webp":
mime_type = "image/webp"
else:
mime_type = "image/png" # 默认MIME类型
# 编码为base64
encoded = base64.b64encode(image_data).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
async def run_demo(host: str, port: int, structure_image: str, style_image: str, prompt: str,
depth_strength: float, style_strength: float, output_dir: str):
"""
运行FLUX风格塑形演示
参数:
host: 服务器主机地址
port: 服务器端口
structure_image: 结构图像路径
style_image: 风格图像路径
prompt: 文本提示
depth_strength: 深度强度
style_strength: 风格强度
output_dir: 输出目录
"""
# 创建服务器参数
server_params = HttpServerParameters(
host=host,
port=port
)
# 连接到服务器
logger.info(f"正在连接到MCP服务器: {host}:{port}")
async with http_client(server_params) as (read, write):
# 创建客户端会话
async with ClientSession(read, write) as session:
# 初始化连接
logger.info("正在初始化MCP客户端会话...")
await session.initialize()
# 检查FLUX状态
try:
logger.info("正在检查FLUX服务状态...")
content, mime_type = await session.read_resource("flux://status")
logger.info(f"FLUX服务状态: {content}")
except Exception as e:
logger.error(f"获取FLUX状态失败: {e}")
return
# 列出可用工具
logger.info("正在获取可用工具列表...")
tools = await session.list_tools()
logger.info(f"可用工具: {[tool.name for tool in tools]}")
# 转换输入图像为base64
logger.info("正在读取输入图像...")
structure_image_base64 = await read_image_to_base64(structure_image)
style_image_base64 = await read_image_to_base64(style_image)
# 调用生成风格化图像工具
logger.info("正在调用生成风格化图像工具...")
try:
result = await session.call_tool(
"生成风格化图像",
arguments={
"prompt": prompt,
"structure_image_base64": structure_image_base64,
"style_image_base64": style_image_base64,
"depth_strength": depth_strength,
"style_strength": style_strength
}
)
# 处理结果
if result.get("status") == "success":
logger.info("图像生成成功")
# 保存生成的图像
image_base64 = result.get("image_base64")
if image_base64:
await save_base64_image(image_base64, output_dir)
else:
logger.warning("结果中不包含图像数据")
else:
logger.error(f"图像生成失败: {result.get('message', '未知错误')}")
except Exception as e:
logger.error(f"调用生成风格化图像工具失败: {e}")
async def main_async():
"""异步主函数"""
# 解析命令行参数
parser = argparse.ArgumentParser(description='FLUX风格塑形MCP客户端演示')
parser.add_argument('--host', type=str, default='127.0.0.1', help='MCP服务器地址')
parser.add_argument('--port', type=int, default=8189, help='MCP服务器端口')
parser.add_argument('--structure', type=str, required=True, help='结构图像路径')
parser.add_argument('--style', type=str, required=True, help='风格图像路径')
parser.add_argument('--prompt', type=str, default='', help='文本提示')
parser.add_argument('--depth-strength', type=float, default=15.0, help='深度强度')
parser.add_argument('--style-strength', type=float, default=0.5, help='风格强度')
parser.add_argument('--output', type=str, default='./output', help='输出目录')
args = parser.parse_args()
# 检查输入文件是否存在
if not os.path.exists(args.structure):
logger.error(f"结构图像文件不存在: {args.structure}")
return 1
if not os.path.exists(args.style):
logger.error(f"风格图像文件不存在: {args.style}")
return 1
# 运行演示
try:
await run_demo(
host=args.host,
port=args.port,
structure_image=args.structure,
style_image=args.style,
prompt=args.prompt,
depth_strength=args.depth_strength,
style_strength=args.style_strength,
output_dir=args.output
)
return 0
except KeyboardInterrupt:
logger.info("用户中断操作")
return 0
except Exception as e:
logger.error(f"演示运行出错: {e}")
return 1
def main():
"""主函数"""
try:
return asyncio.run(main_async())
except KeyboardInterrupt:
logger.info("用户中断操作")
return 0
except Exception as e:
logger.error(f"程序运行出错: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())