158 lines
4.5 KiB
Python
158 lines
4.5 KiB
Python
import os
|
||
import random
|
||
import logging
|
||
import torch
|
||
from typing import Union, Sequence, Mapping, Any
|
||
|
||
# 设置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 从索引获取值的辅助函数
|
||
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
|
||
try:
|
||
return obj[index]
|
||
except KeyError:
|
||
return obj["result"][index]
|
||
|
||
# 获取ComfyUI根目录
|
||
def get_comfyui_dir():
|
||
# 从当前目录向上一级查找ComfyUI根目录
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
|
||
return comfyui_dir
|
||
|
||
# 文生图函数
|
||
def generate_image_from_text(
|
||
prompt: str,
|
||
prompt_negative: str = "",
|
||
width: int = 1024,
|
||
height: int = 1024,
|
||
guidance: float = 7.5,
|
||
steps: int = 30,
|
||
seed: int = -1
|
||
):
|
||
"""
|
||
使用FLUX模型从文本生成图像
|
||
|
||
参数:
|
||
prompt (str): 正面提示词
|
||
prompt_negative (str): 负面提示词
|
||
width (int): 图像宽度
|
||
height (int): 图像高度
|
||
guidance (float): 引导强度
|
||
steps (int): 采样步数
|
||
seed (int): 随机种子,-1表示随机生成
|
||
|
||
返回:
|
||
str: 生成图像的保存路径
|
||
"""
|
||
logger.info("开始文生图过程...")
|
||
|
||
# 使用随机种子(如果指定为-1)
|
||
if seed == -1:
|
||
seed = random.randint(1, 2**64)
|
||
|
||
# 导入必要组件
|
||
from nodes import (
|
||
NODE_CLASS_MAPPINGS, SaveImage, VAELoader,
|
||
UNETLoader, DualCLIPLoader, CLIPTextEncode,
|
||
VAEDecode, EmptyLatentImage
|
||
)
|
||
|
||
logger.info("加载模型...")
|
||
|
||
# 加载CLIP模型
|
||
dualcliploader = DualCLIPLoader()
|
||
CLIP_MODEL = dualcliploader.load_clip(
|
||
clip_name1="t5/t5xxl_fp16.safetensors",
|
||
clip_name2="clip_l.safetensors",
|
||
type="flux",
|
||
)
|
||
|
||
# 加载VAE模型
|
||
vaeloader = VAELoader()
|
||
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
|
||
|
||
# 加载UNET模型
|
||
unetloader = UNETLoader()
|
||
UNET_MODEL = unetloader.load_unet(
|
||
unet_name="flux1-dev.safetensors",
|
||
weight_dtype="default"
|
||
)
|
||
|
||
# 初始化其他节点
|
||
cliptextencode = CLIPTextEncode()
|
||
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
|
||
emptylatentimage = EmptyLatentImage()
|
||
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
|
||
vaedecode = VAEDecode()
|
||
saveimage = SaveImage()
|
||
|
||
# 预加载模型到GPU
|
||
from comfy import model_management
|
||
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL]
|
||
model_management.load_models_gpu([
|
||
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
|
||
])
|
||
|
||
# 图像生成流程
|
||
with torch.inference_mode():
|
||
# 文本编码
|
||
text_encoded = cliptextencode.encode(
|
||
text=prompt,
|
||
clip=get_value_at_index(CLIP_MODEL, 0),
|
||
)
|
||
|
||
negative_encoded = cliptextencode.encode(
|
||
text=prompt_negative,
|
||
clip=get_value_at_index(CLIP_MODEL, 0),
|
||
)
|
||
|
||
# 应用FluxGuidance
|
||
flux_guided = fluxguidance.append(
|
||
guidance=guidance,
|
||
conditioning=get_value_at_index(text_encoded, 0),
|
||
)
|
||
|
||
# 创建空的潜在图像
|
||
empty_latent = emptylatentimage.generate(
|
||
width=width,
|
||
height=height,
|
||
batch_size=1,
|
||
)
|
||
|
||
# 采样
|
||
sampled = ksampler.sample(
|
||
model=get_value_at_index(UNET_MODEL, 0),
|
||
positive=get_value_at_index(flux_guided, 0),
|
||
negative=get_value_at_index(negative_encoded, 0),
|
||
latent_image=get_value_at_index(empty_latent, 0),
|
||
seed=seed,
|
||
steps=steps,
|
||
cfg=1.0,
|
||
sampler_name="euler",
|
||
scheduler="normal",
|
||
denoise=1.0,
|
||
)
|
||
|
||
# 解码
|
||
decoded = vaedecode.decode(
|
||
samples=get_value_at_index(sampled, 0),
|
||
vae=get_value_at_index(VAE_MODEL, 0),
|
||
)
|
||
|
||
# 设置保存的文件名前缀
|
||
prefix = "Flux_Text2Image"
|
||
|
||
# 保存图像
|
||
saved = saveimage.save_images(
|
||
filename_prefix=get_value_at_index(prefix, 0),
|
||
images=get_value_at_index(decoded, 0),
|
||
)
|
||
|
||
# 获取保存的图像路径
|
||
comfyui_dir = get_comfyui_dir()
|
||
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
|
||
logger.info(f"图像保存到 {saved_path}")
|
||
|
||
return saved_path |