From 89fd5ed5740fb5ad87842bb0094a531b6dcd2f2d Mon Sep 17 00:00:00 2001 From: Yurii Mazurevich Date: Fri, 24 Mar 2023 14:04:50 +0200 Subject: [PATCH 1/3] Added MPS device support --- comfy/model_management.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 809b19ea..6288d762 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,6 +4,7 @@ NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 +MPS = 4 accelerate_enabled = False vram_state = NORMAL_VRAM @@ -61,7 +62,8 @@ if "--novram" in sys.argv: set_vram_to = NO_VRAM if "--highvram" in sys.argv: vram_state = HIGH_VRAM - +if torch.backends.mps.is_available(): + vram_state = MPS if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: try: @@ -79,7 +81,7 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM: if "--cpu" in sys.argv: vram_state = CPU -print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state]) +print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state]) current_loaded_model = None @@ -128,6 +130,12 @@ def load_model_gpu(model): current_loaded_model = model if vram_state == CPU: pass + elif vram_state == MPS: + # print(inspect.getmro(real_model.__class__)) + # print(dir(real_model)) + mps_device = torch.device("mps") + real_model.to(mps_device) + pass elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM: model_accelerated = False real_model.cuda() @@ -146,6 +154,9 @@ def load_controlnet_gpu(models): global vram_state if vram_state == CPU: return + + if vram_state == MPS: + return if vram_state == LOW_VRAM or vram_state == NO_VRAM: #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after @@ -173,6 +184,8 @@ def unload_if_low_vram(model): return model def get_torch_device(): + if vram_state == MPS: + return torch.device("mps") if vram_state == CPU: return torch.device("cpu") else: @@ -195,7 +208,7 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and dev.type == 'cpu': + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: @@ -224,8 +237,12 @@ def cpu_mode(): global vram_state return vram_state == CPU +def mps_mode(): + global vram_state + return vram_state == MPS + def should_use_fp16(): - if cpu_mode(): + if cpu_mode() or mps_mode(): return False #TODO ? if torch.cuda.is_bf16_supported(): From 4b943d2b60857f5fe88723d3e81e3cd46e4c05f2 Mon Sep 17 00:00:00 2001 From: Yurii Mazurevich Date: Fri, 24 Mar 2023 14:15:30 +0200 Subject: [PATCH 2/3] Removed unnecessary comment --- comfy/model_management.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 6288d762..db774e49 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -131,8 +131,6 @@ def load_model_gpu(model): if vram_state == CPU: pass elif vram_state == MPS: - # print(inspect.getmro(real_model.__class__)) - # print(dir(real_model)) mps_device = torch.device("mps") real_model.to(mps_device) pass From fc71e7ea08db1f41eaac0409814b25bfbf37da88 Mon Sep 17 00:00:00 2001 From: Yurii Mazurevich Date: Fri, 24 Mar 2023 19:39:55 +0200 Subject: [PATCH 3/3] Fixed typo --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index db774e49..dfeef81a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,7 +4,7 @@ NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 -MPS = 4 +MPS = 5 accelerate_enabled = False vram_state = NORMAL_VRAM