KICCO_AI_IMAGE/flux_style_shaper_api/text_to_image.py

158 lines
4.5 KiB
Python
Raw Normal View History

2025-05-06 18:51:36 +08:00
import os
import random
import logging
import torch
from typing import Union, Sequence, Mapping, Any
# 设置日志
logger = logging.getLogger(__name__)
# 从索引获取值的辅助函数
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# 获取ComfyUI根目录
def get_comfyui_dir():
# 从当前目录向上一级查找ComfyUI根目录
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
return comfyui_dir
# 文生图函数
def generate_image_from_text(
prompt: str,
prompt_negative: str = "",
width: int = 1024,
height: int = 1024,
guidance: float = 7.5,
steps: int = 30,
seed: int = -1
):
"""
使用FLUX模型从文本生成图像
参数:
prompt (str): 正面提示词
prompt_negative (str): 负面提示词
width (int): 图像宽度
height (int): 图像高度
guidance (float): 引导强度
steps (int): 采样步数
seed (int): 随机种子-1表示随机生成
返回:
str: 生成图像的保存路径
"""
logger.info("开始文生图过程...")
# 使用随机种子(如果指定为-1
if seed == -1:
seed = random.randint(1, 2**64)
# 导入必要组件
from nodes import (
NODE_CLASS_MAPPINGS, SaveImage, VAELoader,
UNETLoader, DualCLIPLoader, CLIPTextEncode,
VAEDecode, EmptyLatentImage
)
logger.info("加载模型...")
# 加载CLIP模型
dualcliploader = DualCLIPLoader()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5/t5xxl_fp16.safetensors",
clip_name2="clip_l.safetensors",
type="flux",
)
# 加载VAE模型
vaeloader = VAELoader()
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
# 加载UNET模型
unetloader = UNETLoader()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-dev.safetensors",
weight_dtype="default"
)
# 初始化其他节点
cliptextencode = CLIPTextEncode()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
emptylatentimage = EmptyLatentImage()
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = VAEDecode()
saveimage = SaveImage()
# 预加载模型到GPU
from comfy import model_management
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
])
# 图像生成流程
with torch.inference_mode():
# 文本编码
text_encoded = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(CLIP_MODEL, 0),
)
negative_encoded = cliptextencode.encode(
text=prompt_negative,
clip=get_value_at_index(CLIP_MODEL, 0),
)
# 应用FluxGuidance
flux_guided = fluxguidance.append(
guidance=guidance,
conditioning=get_value_at_index(text_encoded, 0),
)
# 创建空的潜在图像
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=1,
)
# 采样
sampled = ksampler.sample(
model=get_value_at_index(UNET_MODEL, 0),
positive=get_value_at_index(flux_guided, 0),
negative=get_value_at_index(negative_encoded, 0),
latent_image=get_value_at_index(empty_latent, 0),
seed=seed,
steps=steps,
cfg=1.0,
sampler_name="euler",
scheduler="normal",
denoise=1.0,
)
# 解码
decoded = vaedecode.decode(
samples=get_value_at_index(sampled, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
# 设置保存的文件名前缀
prefix = "Flux_Text2Image"
# 保存图像
saved = saveimage.save_images(
filename_prefix=get_value_at_index(prefix, 0),
images=get_value_at_index(decoded, 0),
)
# 获取保存的图像路径
comfyui_dir = get_comfyui_dir()
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
logger.info(f"图像保存到 {saved_path}")
return saved_path