158 lines
4.5 KiB
Python
158 lines
4.5 KiB
Python
|
import os
|
|||
|
import random
|
|||
|
import logging
|
|||
|
import torch
|
|||
|
from typing import Union, Sequence, Mapping, Any
|
|||
|
|
|||
|
# 设置日志
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
# 从索引获取值的辅助函数
|
|||
|
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
|
|||
|
try:
|
|||
|
return obj[index]
|
|||
|
except KeyError:
|
|||
|
return obj["result"][index]
|
|||
|
|
|||
|
# 获取ComfyUI根目录
|
|||
|
def get_comfyui_dir():
|
|||
|
# 从当前目录向上一级查找ComfyUI根目录
|
|||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
|
|||
|
return comfyui_dir
|
|||
|
|
|||
|
# 文生图函数
|
|||
|
def generate_image_from_text(
|
|||
|
prompt: str,
|
|||
|
prompt_negative: str = "",
|
|||
|
width: int = 1024,
|
|||
|
height: int = 1024,
|
|||
|
guidance: float = 7.5,
|
|||
|
steps: int = 30,
|
|||
|
seed: int = -1
|
|||
|
):
|
|||
|
"""
|
|||
|
使用FLUX模型从文本生成图像
|
|||
|
|
|||
|
参数:
|
|||
|
prompt (str): 正面提示词
|
|||
|
prompt_negative (str): 负面提示词
|
|||
|
width (int): 图像宽度
|
|||
|
height (int): 图像高度
|
|||
|
guidance (float): 引导强度
|
|||
|
steps (int): 采样步数
|
|||
|
seed (int): 随机种子,-1表示随机生成
|
|||
|
|
|||
|
返回:
|
|||
|
str: 生成图像的保存路径
|
|||
|
"""
|
|||
|
logger.info("开始文生图过程...")
|
|||
|
|
|||
|
# 使用随机种子(如果指定为-1)
|
|||
|
if seed == -1:
|
|||
|
seed = random.randint(1, 2**64)
|
|||
|
|
|||
|
# 导入必要组件
|
|||
|
from nodes import (
|
|||
|
NODE_CLASS_MAPPINGS, SaveImage, VAELoader,
|
|||
|
UNETLoader, DualCLIPLoader, CLIPTextEncode,
|
|||
|
VAEDecode, EmptyLatentImage
|
|||
|
)
|
|||
|
|
|||
|
logger.info("加载模型...")
|
|||
|
|
|||
|
# 加载CLIP模型
|
|||
|
dualcliploader = DualCLIPLoader()
|
|||
|
CLIP_MODEL = dualcliploader.load_clip(
|
|||
|
clip_name1="t5/t5xxl_fp16.safetensors",
|
|||
|
clip_name2="clip_l.safetensors",
|
|||
|
type="flux",
|
|||
|
)
|
|||
|
|
|||
|
# 加载VAE模型
|
|||
|
vaeloader = VAELoader()
|
|||
|
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
|
|||
|
|
|||
|
# 加载UNET模型
|
|||
|
unetloader = UNETLoader()
|
|||
|
UNET_MODEL = unetloader.load_unet(
|
|||
|
unet_name="flux1-dev.safetensors",
|
|||
|
weight_dtype="default"
|
|||
|
)
|
|||
|
|
|||
|
# 初始化其他节点
|
|||
|
cliptextencode = CLIPTextEncode()
|
|||
|
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
|
|||
|
emptylatentimage = EmptyLatentImage()
|
|||
|
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
|
|||
|
vaedecode = VAEDecode()
|
|||
|
saveimage = SaveImage()
|
|||
|
|
|||
|
# 预加载模型到GPU
|
|||
|
from comfy import model_management
|
|||
|
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL]
|
|||
|
model_management.load_models_gpu([
|
|||
|
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
|
|||
|
])
|
|||
|
|
|||
|
# 图像生成流程
|
|||
|
with torch.inference_mode():
|
|||
|
# 文本编码
|
|||
|
text_encoded = cliptextencode.encode(
|
|||
|
text=prompt,
|
|||
|
clip=get_value_at_index(CLIP_MODEL, 0),
|
|||
|
)
|
|||
|
|
|||
|
negative_encoded = cliptextencode.encode(
|
|||
|
text=prompt_negative,
|
|||
|
clip=get_value_at_index(CLIP_MODEL, 0),
|
|||
|
)
|
|||
|
|
|||
|
# 应用FluxGuidance
|
|||
|
flux_guided = fluxguidance.append(
|
|||
|
guidance=guidance,
|
|||
|
conditioning=get_value_at_index(text_encoded, 0),
|
|||
|
)
|
|||
|
|
|||
|
# 创建空的潜在图像
|
|||
|
empty_latent = emptylatentimage.generate(
|
|||
|
width=width,
|
|||
|
height=height,
|
|||
|
batch_size=1,
|
|||
|
)
|
|||
|
|
|||
|
# 采样
|
|||
|
sampled = ksampler.sample(
|
|||
|
model=get_value_at_index(UNET_MODEL, 0),
|
|||
|
positive=get_value_at_index(flux_guided, 0),
|
|||
|
negative=get_value_at_index(negative_encoded, 0),
|
|||
|
latent_image=get_value_at_index(empty_latent, 0),
|
|||
|
seed=seed,
|
|||
|
steps=steps,
|
|||
|
cfg=1.0,
|
|||
|
sampler_name="euler",
|
|||
|
scheduler="normal",
|
|||
|
denoise=1.0,
|
|||
|
)
|
|||
|
|
|||
|
# 解码
|
|||
|
decoded = vaedecode.decode(
|
|||
|
samples=get_value_at_index(sampled, 0),
|
|||
|
vae=get_value_at_index(VAE_MODEL, 0),
|
|||
|
)
|
|||
|
|
|||
|
# 设置保存的文件名前缀
|
|||
|
prefix = "Flux_Text2Image"
|
|||
|
|
|||
|
# 保存图像
|
|||
|
saved = saveimage.save_images(
|
|||
|
filename_prefix=get_value_at_index(prefix, 0),
|
|||
|
images=get_value_at_index(decoded, 0),
|
|||
|
)
|
|||
|
|
|||
|
# 获取保存的图像路径
|
|||
|
comfyui_dir = get_comfyui_dir()
|
|||
|
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
|
|||
|
logger.info(f"图像保存到 {saved_path}")
|
|||
|
|
|||
|
return saved_path
|