KICCO_AI_IMAGE/flux_style_shaper_api/text_to_image.py

158 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import logging
import torch
from typing import Union, Sequence, Mapping, Any
# 设置日志
logger = logging.getLogger(__name__)
# 从索引获取值的辅助函数
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# 获取ComfyUI根目录
def get_comfyui_dir():
# 从当前目录向上一级查找ComfyUI根目录
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_dir = os.path.dirname(current_dir) # 当前目录的上一级目录
return comfyui_dir
# 文生图函数
def generate_image_from_text(
prompt: str,
prompt_negative: str = "",
width: int = 1024,
height: int = 1024,
guidance: float = 7.5,
steps: int = 30,
seed: int = -1
):
"""
使用FLUX模型从文本生成图像
参数:
prompt (str): 正面提示词
prompt_negative (str): 负面提示词
width (int): 图像宽度
height (int): 图像高度
guidance (float): 引导强度
steps (int): 采样步数
seed (int): 随机种子,-1表示随机生成
返回:
str: 生成图像的保存路径
"""
logger.info("开始文生图过程...")
# 使用随机种子(如果指定为-1
if seed == -1:
seed = random.randint(1, 2**64)
# 导入必要组件
from nodes import (
NODE_CLASS_MAPPINGS, SaveImage, VAELoader,
UNETLoader, DualCLIPLoader, CLIPTextEncode,
VAEDecode, EmptyLatentImage
)
logger.info("加载模型...")
# 加载CLIP模型
dualcliploader = DualCLIPLoader()
CLIP_MODEL = dualcliploader.load_clip(
clip_name1="t5/t5xxl_fp16.safetensors",
clip_name2="clip_l.safetensors",
type="flux",
)
# 加载VAE模型
vaeloader = VAELoader()
VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
# 加载UNET模型
unetloader = UNETLoader()
UNET_MODEL = unetloader.load_unet(
unet_name="flux1-dev.safetensors",
weight_dtype="default"
)
# 初始化其他节点
cliptextencode = CLIPTextEncode()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
emptylatentimage = EmptyLatentImage()
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = VAEDecode()
saveimage = SaveImage()
# 预加载模型到GPU
from comfy import model_management
model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL]
model_management.load_models_gpu([
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
])
# 图像生成流程
with torch.inference_mode():
# 文本编码
text_encoded = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(CLIP_MODEL, 0),
)
negative_encoded = cliptextencode.encode(
text=prompt_negative,
clip=get_value_at_index(CLIP_MODEL, 0),
)
# 应用FluxGuidance
flux_guided = fluxguidance.append(
guidance=guidance,
conditioning=get_value_at_index(text_encoded, 0),
)
# 创建空的潜在图像
empty_latent = emptylatentimage.generate(
width=width,
height=height,
batch_size=1,
)
# 采样
sampled = ksampler.sample(
model=get_value_at_index(UNET_MODEL, 0),
positive=get_value_at_index(flux_guided, 0),
negative=get_value_at_index(negative_encoded, 0),
latent_image=get_value_at_index(empty_latent, 0),
seed=seed,
steps=steps,
cfg=1.0,
sampler_name="euler",
scheduler="normal",
denoise=1.0,
)
# 解码
decoded = vaedecode.decode(
samples=get_value_at_index(sampled, 0),
vae=get_value_at_index(VAE_MODEL, 0),
)
# 设置保存的文件名前缀
prefix = "Flux_Text2Image"
# 保存图像
saved = saveimage.save_images(
filename_prefix=get_value_at_index(prefix, 0),
images=get_value_at_index(decoded, 0),
)
# 获取保存的图像路径
comfyui_dir = get_comfyui_dir()
saved_path = os.path.join(comfyui_dir, "output", saved['ui']['images'][0]['filename'])
logger.info(f"图像保存到 {saved_path}")
return saved_path