KICCO_AI_IMAGE/flux_style_shaper_api/flux_fill_api.py

258 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import sys
import logging
import torch
from typing import Union, Sequence, Mapping, Any
import shutil
# 设置日志
logger = logging.getLogger(__name__)
# 从索引获取值的辅助函数
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# 获取ComfyUI根目录
def get_comfyui_dir():
# 从当前目录向上一级查找ComfyUI根目录
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
return comfyui_dir
# 图像填充服务函数
def generate_fill_image(
prompt: str,
prompt_negative: str = "",
input_image_path: str = None,
style_image: str = None,
reference_strength: float = 0.8,
guidance: float = 30.0,
steps: int = 28,
seed: int = -1
):
"""
使用FLUX Fill模型进行图像填充可选择使用参考图像
参数:
prompt (str): 正面提示词
prompt_negative (str): 负面提示词
input_image_path (str): 输入图像路径(带有透明区域作为填充遮罩)
style_image (str): 参考图像路径(用于指导填充风格和内容)
reference_strength (float): 参考图像的影响强度0-1之间默认0.8
guidance (float): 引导强度默认30.0
steps (int): 采样步数默认28
seed (int): 随机种子,-1表示随机生成
返回:
str: 生成图像的保存路径
"""
logger.info("开始图像填充过程...")
# 使用随机种子(如果指定为-1
if seed == -1:
seed = random.randint(1, 2**64)
# 导入必要组件
from nodes import (
NODE_CLASS_MAPPINGS, SaveImage, VAELoader,
UNETLoader, DualCLIPLoader, CLIPTextEncode,
VAEDecode, LoadImage, InpaintModelConditioning,
CLIPVisionEncode, ConditioningCombine, CLIPVisionLoader,StyleModelLoader,StyleModelApply
)
logger.info("加载模型...")
# 加载CLIP模型
dualcliploader = DualCLIPLoader()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5/t5xxl_fp16.safetensors",
type="flux",
device="default"
)
# 加载CLIP Vision模型
clipvisionloader = CLIPVisionLoader()
CLIP_VISION_MODEL = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
# 加载VAE模型
vaeloader = VAELoader()
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
# 加载UNET模型Fill模型
unetloader = UNETLoader()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-fill-dev.safetensors",
weight_dtype="default"
)
# 初始化其他节点
cliptextencode = CLIPTextEncode()
clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
conditioningcombine = NODE_CLASS_MAPPINGS["ConditioningCombine"]()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
differentialdiffusion = NODE_CLASS_MAPPINGS["DifferentialDiffusion"]()
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = VAEDecode()
saveimage = SaveImage()
loadimage = LoadImage()
inpaintmodelconditioning = InpaintModelConditioning()
# 把输入图像复制到ComfyUI的input目录
comfyui_dir = get_comfyui_dir()
input_dir = os.path.join(comfyui_dir, "input")
os.makedirs(input_dir, exist_ok=True)
# 创建临时文件名以避免冲突
input_filename = f"input_{random.randint(1000, 9999)}.png"
ref_filename = None
if style_image:
ref_filename = f"ref_{random.randint(1000, 9999)}.png"
input_comfy_path = os.path.join(input_dir, input_filename)
shutil.copy(input_image_path, input_comfy_path)
ref_comfy_path = None
if style_image:
ref_comfy_path = os.path.join(input_dir, ref_filename)
shutil.copy(style_image, ref_comfy_path)
# 预加载模型到GPU
from comfy import model_management
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
])
# 图像生成流程
with torch.inference_mode():
# 文本编码
text_encoded = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(CLIP_MODEL, 0),
)
negative_encoded = cliptextencode.encode(
text=prompt_negative,
clip=get_value_at_index(CLIP_MODEL, 0),
)
# 应用FluxGuidance
flux_guided = fluxguidance.append(
guidance=guidance,
conditioning=get_value_at_index(text_encoded, 0),
)
# 如果有参考图像使用StyleModelApplyAdvanced处理
final_positive_cond = get_value_at_index(flux_guided, 0)
if style_image:
try:
# 初始化StyleModelLoader和StyleModelApplyAdvanced节点
stylemodelloader = StyleModelLoader()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() # 应用风格
# 加载参考图像
ref_img = loadimage.load_image(image=ref_filename)
# 使用CLIP Vision编码参考图像
vision_encoded = clipvisionencode.encode(
crop="center", # 中心裁剪
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
image=get_value_at_index(ref_img, 0)
)
# 应用风格 - 将参考图像特征应用到条件上
style_applied = stylemodelapplyadvanced.apply_stylemodel(
conditioning=final_positive_cond, # 条件
style_model=get_value_at_index(STYLE_MODEL, 0), # 风格模型
clip_vision_output=get_value_at_index(vision_encoded, 0), # 参考图像特征
strength=0.5, # 风格强度参数
)
# 获取应用风格后的条件
final_positive_cond = get_value_at_index(style_applied, 0)
logger.info("成功应用参考图像风格")
except Exception as e:
logger.error(f"应用参考图像风格时出错: {e}")
# 继续使用原始条件,不中断处理
logger.info("继续使用原始条件而不应用参考图像")
# 加载图像(同时获取图像和遮罩)
loaded_image = loadimage.load_image(image=input_filename)
input_img = get_value_at_index(loaded_image, 0) # 第一个输出是图像
mask_img = get_value_at_index(loaded_image, 1) # 第二个输出是遮罩
# 应用DifferentialDiffusion
diff_model = differentialdiffusion.apply(
model=get_value_at_index(UNET_MODEL, 0),
)
# 准备InpaintModelConditioning
inpaint_cond = inpaintmodelconditioning.encode(
positive=final_positive_cond,
negative=get_value_at_index(negative_encoded, 0),
vae=get_value_at_index(VAE_MODEL, 0),
pixels=input_img,
mask=mask_img,
noise_mask=False,
)
# 从inpaint_cond获取正面条件、负面条件和潜在图像
positive_cond = get_value_at_index(inpaint_cond, 0)
negative_cond = get_value_at_index(inpaint_cond, 1)
latent_image = get_value_at_index(inpaint_cond, 2)
# 采样
sampled = ksampler.sample(
model=get_value_at_index(diff_model, 0),
positive=positive_cond,
negative=negative_cond,
latent_image=latent_image,
seed=seed,
steps=steps,
cfg=1.0, # 使用FluxGuidance时cfg通常设为1.0
sampler_name="euler",
scheduler="normal",
denoise=1.0,
)
# 解码
decoded = vaedecode.decode(
samples=get_value_at_index(sampled, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
# 设置保存的文件名前缀
prefix = "Flux_Fill"
# 保存图像
saved = saveimage.save_images(
filename_prefix=prefix,
images=get_value_at_index(decoded, 0),
)
# 获取保存的图像路径
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
logger.info(f"图像保存到 {saved_path}")
# 清理临时文件
try:
os.remove(input_comfy_path)
if ref_comfy_path:
os.remove(ref_comfy_path)
except Exception as e:
logger.warning(f"清理临时文件时出错: {e}")
return saved_path