KICCO_AI_IMAGE/flux_style_shaper_api/flux_fill_api.py

258 lines
9.1 KiB
Python
Raw Permalink Normal View History

2025-05-06 18:51:36 +08:00
import os
import random
import sys
import logging
import torch
from typing import Union, Sequence, Mapping, Any
import shutil
# 设置日志
logger = logging.getLogger(__name__)
# 从索引获取值的辅助函数
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# 获取ComfyUI根目录
def get_comfyui_dir():
# 从当前目录向上一级查找ComfyUI根目录
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
return comfyui_dir
# 图像填充服务函数
def generate_fill_image(
prompt: str,
prompt_negative: str = "",
input_image_path: str = None,
2025-05-06 20:44:11 +08:00
style_image: str = None,
2025-05-06 18:51:36 +08:00
reference_strength: float = 0.8,
guidance: float = 30.0,
steps: int = 28,
seed: int = -1
):
"""
使用FLUX Fill模型进行图像填充可选择使用参考图像
参数:
prompt (str): 正面提示词
prompt_negative (str): 负面提示词
input_image_path (str): 输入图像路径带有透明区域作为填充遮罩
2025-05-06 20:44:11 +08:00
style_image (str): 参考图像路径用于指导填充风格和内容
2025-05-06 18:51:36 +08:00
reference_strength (float): 参考图像的影响强度0-1之间默认0.8
guidance (float): 引导强度默认30.0
steps (int): 采样步数默认28
seed (int): 随机种子-1表示随机生成
返回:
str: 生成图像的保存路径
"""
logger.info("开始图像填充过程...")
# 使用随机种子(如果指定为-1
if seed == -1:
seed = random.randint(1, 2**64)
# 导入必要组件
from nodes import (
NODE_CLASS_MAPPINGS, SaveImage, VAELoader,
UNETLoader, DualCLIPLoader, CLIPTextEncode,
VAEDecode, LoadImage, InpaintModelConditioning,
2025-05-06 20:36:56 +08:00
CLIPVisionEncode, ConditioningCombine, CLIPVisionLoader,StyleModelLoader,StyleModelApply
2025-05-06 18:51:36 +08:00
)
logger.info("加载模型...")
# 加载CLIP模型
dualcliploader = DualCLIPLoader()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5/t5xxl_fp16.safetensors",
type="flux",
device="default"
)
# 加载CLIP Vision模型
clipvisionloader = CLIPVisionLoader()
CLIP_VISION_MODEL = clipvisionloader.load_clip(
clip_name="sigclip_vision_patch14_384.safetensors"
)
# 加载VAE模型
vaeloader = VAELoader()
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
# 加载UNET模型Fill模型
unetloader = UNETLoader()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-fill-dev.safetensors",
weight_dtype="default"
)
# 初始化其他节点
cliptextencode = CLIPTextEncode()
clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
conditioningcombine = NODE_CLASS_MAPPINGS["ConditioningCombine"]()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
differentialdiffusion = NODE_CLASS_MAPPINGS["DifferentialDiffusion"]()
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = VAEDecode()
saveimage = SaveImage()
loadimage = LoadImage()
inpaintmodelconditioning = InpaintModelConditioning()
# 把输入图像复制到ComfyUI的input目录
comfyui_dir = get_comfyui_dir()
input_dir = os.path.join(comfyui_dir, "input")
os.makedirs(input_dir, exist_ok=True)
# 创建临时文件名以避免冲突
input_filename = f"input_{random.randint(1000, 9999)}.png"
ref_filename = None
2025-05-06 20:44:11 +08:00
if style_image:
2025-05-06 18:51:36 +08:00
ref_filename = f"ref_{random.randint(1000, 9999)}.png"
input_comfy_path = os.path.join(input_dir, input_filename)
shutil.copy(input_image_path, input_comfy_path)
ref_comfy_path = None
2025-05-06 20:44:11 +08:00
if style_image:
2025-05-06 18:51:36 +08:00
ref_comfy_path = os.path.join(input_dir, ref_filename)
2025-05-06 20:44:11 +08:00
shutil.copy(style_image, ref_comfy_path)
2025-05-06 18:51:36 +08:00
# 预加载模型到GPU
from comfy import model_management
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
])
# 图像生成流程
with torch.inference_mode():
# 文本编码
text_encoded = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(CLIP_MODEL, 0),
)
negative_encoded = cliptextencode.encode(
text=prompt_negative,
clip=get_value_at_index(CLIP_MODEL, 0),
)
# 应用FluxGuidance
flux_guided = fluxguidance.append(
guidance=guidance,
conditioning=get_value_at_index(text_encoded, 0),
)
2025-05-06 20:36:56 +08:00
# 如果有参考图像使用StyleModelApplyAdvanced处理
2025-05-06 18:51:36 +08:00
final_positive_cond = get_value_at_index(flux_guided, 0)
2025-05-06 20:44:11 +08:00
if style_image:
2025-05-06 18:51:36 +08:00
try:
2025-05-06 20:36:56 +08:00
# 初始化StyleModelLoader和StyleModelApplyAdvanced节点
stylemodelloader = StyleModelLoader()
STYLE_MODEL = stylemodelloader.load_style_model(
style_model_name="flux1-redux-dev.safetensors"
)
stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() # 应用风格
# 加载参考图像
ref_img = loadimage.load_image(image=ref_filename)
# 使用CLIP Vision编码参考图像
vision_encoded = clipvisionencode.encode(
crop="center", # 中心裁剪
clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
image=get_value_at_index(ref_img, 0)
)
# 应用风格 - 将参考图像特征应用到条件上
style_applied = stylemodelapplyadvanced.apply_stylemodel(
conditioning=final_positive_cond, # 条件
style_model=get_value_at_index(STYLE_MODEL, 0), # 风格模型
clip_vision_output=get_value_at_index(vision_encoded, 0), # 参考图像特征
strength=0.5, # 风格强度参数
)
# 获取应用风格后的条件
final_positive_cond = get_value_at_index(style_applied, 0)
logger.info("成功应用参考图像风格")
2025-05-06 18:51:36 +08:00
except Exception as e:
2025-05-06 20:36:56 +08:00
logger.error(f"应用参考图像风格时出错: {e}")
# 继续使用原始条件,不中断处理
logger.info("继续使用原始条件而不应用参考图像")
2025-05-06 18:51:36 +08:00
# 加载图像(同时获取图像和遮罩)
loaded_image = loadimage.load_image(image=input_filename)
input_img = get_value_at_index(loaded_image, 0) # 第一个输出是图像
mask_img = get_value_at_index(loaded_image, 1) # 第二个输出是遮罩
# 应用DifferentialDiffusion
diff_model = differentialdiffusion.apply(
model=get_value_at_index(UNET_MODEL, 0),
)
# 准备InpaintModelConditioning
inpaint_cond = inpaintmodelconditioning.encode(
positive=final_positive_cond,
negative=get_value_at_index(negative_encoded, 0),
vae=get_value_at_index(VAE_MODEL, 0),
pixels=input_img,
mask=mask_img,
noise_mask=False,
)
# 从inpaint_cond获取正面条件、负面条件和潜在图像
positive_cond = get_value_at_index(inpaint_cond, 0)
negative_cond = get_value_at_index(inpaint_cond, 1)
latent_image = get_value_at_index(inpaint_cond, 2)
# 采样
sampled = ksampler.sample(
model=get_value_at_index(diff_model, 0),
positive=positive_cond,
negative=negative_cond,
latent_image=latent_image,
seed=seed,
steps=steps,
cfg=1.0, # 使用FluxGuidance时cfg通常设为1.0
sampler_name="euler",
scheduler="normal",
denoise=1.0,
)
# 解码
decoded = vaedecode.decode(
samples=get_value_at_index(sampled, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
# 设置保存的文件名前缀
prefix = "Flux_Fill"
# 保存图像
saved = saveimage.save_images(
filename_prefix=prefix,
images=get_value_at_index(decoded, 0),
)
# 获取保存的图像路径
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
logger.info(f"图像保存到 {saved_path}")
# 清理临时文件
try:
os.remove(input_comfy_path)
if ref_comfy_path:
os.remove(ref_comfy_path)
except Exception as e:
logger.warning(f"清理临时文件时出错: {e}")
return saved_path