2023-03-12 02:09:28 +08:00
import os
2024-05-27 01:44:17 +08:00
from spandrel import ModelLoader , ImageModelDescriptor
2023-04-16 06:55:17 +08:00
from comfy import model_management
2023-03-12 02:09:28 +08:00
import torch
2023-03-12 03:04:13 +08:00
import comfy . utils
2023-03-18 05:57:57 +08:00
import folder_paths
2023-03-12 02:09:28 +08:00
class UpscaleModelLoader :
@classmethod
def INPUT_TYPES ( s ) :
2023-03-18 05:57:57 +08:00
return { " required " : { " model_name " : ( folder_paths . get_filename_list ( " upscale_models " ) , ) ,
2023-03-12 02:09:28 +08:00
} }
RETURN_TYPES = ( " UPSCALE_MODEL " , )
FUNCTION = " load_model "
CATEGORY = " loaders "
def load_model ( self , model_name ) :
2023-03-18 05:57:57 +08:00
model_path = folder_paths . get_full_path ( " upscale_models " , model_name )
2023-05-15 03:10:40 +08:00
sd = comfy . utils . load_torch_file ( model_path , safe_load = True )
2023-09-07 15:31:43 +08:00
if " module.layers.0.residual_group.blocks.0.norm1.weight " in sd :
sd = comfy . utils . state_dict_prefix_replace ( sd , { " module. " : " " } )
2024-05-27 01:44:17 +08:00
out = ModelLoader ( ) . load_from_state_dict ( sd ) . eval ( )
if not isinstance ( out , ImageModelDescriptor ) :
raise Exception ( " Upscale model must be a single-image model. " )
2023-03-12 02:09:28 +08:00
return ( out , )
class ImageUpscaleWithModel :
@classmethod
def INPUT_TYPES ( s ) :
return { " required " : { " upscale_model " : ( " UPSCALE_MODEL " , ) ,
" image " : ( " IMAGE " , ) ,
} }
RETURN_TYPES = ( " IMAGE " , )
FUNCTION = " upscale "
2023-03-12 07:10:36 +08:00
CATEGORY = " image/upscaling "
2023-03-12 02:09:28 +08:00
def upscale ( self , upscale_model , image ) :
2023-03-16 03:18:18 +08:00
device = model_management . get_torch_device ( )
2024-04-23 06:42:41 +08:00
memory_required = model_management . module_size ( upscale_model )
2024-04-26 05:04:19 +08:00
memory_required + = ( 512 * 512 * 3 ) * image . element_size ( ) * max ( upscale_model . scale , 1.0 ) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
2024-04-23 06:42:41 +08:00
memory_required + = image . nelement ( ) * image . element_size ( )
model_management . free_memory ( memory_required , device )
2023-03-12 02:09:28 +08:00
upscale_model . to ( device )
in_img = image . movedim ( - 1 , - 3 ) . to ( device )
2023-07-25 07:47:32 +08:00
tile = 512
overlap = 32
oom = True
while oom :
try :
steps = in_img . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( in_img . shape [ 3 ] , in_img . shape [ 2 ] , tile_x = tile , tile_y = tile , overlap = overlap )
pbar = comfy . utils . ProgressBar ( steps )
s = comfy . utils . tiled_scale ( in_img , lambda a : upscale_model ( a ) , tile_x = tile , tile_y = tile , overlap = overlap , upscale_amount = upscale_model . scale , pbar = pbar )
oom = False
except model_management . OOM_EXCEPTION as e :
tile / / = 2
if tile < 128 :
raise e
2023-05-03 02:18:07 +08:00
2024-05-27 01:44:17 +08:00
upscale_model . to ( " cpu " )
2023-03-12 02:09:28 +08:00
s = torch . clamp ( s . movedim ( - 3 , - 1 ) , min = 0 , max = 1.0 )
return ( s , )
NODE_CLASS_MAPPINGS = {
" UpscaleModelLoader " : UpscaleModelLoader ,
" ImageUpscaleWithModel " : ImageUpscaleWithModel
}