918 lines
38 KiB
Python
918 lines
38 KiB
Python
import os
|
||
import sys
|
||
import asyncio
|
||
import traceback
|
||
|
||
import nodes
|
||
import folder_paths
|
||
import execution
|
||
import uuid
|
||
import urllib
|
||
import json
|
||
import glob
|
||
import struct
|
||
import ssl
|
||
import socket
|
||
import ipaddress
|
||
from PIL import Image, ImageOps
|
||
from PIL.PngImagePlugin import PngInfo
|
||
from io import BytesIO
|
||
|
||
import aiohttp
|
||
from aiohttp import web
|
||
import logging
|
||
|
||
import mimetypes
|
||
from comfy.cli_args import args
|
||
import comfy.utils
|
||
import comfy.model_management
|
||
import node_helpers
|
||
from comfyui_version import __version__
|
||
from app.frontend_management import FrontendManager
|
||
from app.user_manager import UserManager
|
||
from app.model_manager import ModelFileManager
|
||
from app.custom_node_manager import CustomNodeManager
|
||
from typing import Optional
|
||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||
|
||
class BinaryEventTypes:
|
||
PREVIEW_IMAGE = 1
|
||
UNENCODED_PREVIEW_IMAGE = 2
|
||
|
||
async def send_socket_catch_exception(function, message):
|
||
try:
|
||
await function(message)
|
||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||
logging.warning("send error: {}".format(err))
|
||
|
||
@web.middleware
|
||
async def cache_control(request: web.Request, handler):
|
||
response: web.Response = await handler(request)
|
||
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
|
||
response.headers.setdefault('Cache-Control', 'no-cache')
|
||
return response
|
||
|
||
|
||
@web.middleware
|
||
async def compress_body(request: web.Request, handler):
|
||
accept_encoding = request.headers.get("Accept-Encoding", "")
|
||
response: web.Response = await handler(request)
|
||
if not isinstance(response, web.Response):
|
||
return response
|
||
if response.content_type not in ["application/json", "text/plain"]:
|
||
return response
|
||
if response.body and "gzip" in accept_encoding:
|
||
response.enable_compression()
|
||
return response
|
||
|
||
|
||
def create_cors_middleware(allowed_origin: str):
|
||
@web.middleware
|
||
async def cors_middleware(request: web.Request, handler):
|
||
if request.method == "OPTIONS":
|
||
# Pre-flight request. Reply successfully:
|
||
response = web.Response()
|
||
else:
|
||
response = await handler(request)
|
||
|
||
response.headers['Access-Control-Allow-Origin'] = allowed_origin
|
||
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
|
||
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
|
||
response.headers['Access-Control-Allow-Credentials'] = 'true'
|
||
return response
|
||
|
||
return cors_middleware
|
||
|
||
def is_loopback(host):
|
||
if host is None:
|
||
return False
|
||
try:
|
||
if ipaddress.ip_address(host).is_loopback:
|
||
return True
|
||
else:
|
||
return False
|
||
except:
|
||
pass
|
||
|
||
loopback = False
|
||
for family in (socket.AF_INET, socket.AF_INET6):
|
||
try:
|
||
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||
for family, _, _, _, sockaddr in r:
|
||
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||
return loopback
|
||
else:
|
||
loopback = True
|
||
except socket.gaierror:
|
||
pass
|
||
|
||
return loopback
|
||
|
||
|
||
def create_origin_only_middleware():
|
||
@web.middleware
|
||
async def origin_only_middleware(request: web.Request, handler):
|
||
#this code is used to prevent the case where a random website can queue comfy workflows by making a POST to 127.0.0.1 which browsers don't prevent for some dumb reason.
|
||
#in that case the Host and Origin hostnames won't match
|
||
#I know the proper fix would be to add a cookie but this should take care of the problem in the meantime
|
||
if 'Host' in request.headers and 'Origin' in request.headers:
|
||
host = request.headers['Host']
|
||
origin = request.headers['Origin']
|
||
host_domain = host.lower()
|
||
parsed = urllib.parse.urlparse(origin)
|
||
origin_domain = parsed.netloc.lower()
|
||
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)
|
||
|
||
#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
|
||
loopback = is_loopback(host_domain_parsed.hostname)
|
||
|
||
if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
|
||
host_domain = host_domain_parsed.hostname
|
||
if host_domain_parsed.port is None:
|
||
origin_domain = parsed.hostname
|
||
|
||
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||
if host_domain != origin_domain:
|
||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||
return web.Response(status=403)
|
||
|
||
if request.method == "OPTIONS":
|
||
response = web.Response()
|
||
else:
|
||
response = await handler(request)
|
||
|
||
return response
|
||
|
||
return origin_only_middleware
|
||
|
||
class PromptServer():
|
||
def __init__(self, loop):
|
||
PromptServer.instance = self
|
||
|
||
mimetypes.init()
|
||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||
mimetypes.add_type('image/webp', '.webp')
|
||
|
||
self.user_manager = UserManager()
|
||
self.model_file_manager = ModelFileManager()
|
||
self.custom_node_manager = CustomNodeManager()
|
||
self.internal_routes = InternalRoutes(self)
|
||
self.supports = ["custom_nodes_from_web"]
|
||
self.prompt_queue = None
|
||
self.loop = loop
|
||
self.messages = asyncio.Queue()
|
||
self.client_session:Optional[aiohttp.ClientSession] = None
|
||
self.number = 0
|
||
|
||
middlewares = [cache_control]
|
||
if args.enable_compress_response_body:
|
||
middlewares.append(compress_body)
|
||
|
||
if args.enable_cors_header:
|
||
middlewares.append(create_cors_middleware(args.enable_cors_header))
|
||
else:
|
||
middlewares.append(create_origin_only_middleware())
|
||
|
||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||
self.sockets = dict()
|
||
self.web_root = (
|
||
FrontendManager.init_frontend(args.front_end_version)
|
||
if args.front_end_root is None
|
||
else args.front_end_root
|
||
)
|
||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||
routes = web.RouteTableDef()
|
||
self.routes = routes
|
||
self.last_node_id = None
|
||
self.client_id = None
|
||
|
||
self.on_prompt_handlers = []
|
||
|
||
@routes.get('/ws')
|
||
async def websocket_handler(request):
|
||
ws = web.WebSocketResponse()
|
||
await ws.prepare(request)
|
||
sid = request.rel_url.query.get('clientId', '')
|
||
if sid:
|
||
# Reusing existing session, remove old
|
||
self.sockets.pop(sid, None)
|
||
else:
|
||
sid = uuid.uuid4().hex
|
||
|
||
self.sockets[sid] = ws
|
||
|
||
try:
|
||
# Send initial state to the new client
|
||
await self.send("status", { "status": self.get_queue_info(), 'sid': sid }, sid)
|
||
# On reconnect if we are the currently executing client send the current node
|
||
if self.client_id == sid and self.last_node_id is not None:
|
||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||
|
||
async for msg in ws:
|
||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||
logging.warning('ws connection closed with exception %s' % ws.exception())
|
||
finally:
|
||
self.sockets.pop(sid, None)
|
||
return ws
|
||
|
||
@routes.get("/")
|
||
async def get_root(request):
|
||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||
response.headers['Cache-Control'] = 'no-cache'
|
||
response.headers["Pragma"] = "no-cache"
|
||
response.headers["Expires"] = "0"
|
||
return response
|
||
|
||
@routes.get("/embeddings")
|
||
def get_embeddings(self):
|
||
embeddings = folder_paths.get_filename_list("embeddings")
|
||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||
|
||
@routes.get("/models")
|
||
def list_model_types(request):
|
||
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||
|
||
return web.json_response(model_types)
|
||
|
||
@routes.get("/models/{folder}")
|
||
async def get_models(request):
|
||
folder = request.match_info.get("folder", None)
|
||
if not folder in folder_paths.folder_names_and_paths:
|
||
return web.Response(status=404)
|
||
files = folder_paths.get_filename_list(folder)
|
||
return web.json_response(files)
|
||
|
||
@routes.get("/extensions")
|
||
async def get_extensions(request):
|
||
files = glob.glob(os.path.join(
|
||
glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
||
|
||
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
||
|
||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
|
||
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
|
||
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||
|
||
return web.json_response(extensions)
|
||
|
||
def get_dir_by_type(dir_type):
|
||
if dir_type is None:
|
||
dir_type = "input"
|
||
|
||
if dir_type == "input":
|
||
type_dir = folder_paths.get_input_directory()
|
||
elif dir_type == "temp":
|
||
type_dir = folder_paths.get_temp_directory()
|
||
elif dir_type == "output":
|
||
type_dir = folder_paths.get_output_directory()
|
||
|
||
return type_dir, dir_type
|
||
|
||
def compare_image_hash(filepath, image):
|
||
hasher = node_helpers.hasher()
|
||
|
||
# function to compare hashes of two images to see if it already exists, fix to #3465
|
||
if os.path.exists(filepath):
|
||
a = hasher()
|
||
b = hasher()
|
||
with open(filepath, "rb") as f:
|
||
a.update(f.read())
|
||
b.update(image.file.read())
|
||
image.file.seek(0)
|
||
f.close()
|
||
return a.hexdigest() == b.hexdigest()
|
||
return False
|
||
|
||
def image_upload(post, image_save_function=None):
|
||
image = post.get("image")
|
||
overwrite = post.get("overwrite")
|
||
image_is_duplicate = False
|
||
|
||
image_upload_type = post.get("type")
|
||
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
|
||
|
||
if image and image.file:
|
||
filename = image.filename
|
||
if not filename:
|
||
return web.Response(status=400)
|
||
|
||
subfolder = post.get("subfolder", "")
|
||
full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder))
|
||
filepath = os.path.abspath(os.path.join(full_output_folder, filename))
|
||
|
||
if os.path.commonpath((upload_dir, filepath)) != upload_dir:
|
||
return web.Response(status=400)
|
||
|
||
if not os.path.exists(full_output_folder):
|
||
os.makedirs(full_output_folder)
|
||
|
||
split = os.path.splitext(filename)
|
||
|
||
if overwrite is not None and (overwrite == "true" or overwrite == "1"):
|
||
pass
|
||
else:
|
||
i = 1
|
||
while os.path.exists(filepath):
|
||
if compare_image_hash(filepath, image): #compare hash to prevent saving of duplicates with same name, fix for #3465
|
||
image_is_duplicate = True
|
||
break
|
||
filename = f"{split[0]} ({i}){split[1]}"
|
||
filepath = os.path.join(full_output_folder, filename)
|
||
i += 1
|
||
|
||
if not image_is_duplicate:
|
||
if image_save_function is not None:
|
||
image_save_function(image, post, filepath)
|
||
else:
|
||
with open(filepath, "wb") as f:
|
||
f.write(image.file.read())
|
||
|
||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||
else:
|
||
return web.Response(status=400)
|
||
|
||
@routes.post("/upload/image")
|
||
async def upload_image(request):
|
||
post = await request.post()
|
||
return image_upload(post)
|
||
|
||
|
||
@routes.post("/upload/mask")
|
||
async def upload_mask(request):
|
||
post = await request.post()
|
||
|
||
def image_save_function(image, post, filepath):
|
||
original_ref = json.loads(post.get("original_ref"))
|
||
filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
|
||
|
||
if not filename:
|
||
return web.Response(status=400)
|
||
|
||
# validation for security: prevent accessing arbitrary path
|
||
if filename[0] == '/' or '..' in filename:
|
||
return web.Response(status=400)
|
||
|
||
if output_dir is None:
|
||
type = original_ref.get("type", "output")
|
||
output_dir = folder_paths.get_directory_by_type(type)
|
||
|
||
if output_dir is None:
|
||
return web.Response(status=400)
|
||
|
||
if original_ref.get("subfolder", "") != "":
|
||
full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
|
||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||
return web.Response(status=403)
|
||
output_dir = full_output_dir
|
||
|
||
file = os.path.join(output_dir, filename)
|
||
|
||
if os.path.isfile(file):
|
||
with Image.open(file) as original_pil:
|
||
metadata = PngInfo()
|
||
if hasattr(original_pil,'text'):
|
||
for key in original_pil.text:
|
||
metadata.add_text(key, original_pil.text[key])
|
||
original_pil = original_pil.convert('RGBA')
|
||
mask_pil = Image.open(image.file).convert('RGBA')
|
||
|
||
# alpha copy
|
||
new_alpha = mask_pil.getchannel('A')
|
||
original_pil.putalpha(new_alpha)
|
||
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
|
||
|
||
return image_upload(post, image_save_function)
|
||
|
||
@routes.get("/view")
|
||
async def view_image(request):
|
||
if "filename" in request.rel_url.query:
|
||
filename = request.rel_url.query["filename"]
|
||
filename,output_dir = folder_paths.annotated_filepath(filename)
|
||
|
||
if not filename:
|
||
return web.Response(status=400)
|
||
|
||
# validation for security: prevent accessing arbitrary path
|
||
if filename[0] == '/' or '..' in filename:
|
||
return web.Response(status=400)
|
||
|
||
if output_dir is None:
|
||
type = request.rel_url.query.get("type", "output")
|
||
output_dir = folder_paths.get_directory_by_type(type)
|
||
|
||
if output_dir is None:
|
||
return web.Response(status=400)
|
||
|
||
if "subfolder" in request.rel_url.query:
|
||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||
return web.Response(status=403)
|
||
output_dir = full_output_dir
|
||
|
||
filename = os.path.basename(filename)
|
||
file = os.path.join(output_dir, filename)
|
||
|
||
if os.path.isfile(file):
|
||
if 'preview' in request.rel_url.query:
|
||
with Image.open(file) as img:
|
||
preview_info = request.rel_url.query['preview'].split(';')
|
||
image_format = preview_info[0]
|
||
if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
|
||
image_format = 'webp'
|
||
|
||
quality = 90
|
||
if preview_info[-1].isdigit():
|
||
quality = int(preview_info[-1])
|
||
|
||
buffer = BytesIO()
|
||
if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
|
||
img = img.convert("RGB")
|
||
img.save(buffer, format=image_format, quality=quality)
|
||
buffer.seek(0)
|
||
|
||
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
|
||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||
|
||
if 'channel' not in request.rel_url.query:
|
||
channel = 'rgba'
|
||
else:
|
||
channel = request.rel_url.query["channel"]
|
||
|
||
if channel == 'rgb':
|
||
with Image.open(file) as img:
|
||
if img.mode == "RGBA":
|
||
r, g, b, a = img.split()
|
||
new_img = Image.merge('RGB', (r, g, b))
|
||
else:
|
||
new_img = img.convert("RGB")
|
||
|
||
buffer = BytesIO()
|
||
new_img.save(buffer, format='PNG')
|
||
buffer.seek(0)
|
||
|
||
return web.Response(body=buffer.read(), content_type='image/png',
|
||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||
|
||
elif channel == 'a':
|
||
with Image.open(file) as img:
|
||
if img.mode == "RGBA":
|
||
_, _, _, a = img.split()
|
||
else:
|
||
a = Image.new('L', img.size, 255)
|
||
|
||
# alpha img
|
||
alpha_img = Image.new('RGBA', img.size)
|
||
alpha_img.putalpha(a)
|
||
alpha_buffer = BytesIO()
|
||
alpha_img.save(alpha_buffer, format='PNG')
|
||
alpha_buffer.seek(0)
|
||
|
||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||
else:
|
||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||
|
||
# For security, force certain extensions to download instead of display
|
||
file_extension = os.path.splitext(filename)[1].lower()
|
||
if file_extension in {'.html', '.htm', '.js', '.css'}:
|
||
content_type = 'application/octet-stream' # Forces download
|
||
|
||
return web.FileResponse(
|
||
file,
|
||
headers={
|
||
"Content-Disposition": f"filename=\"{filename}\"",
|
||
"Content-Type": content_type
|
||
}
|
||
)
|
||
|
||
return web.Response(status=404)
|
||
|
||
@routes.get("/view_metadata/{folder_name}")
|
||
async def view_metadata(request):
|
||
folder_name = request.match_info.get("folder_name", None)
|
||
if folder_name is None:
|
||
return web.Response(status=404)
|
||
if not "filename" in request.rel_url.query:
|
||
return web.Response(status=404)
|
||
|
||
filename = request.rel_url.query["filename"]
|
||
if not filename.endswith(".safetensors"):
|
||
return web.Response(status=404)
|
||
|
||
safetensors_path = folder_paths.get_full_path(folder_name, filename)
|
||
if safetensors_path is None:
|
||
return web.Response(status=404)
|
||
out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024)
|
||
if out is None:
|
||
return web.Response(status=404)
|
||
dt = json.loads(out)
|
||
if not "__metadata__" in dt:
|
||
return web.Response(status=404)
|
||
return web.json_response(dt["__metadata__"])
|
||
|
||
@routes.get("/system_stats")
|
||
async def system_stats(request):
|
||
device = comfy.model_management.get_torch_device()
|
||
device_name = comfy.model_management.get_torch_device_name(device)
|
||
cpu_device = comfy.model_management.torch.device("cpu")
|
||
ram_total = comfy.model_management.get_total_memory(cpu_device)
|
||
ram_free = comfy.model_management.get_free_memory(cpu_device)
|
||
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
|
||
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
|
||
|
||
system_stats = {
|
||
"system": {
|
||
"os": os.name,
|
||
"ram_total": ram_total,
|
||
"ram_free": ram_free,
|
||
"comfyui_version": __version__,
|
||
"python_version": sys.version,
|
||
"pytorch_version": comfy.model_management.torch_version,
|
||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||
"argv": sys.argv
|
||
},
|
||
"devices": [
|
||
{
|
||
"name": device_name,
|
||
"type": device.type,
|
||
"index": device.index,
|
||
"vram_total": vram_total,
|
||
"vram_free": vram_free,
|
||
"torch_vram_total": torch_vram_total,
|
||
"torch_vram_free": torch_vram_free,
|
||
}
|
||
]
|
||
}
|
||
return web.json_response(system_stats)
|
||
|
||
@routes.get("/prompt")
|
||
async def get_prompt(request):
|
||
return web.json_response(self.get_queue_info())
|
||
|
||
def node_info(node_class):
|
||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||
info = {}
|
||
info['input'] = obj_class.INPUT_TYPES()
|
||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||
info['output'] = obj_class.RETURN_TYPES
|
||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||
info['name'] = node_class
|
||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
|
||
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
|
||
info['category'] = 'sd'
|
||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||
info['output_node'] = True
|
||
else:
|
||
info['output_node'] = False
|
||
|
||
if hasattr(obj_class, 'CATEGORY'):
|
||
info['category'] = obj_class.CATEGORY
|
||
|
||
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
|
||
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
|
||
|
||
if getattr(obj_class, "DEPRECATED", False):
|
||
info['deprecated'] = True
|
||
if getattr(obj_class, "EXPERIMENTAL", False):
|
||
info['experimental'] = True
|
||
|
||
if hasattr(obj_class, 'API_NODE'):
|
||
info['api_node'] = obj_class.API_NODE
|
||
return info
|
||
|
||
@routes.get("/object_info")
|
||
async def get_object_info(request):
|
||
with folder_paths.cache_helper:
|
||
out = {}
|
||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||
try:
|
||
out[x] = node_info(x)
|
||
except Exception:
|
||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||
logging.error(traceback.format_exc())
|
||
return web.json_response(out)
|
||
|
||
@routes.get("/object_info/{node_class}")
|
||
async def get_object_info_node(request):
|
||
node_class = request.match_info.get("node_class", None)
|
||
out = {}
|
||
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||
out[node_class] = node_info(node_class)
|
||
return web.json_response(out)
|
||
|
||
@routes.get("/history")
|
||
async def get_history(request):
|
||
max_items = request.rel_url.query.get("max_items", None)
|
||
if max_items is not None:
|
||
max_items = int(max_items)
|
||
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
||
|
||
@routes.get("/history/{prompt_id}")
|
||
async def get_history_prompt_id(request):
|
||
prompt_id = request.match_info.get("prompt_id", None)
|
||
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
||
|
||
@routes.get("/queue")
|
||
async def get_queue(request):
|
||
queue_info = {}
|
||
current_queue = self.prompt_queue.get_current_queue()
|
||
queue_info['queue_running'] = current_queue[0]
|
||
queue_info['queue_pending'] = current_queue[1]
|
||
return web.json_response(queue_info)
|
||
|
||
@routes.post("/prompt")
|
||
async def post_prompt(request):
|
||
logging.info("got prompt")
|
||
json_data = await request.json()
|
||
json_data = self.trigger_on_prompt(json_data)
|
||
|
||
if "number" in json_data:
|
||
number = float(json_data['number'])
|
||
else:
|
||
number = self.number
|
||
if "front" in json_data:
|
||
if json_data['front']:
|
||
number = -number
|
||
|
||
self.number += 1
|
||
|
||
if "prompt" in json_data:
|
||
prompt = json_data["prompt"]
|
||
valid = execution.validate_prompt(prompt)
|
||
extra_data = {}
|
||
if "extra_data" in json_data:
|
||
extra_data = json_data["extra_data"]
|
||
|
||
if "client_id" in json_data:
|
||
extra_data["client_id"] = json_data["client_id"]
|
||
if valid[0]:
|
||
prompt_id = str(uuid.uuid4())
|
||
outputs_to_execute = valid[2]
|
||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||
return web.json_response(response)
|
||
else:
|
||
logging.warning("invalid prompt: {}".format(valid[1]))
|
||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||
else:
|
||
error = {
|
||
"type": "no_prompt",
|
||
"message": "No prompt provided",
|
||
"details": "No prompt provided",
|
||
"extra_info": {}
|
||
}
|
||
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
||
|
||
@routes.post("/queue")
|
||
async def post_queue(request):
|
||
json_data = await request.json()
|
||
if "clear" in json_data:
|
||
if json_data["clear"]:
|
||
self.prompt_queue.wipe_queue()
|
||
if "delete" in json_data:
|
||
to_delete = json_data['delete']
|
||
for id_to_delete in to_delete:
|
||
delete_func = lambda a: a[1] == id_to_delete
|
||
self.prompt_queue.delete_queue_item(delete_func)
|
||
|
||
return web.Response(status=200)
|
||
|
||
@routes.post("/interrupt")
|
||
async def post_interrupt(request):
|
||
nodes.interrupt_processing()
|
||
return web.Response(status=200)
|
||
|
||
@routes.post("/free")
|
||
async def post_free(request):
|
||
json_data = await request.json()
|
||
unload_models = json_data.get("unload_models", False)
|
||
free_memory = json_data.get("free_memory", False)
|
||
if unload_models:
|
||
self.prompt_queue.set_flag("unload_models", unload_models)
|
||
if free_memory:
|
||
self.prompt_queue.set_flag("free_memory", free_memory)
|
||
return web.Response(status=200)
|
||
|
||
@routes.post("/history")
|
||
async def post_history(request):
|
||
json_data = await request.json()
|
||
if "clear" in json_data:
|
||
if json_data["clear"]:
|
||
self.prompt_queue.wipe_history()
|
||
if "delete" in json_data:
|
||
to_delete = json_data['delete']
|
||
for id_to_delete in to_delete:
|
||
self.prompt_queue.delete_history_item(id_to_delete)
|
||
|
||
return web.Response(status=200)
|
||
|
||
async def setup(self):
|
||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||
|
||
def add_routes(self):
|
||
self.user_manager.add_routes(self.routes)
|
||
self.model_file_manager.add_routes(self.routes)
|
||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||
|
||
# Prefix every route with /api for easier matching for delegation.
|
||
# This is very useful for frontend dev server, which need to forward
|
||
# everything except serving of static files.
|
||
# Currently both the old endpoints without prefix and new endpoints with
|
||
# prefix are supported.
|
||
api_routes = web.RouteTableDef()
|
||
|
||
# 存储已添加的路由,避免重复
|
||
added_routes = set()
|
||
|
||
for route in self.routes:
|
||
# Custom nodes might add extra static routes. Only process non-static
|
||
# routes to add /api prefix.
|
||
if isinstance(route, web.RouteDef):
|
||
route_key = f"{route.method}:{route.path}"
|
||
if route_key not in added_routes:
|
||
api_routes.route(route.method, "/api" + route.path)(route.handler, **route.kwargs)
|
||
added_routes.add(route_key)
|
||
|
||
# 添加API路由,使用异常处理避免重复路由错误
|
||
try:
|
||
self.app.add_routes(api_routes)
|
||
except RuntimeError as e:
|
||
import logging
|
||
logging.warning(f"添加API路由时出错: {e}")
|
||
logging.warning("尝试跳过冲突路由继续启动服务...")
|
||
# 单个添加路由,跳过冲突的路由
|
||
for route in api_routes:
|
||
try:
|
||
self.app.router.add_route(route.method, route.path, route.handler, **route.kwargs)
|
||
except RuntimeError:
|
||
continue
|
||
|
||
# 添加主路由,使用相同的异常处理方式
|
||
try:
|
||
self.app.add_routes(self.routes)
|
||
except RuntimeError as e:
|
||
import logging
|
||
logging.warning(f"添加主路由时出错: {e}")
|
||
for route in self.routes:
|
||
if isinstance(route, web.RouteDef):
|
||
try:
|
||
self.app.router.add_route(route.method, route.path, route.handler, **route.kwargs)
|
||
except RuntimeError:
|
||
continue
|
||
else:
|
||
try:
|
||
self.app.add_routes([route])
|
||
except RuntimeError:
|
||
continue
|
||
|
||
# Add routes from web extensions.
|
||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||
|
||
workflow_templates_path = FrontendManager.templates_path()
|
||
if workflow_templates_path:
|
||
self.app.add_routes([
|
||
web.static('/templates', workflow_templates_path)
|
||
])
|
||
|
||
self.app.add_routes([
|
||
web.static('/', self.web_root),
|
||
])
|
||
|
||
def get_queue_info(self):
|
||
prompt_info = {}
|
||
exec_info = {}
|
||
exec_info['queue_remaining'] = self.prompt_queue.get_tasks_remaining()
|
||
prompt_info['exec_info'] = exec_info
|
||
return prompt_info
|
||
|
||
async def send(self, event, data, sid=None):
|
||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||
await self.send_image(data, sid=sid)
|
||
elif isinstance(data, (bytes, bytearray)):
|
||
await self.send_bytes(event, data, sid)
|
||
else:
|
||
await self.send_json(event, data, sid)
|
||
|
||
def encode_bytes(self, event, data):
|
||
if not isinstance(event, int):
|
||
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||
|
||
packed = struct.pack(">I", event)
|
||
message = bytearray(packed)
|
||
message.extend(data)
|
||
return message
|
||
|
||
async def send_image(self, image_data, sid=None):
|
||
image_type = image_data[0]
|
||
image = image_data[1]
|
||
max_size = image_data[2]
|
||
if max_size is not None:
|
||
if hasattr(Image, 'Resampling'):
|
||
resampling = Image.Resampling.BILINEAR
|
||
else:
|
||
resampling = Image.ANTIALIAS
|
||
|
||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||
type_num = 1
|
||
if image_type == "JPEG":
|
||
type_num = 1
|
||
elif image_type == "PNG":
|
||
type_num = 2
|
||
|
||
bytesIO = BytesIO()
|
||
header = struct.pack(">I", type_num)
|
||
bytesIO.write(header)
|
||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||
preview_bytes = bytesIO.getvalue()
|
||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||
|
||
async def send_bytes(self, event, data, sid=None):
|
||
message = self.encode_bytes(event, data)
|
||
|
||
if sid is None:
|
||
sockets = list(self.sockets.values())
|
||
for ws in sockets:
|
||
await send_socket_catch_exception(ws.send_bytes, message)
|
||
elif sid in self.sockets:
|
||
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
|
||
|
||
async def send_json(self, event, data, sid=None):
|
||
message = {"type": event, "data": data}
|
||
|
||
if sid is None:
|
||
sockets = list(self.sockets.values())
|
||
for ws in sockets:
|
||
await send_socket_catch_exception(ws.send_json, message)
|
||
elif sid in self.sockets:
|
||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||
|
||
def send_sync(self, event, data, sid=None):
|
||
self.loop.call_soon_threadsafe(
|
||
self.messages.put_nowait, (event, data, sid))
|
||
|
||
def queue_updated(self):
|
||
self.send_sync("status", { "status": self.get_queue_info() })
|
||
|
||
async def publish_loop(self):
|
||
while True:
|
||
msg = await self.messages.get()
|
||
await self.send(*msg)
|
||
|
||
async def start(self, address, port, verbose=True, call_on_start=None):
|
||
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||
|
||
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
|
||
runner = web.AppRunner(self.app, access_log=None)
|
||
await runner.setup()
|
||
ssl_ctx = None
|
||
scheme = "http"
|
||
if args.tls_keyfile and args.tls_certfile:
|
||
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
|
||
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
|
||
keyfile=args.tls_keyfile)
|
||
scheme = "https"
|
||
|
||
if verbose:
|
||
logging.info("Starting server\n")
|
||
for addr in addresses:
|
||
address = addr[0]
|
||
port = addr[1]
|
||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||
await site.start()
|
||
|
||
if not hasattr(self, 'address'):
|
||
self.address = address #TODO: remove this
|
||
self.port = port
|
||
|
||
if ':' in address:
|
||
address_print = "[{}]".format(address)
|
||
else:
|
||
address_print = address
|
||
|
||
if verbose:
|
||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||
|
||
if call_on_start is not None:
|
||
call_on_start(scheme, self.address, self.port)
|
||
|
||
def add_on_prompt_handler(self, handler):
|
||
self.on_prompt_handlers.append(handler)
|
||
|
||
def trigger_on_prompt(self, json_data):
|
||
for handler in self.on_prompt_handlers:
|
||
try:
|
||
json_data = handler(json_data)
|
||
except Exception:
|
||
logging.warning("[ERROR] An error occurred during the on_prompt_handler processing")
|
||
logging.warning(traceback.format_exc())
|
||
|
||
return json_data
|