245 lines
7.6 KiB
Python
245 lines
7.6 KiB
Python
import asyncio
|
||
import logging
|
||
from typing import Dict, List, Optional, Any, Callable, Union
|
||
import os
|
||
|
||
# 导入MCP SDK相关库
|
||
from mcp.server.fastmcp import FastMCP
|
||
from mcp.types import Resource, Tool, Prompt
|
||
from mcp import types
|
||
|
||
class MCPServer:
|
||
"""
|
||
Model Context Protocol 服务器
|
||
用于管理模型连接和资源交互
|
||
"""
|
||
instance = None
|
||
|
||
def __init__(self, server_name: str = "ComfyUI-MCP"):
|
||
"""
|
||
初始化MCP服务器
|
||
|
||
参数:
|
||
server_name: 服务器名称
|
||
"""
|
||
MCPServer.instance = self
|
||
|
||
# 初始化FastMCP实例
|
||
self.mcp = FastMCP(server_name)
|
||
|
||
# 存储注册的资源、工具和提示
|
||
self.registered_resources: Dict[str, Callable] = {}
|
||
self.registered_tools: Dict[str, Callable] = {}
|
||
self.registered_prompts: Dict[str, Callable] = {}
|
||
|
||
# 存储模型连接信息
|
||
self.active_models: Dict[str, Dict[str, Any]] = {}
|
||
self.model_count: int = 0
|
||
|
||
# 状态标志
|
||
self.is_running: bool = False
|
||
|
||
# 设置基本资源和工具
|
||
self._setup_default_handlers()
|
||
|
||
logging.info(f"[MCP服务器] 已初始化,服务器名称: {server_name}")
|
||
|
||
def _setup_default_handlers(self):
|
||
"""设置默认的资源处理器和工具"""
|
||
|
||
# 注册状态资源
|
||
@self.mcp.resource("status://server")
|
||
def get_server_status() -> str:
|
||
"""返回服务器状态信息"""
|
||
return f"""
|
||
服务器状态:
|
||
- 运行中: {self.is_running}
|
||
- 活跃模型数: {len(self.active_models)}
|
||
- 已注册资源数: {len(self.registered_resources)}
|
||
- 已注册工具数: {len(self.registered_tools)}
|
||
- 已注册提示数: {len(self.registered_prompts)}
|
||
"""
|
||
|
||
# 注册模型列表资源
|
||
@self.mcp.resource("models://list")
|
||
def list_models() -> str:
|
||
"""返回当前连接的模型列表"""
|
||
if not self.active_models:
|
||
return "当前没有活跃的模型连接"
|
||
|
||
result = "活跃模型列表:\n"
|
||
for model_id, model_info in self.active_models.items():
|
||
result += f"- ID: {model_id}, 名称: {model_info.get('name', 'Unknown')}\n"
|
||
return result
|
||
|
||
# 注册Echo工具
|
||
@self.mcp.tool()
|
||
def echo(message: str) -> str:
|
||
"""简单的Echo工具,用于测试连接"""
|
||
return f"MCP服务器回声: {message}"
|
||
|
||
# 注册系统信息工具
|
||
@self.mcp.tool()
|
||
def system_info() -> dict:
|
||
"""返回系统信息"""
|
||
import platform
|
||
return {
|
||
"os": platform.system(),
|
||
"python_version": platform.python_version(),
|
||
"hostname": platform.node(),
|
||
"cpu": platform.processor()
|
||
}
|
||
|
||
def register_resource(self, uri_pattern: str):
|
||
"""
|
||
注册一个资源处理函数
|
||
|
||
参数:
|
||
uri_pattern: 资源URI模式
|
||
"""
|
||
def decorator(func):
|
||
self.registered_resources[uri_pattern] = func
|
||
self.mcp.resource(uri_pattern)(func)
|
||
logging.info(f"[MCP服务器] 已注册资源: {uri_pattern}")
|
||
return func
|
||
return decorator
|
||
|
||
def register_tool(self, name: Optional[str] = None):
|
||
"""
|
||
注册一个工具函数
|
||
|
||
参数:
|
||
name: 工具名称(可选)
|
||
"""
|
||
def decorator(func):
|
||
tool_name = name or func.__name__
|
||
self.registered_tools[tool_name] = func
|
||
self.mcp.tool(name=tool_name)(func)
|
||
logging.info(f"[MCP服务器] 已注册工具: {tool_name}")
|
||
return func
|
||
return decorator
|
||
|
||
def register_prompt(self, name: Optional[str] = None):
|
||
"""
|
||
注册一个提示模板
|
||
|
||
参数:
|
||
name: 提示名称(可选)
|
||
"""
|
||
def decorator(func):
|
||
prompt_name = name or func.__name__
|
||
self.registered_prompts[prompt_name] = func
|
||
self.mcp.prompt(name=prompt_name)(func)
|
||
logging.info(f"[MCP服务器] 已注册提示: {prompt_name}")
|
||
return func
|
||
return decorator
|
||
|
||
def register_model(self, model_info: Dict[str, Any]) -> str:
|
||
"""
|
||
注册一个模型到MCP服务器
|
||
|
||
参数:
|
||
model_info: 模型信息字典
|
||
|
||
返回:
|
||
model_id: 模型ID
|
||
"""
|
||
model_id = f"model_{self.model_count}"
|
||
self.model_count += 1
|
||
|
||
self.active_models[model_id] = {
|
||
"id": model_id,
|
||
"registered_at": asyncio.get_event_loop().time(),
|
||
**model_info
|
||
}
|
||
|
||
logging.info(f"[MCP服务器] 已注册模型: {model_id}")
|
||
return model_id
|
||
|
||
def unregister_model(self, model_id: str) -> bool:
|
||
"""
|
||
从MCP服务器注销一个模型
|
||
|
||
参数:
|
||
model_id: 模型ID
|
||
|
||
返回:
|
||
成功与否
|
||
"""
|
||
if model_id in self.active_models:
|
||
del self.active_models[model_id]
|
||
logging.info(f"[MCP服务器] 已注销模型: {model_id}")
|
||
return True
|
||
|
||
logging.warning(f"[MCP服务器] 尝试注销未知模型: {model_id}")
|
||
return False
|
||
|
||
async def start(self, host: str = "127.0.0.1", port: int = 8189):
|
||
"""
|
||
启动MCP服务器
|
||
|
||
参数:
|
||
host: 主机地址
|
||
port: 端口号
|
||
"""
|
||
self.is_running = True
|
||
try:
|
||
# 启动FastMCP HTTP服务器
|
||
uvicorn_config = {
|
||
"host": host,
|
||
"port": port,
|
||
"log_level": "info"
|
||
}
|
||
|
||
logging.info(f"[MCP服务器] 正在启动... 地址: {host}:{port}")
|
||
|
||
# 使用FastMCP的HTTP服务器启动方法
|
||
await self.mcp.serve(**uvicorn_config)
|
||
except Exception as e:
|
||
self.is_running = False
|
||
logging.error(f"[MCP服务器] 启动失败: {str(e)}")
|
||
raise
|
||
finally:
|
||
self.is_running = False
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取服务器统计信息"""
|
||
return {
|
||
"is_running": self.is_running,
|
||
"active_models_count": len(self.active_models),
|
||
"resources_count": len(self.registered_resources),
|
||
"tools_count": len(self.registered_tools),
|
||
"prompts_count": len(self.registered_prompts)
|
||
}
|
||
|
||
# 运行MCP服务器的异步函数
|
||
async def run_server(server_name: str = "ComfyUI-MCP", host: str = "127.0.0.1", port: int = 8189):
|
||
"""
|
||
运行MCP服务器的便捷函数
|
||
|
||
参数:
|
||
server_name: 服务器名称
|
||
host: 主机地址
|
||
port: 端口号
|
||
"""
|
||
server = MCPServer(server_name)
|
||
await server.start(host, port)
|
||
return server
|
||
|
||
# 同步运行MCP服务器的函数
|
||
def run(server_name: str = "ComfyUI-MCP", host: str = "127.0.0.1", port: int = 8189):
|
||
"""
|
||
同步运行MCP服务器
|
||
|
||
参数:
|
||
server_name: 服务器名称
|
||
host: 主机地址
|
||
port: 端口号
|
||
"""
|
||
loop = asyncio.get_event_loop()
|
||
try:
|
||
return loop.run_until_complete(run_server(server_name, host, port))
|
||
except KeyboardInterrupt:
|
||
logging.info("[MCP服务器] 收到中断信号,正在关闭...")
|
||
finally:
|
||
loop.close() |