From 4adcea72287981f88758151c73a1892c82724ffe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Mar 2023 14:30:43 -0400 Subject: [PATCH] I don't think controlnets were being handled correctly by MPS. --- comfy/model_management.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dfeef81a..0d5702b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -62,8 +62,7 @@ if "--novram" in sys.argv: set_vram_to = NO_VRAM if "--highvram" in sys.argv: vram_state = HIGH_VRAM -if torch.backends.mps.is_available(): - vram_state = MPS + if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: @@ -78,6 +77,12 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: total_vram_available_mb = (total_vram - 1024) // 2 total_vram_available_mb = int(max(256, total_vram_available_mb)) +try: + if torch.backends.mps.is_available(): + vram_state = MPS +except: + pass + if "--cpu" in sys.argv: vram_state = CPU @@ -152,9 +157,6 @@ def load_controlnet_gpu(models): global vram_state if vram_state == CPU: return - - if vram_state == MPS: - return if vram_state == LOW_VRAM or vram_state == NO_VRAM: #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after @@ -164,9 +166,10 @@ def load_controlnet_gpu(models): if m not in models: m.cpu() + device = get_torch_device() current_gpu_controlnets = [] for m in models: - current_gpu_controlnets.append(m.cuda()) + current_gpu_controlnets.append(m.to(device)) def load_if_low_vram(model):