From e87aa1873f6a01148898b41dc3498ca6f82410d8 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 12 Apr 2023 19:36:35 -0600 Subject: [PATCH 01/35] Add slider setting type --- web/scripts/ui.js | 24 ++++++++++++++++++++++++ web/style.css | 8 ++++++++ 2 files changed, 32 insertions(+) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 09861c44..1c7fdc8a 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -270,6 +270,30 @@ class ComfySettingsDialog extends ComfyDialog { ]), ]); break; + case "slider": + element = $el("div", [ + $el("label", { textContent: name }, [ + $el("input", { + type: "range", + value, + oninput: (e) => { + setter(e.target.value); + e.target.nextElementSibling.value = e.target.value; + }, + ...attrs + }), + $el("input", { + type: "number", + value, + oninput: (e) => { + setter(e.target.value); + e.target.previousElementSibling.value = e.target.value; + }, + ...attrs + }), + ]), + ]); + break; default: console.warn("Unsupported setting type, defaulting to text"); element = $el("div", [ diff --git a/web/style.css b/web/style.css index 34e31726..e3b44576 100644 --- a/web/style.css +++ b/web/style.css @@ -217,6 +217,14 @@ button.comfy-queue-btn { z-index: 99; } +.comfy-modal.comfy-settings input[type="range"] { + vertical-align: middle; +} + +.comfy-modal.comfy-settings input[type="range"] + input[type="number"] { + width: 3.5em; +} + .comfy-modal input, .comfy-modal select { color: var(--input-text); From 8810e1b4b9e2c14860800a2bdc97d50d1aa2f904 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 12 Apr 2023 21:15:21 -0600 Subject: [PATCH 02/35] Fix indentation --- web/scripts/ui.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 1c7fdc8a..6cbc9383 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -270,7 +270,7 @@ class ComfySettingsDialog extends ComfyDialog { ]), ]); break; - case "slider": + case "slider": element = $el("div", [ $el("label", { textContent: name }, [ $el("input", { @@ -278,16 +278,16 @@ class ComfySettingsDialog extends ComfyDialog { value, oninput: (e) => { setter(e.target.value); - e.target.nextElementSibling.value = e.target.value; + e.target.nextElementSibling.value = e.target.value; }, ...attrs }), - $el("input", { + $el("input", { type: "number", value, oninput: (e) => { setter(e.target.value); - e.target.previousElementSibling.value = e.target.value; + e.target.previousElementSibling.value = e.target.value; }, ...attrs }), From 307ef543bf66e5ffd718b3a0b148c72287b65a89 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 13 Apr 2023 10:04:06 -0600 Subject: [PATCH 03/35] Change grid size to slider --- web/extensions/core/snapToGrid.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/snapToGrid.js b/web/extensions/core/snapToGrid.js index 20b245e1..cb5fc154 100644 --- a/web/extensions/core/snapToGrid.js +++ b/web/extensions/core/snapToGrid.js @@ -9,7 +9,7 @@ app.registerExtension({ app.ui.settings.addSetting({ id: "Comfy.SnapToGrid.GridSize", name: "Grid Size", - type: "number", + type: "slider", attrs: { min: 1, max: 500, From 8489cba1405f222f4675c120aee4a3722affb3f8 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 13 Apr 2023 22:01:01 +0200 Subject: [PATCH 04/35] add unique ID per word/embedding for tokenizer --- comfy/sd1_clip.py | 117 ++++++++++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4f51657c..3dd8262a 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -224,60 +224,85 @@ class SD1Tokenizer: self.inv_vocab = {v: k for k, v in vocab.items()} self.embedding_directory = embedding_directory self.max_word_length = 8 + self.embedding_identifier = "embedding:" - def tokenize_with_weights(self, text): + def _try_get_embedding(self, name:str): + ''' + Takes a potential embedding name and tries to retrieve it. + Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. + ''' + embedding_name = name[len(self.embedding_identifier):].strip('\n') + embed = load_embed(embedding_name, self.embedding_directory) + if embed is None: + stripped = embedding_name.strip(',') + if len(stripped) < len(embedding_name): + embed = load_embed(stripped, self.embedding_directory) + return (embed, embedding_name[len(stripped):]) + return (embed, "") + + + def tokenize_with_weights(self, text:str): + ''' + Takes a prompt and converts it to a list of (token, weight, word id) elements. + Tokens can both be integer tokens and pre computed CLIP tensors. + Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. + Returned list has the dimensions NxM where M is the input size of CLIP + ''' text = escape_important(text) parsed_weights = token_weights(text, 1.0) + #tokenize words tokens = [] - for t in parsed_weights: - to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ') - while len(to_tokenize) > 0: - word = to_tokenize.pop(0) - temp_tokens = [] - embedding_identifier = "embedding:" - if word.startswith(embedding_identifier) and self.embedding_directory is not None: - embedding_name = word[len(embedding_identifier):].strip('\n') - embed = load_embed(embedding_name, self.embedding_directory) + for weighted_segment, weight in parsed_weights: + to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ') + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + #if we find an embedding, deal with the embedding + if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: + embed, leftover = self._try_get_embedding(word) if embed is None: - stripped = embedding_name.strip(',') - if len(stripped) < len(embedding_name): - embed = load_embed(stripped, self.embedding_directory) - if embed is not None: - to_tokenize.insert(0, embedding_name[len(stripped):]) - - if embed is not None: - if len(embed.shape) == 1: - temp_tokens += [(embed, t[1])] - else: - for x in range(embed.shape[0]): - temp_tokens += [(embed[x], t[1])] + print(f"warning, embedding:{word} does not exist, ignoring") else: - print("warning, embedding:{} does not exist, ignoring".format(embedding_name)) - elif len(word) > 0: - tt = self.tokenizer(word)["input_ids"][1:-1] - for x in tt: - temp_tokens += [(x, t[1])] - tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section) + if len(embed.shape) == 1: + tokens.append([(embed, weight)]) + else: + tokens.append([(embed[x], weight) for x in range(embed.shape[0])]) + #if we accidentally have leftover text, continue parsing using leftover, else move on to next word + if leftover != "": + word = leftover + else: + continue + #parse word + tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) + + #reshape token array to CLIP input size + batched_tokens = [] + batch = [] + batched_tokens.append(batch) + for i, t_group in enumerate(tokens): + #start a new batch if there is not enough room + if len(t_group) + len(batch) > self.max_tokens_per_section: + remaining_length = self.max_tokens_per_section - len(batch) + #fill remaining space depending on length of tokens + if len(t_group) > self.max_word_length: + #put part of group of tokens in the batch + batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + t_group = t_group[remaining_length:] + else: + #filler tokens + batch.extend([(self.end_token, 1.0, 0)] * remaining_length) + batch = [] + batched_tokens.append(batch) + #put current group of tokens in the batch + batch.extend([(t,w,i+1) for t,w in t_group]) + + #fill last batch + batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) + + #add start and end tokens + batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + return batched_tokens - #try not to split words in different sections - if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length): - for x in range(tokens_left): - tokens += [(self.end_token, 1.0)] - tokens += temp_tokens - - out_tokens = [] - for x in range(0, len(tokens), self.max_tokens_per_section): - o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))] - o_token += [(self.end_token, 1.0)] - if self.pad_with_end: - o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token)) - else: - o_token +=[(0, 1.0)] * (self.max_length - len(o_token)) - - out_tokens += [o_token] - - return out_tokens def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) From 73175cf58c0371903bf9bec107f0e82c5c4363d0 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 13 Apr 2023 22:06:50 +0200 Subject: [PATCH 05/35] split tokenizer from encoder --- comfy/sd.py | 6 ++++-- nodes.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 2d7ff5ab..6bd30daf 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -372,10 +372,12 @@ class CLIP: def clip_layer(self, layer_idx): self.layer_idx = layer_idx - def encode(self, text): + def tokenize(self, text): + return self.tokenizer.tokenize_with_weights(text) + + def encode(self, tokens): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) - tokens = self.tokenizer.tokenize_with_weights(text) try: self.patcher.patch_model() cond = self.cond_stage_model.encode_token_weights(tokens) diff --git a/nodes.py b/nodes.py index 946c6685..b81d1601 100644 --- a/nodes.py +++ b/nodes.py @@ -44,7 +44,8 @@ class CLIPTextEncode: CATEGORY = "conditioning" def encode(self, clip, text): - return ([[clip.encode(text), {}]], ) + tokens = clip.tokenize(text) + return ([[clip.encode(tokens), {}]], ) class ConditioningCombine: @classmethod From 752f7a162ba728b3ab7b9ce53be73c271da25dd5 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 14 Apr 2023 21:02:45 +0200 Subject: [PATCH 06/35] align behavior with old tokenize function --- comfy/sd1_clip.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3dd8262a..45bc9526 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -226,12 +226,11 @@ class SD1Tokenizer: self.max_word_length = 8 self.embedding_identifier = "embedding:" - def _try_get_embedding(self, name:str): + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. Returns a Tuple consisting of the embedding and any leftover string, embedding can be None. ''' - embedding_name = name[len(self.embedding_identifier):].strip('\n') embed = load_embed(embedding_name, self.embedding_directory) if embed is None: stripped = embedding_name.strip(',') @@ -259,9 +258,10 @@ class SD1Tokenizer: for word in to_tokenize: #if we find an embedding, deal with the embedding if word.startswith(self.embedding_identifier) and self.embedding_directory is not None: - embed, leftover = self._try_get_embedding(word) + embedding_name = word[len(self.embedding_identifier):].strip('\n') + embed, leftover = self._try_get_embedding(embedding_name) if embed is None: - print(f"warning, embedding:{word} does not exist, ignoring") + print(f"warning, embedding:{embedding_name} does not exist, ignoring") else: if len(embed.shape) == 1: tokens.append([(embed, weight)]) @@ -280,21 +280,21 @@ class SD1Tokenizer: batch = [] batched_tokens.append(batch) for i, t_group in enumerate(tokens): - #start a new batch if there is not enough room - if len(t_group) + len(batch) > self.max_tokens_per_section: - remaining_length = self.max_tokens_per_section - len(batch) - #fill remaining space depending on length of tokens - if len(t_group) > self.max_word_length: - #put part of group of tokens in the batch - batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) - t_group = t_group[remaining_length:] + #determine if we're going to try and keep the tokens in a single batch + is_large = len(t_group) >= self.max_word_length + while len(t_group) > 0: + if len(t_group) + len(batch) > self.max_tokens_per_section: + remaining_length = self.max_tokens_per_section - len(batch) + if is_large: + batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + t_group = t_group[remaining_length:] + else: + batch.extend([(self.end_token, 1.0, 0)] * remaining_length) + batch = [] + batched_tokens.append(batch) else: - #filler tokens - batch.extend([(self.end_token, 1.0, 0)] * remaining_length) - batch = [] - batched_tokens.append(batch) - #put current group of tokens in the batch - batch.extend([(t,w,i+1) for t,w in t_group]) + batch.extend([(t,w,i+1) for t,w in t_group]) + t_group = [] #fill last batch batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) From da115bd78d7c4571dc0747dcb17e280b5c8ff4ea Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 14 Apr 2023 21:16:55 +0200 Subject: [PATCH 07/35] ensure backwards compat with optional args --- comfy/sd.py | 10 +++++++--- comfy/sd1_clip.py | 6 +++++- nodes.py | 3 +-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6bd30daf..6e54bc60 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -372,12 +372,16 @@ class CLIP: def clip_layer(self, layer_idx): self.layer_idx = layer_idx - def tokenize(self, text): - return self.tokenizer.tokenize_with_weights(text) + def tokenize(self, text, return_word_ids=False): + return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode(self, tokens): + def encode(self, text, from_tokens=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) + if from_tokens: + tokens = text + else: + tokens = self.tokenizer.tokenize_with_weights(text) try: self.patcher.patch_model() cond = self.cond_stage_model.encode_token_weights(tokens) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 45bc9526..02e925c8 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -240,7 +240,7 @@ class SD1Tokenizer: return (embed, "") - def tokenize_with_weights(self, text:str): + def tokenize_with_weights(self, text:str, return_word_ids=False): ''' Takes a prompt and converts it to a list of (token, weight, word id) elements. Tokens can both be integer tokens and pre computed CLIP tensors. @@ -301,6 +301,10 @@ class SD1Tokenizer: #add start and end tokens batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + + if not return_word_ids: + batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] + return batched_tokens diff --git a/nodes.py b/nodes.py index b68c8ef4..6468ac6b 100644 --- a/nodes.py +++ b/nodes.py @@ -44,8 +44,7 @@ class CLIPTextEncode: CATEGORY = "conditioning" def encode(self, clip, text): - tokens = clip.tokenize(text) - return ([[clip.encode(tokens), {}]], ) + return ([[clip.encode(text), {}]], ) class ConditioningCombine: @classmethod From d0b1b6c6bf60a6f85e742a3340e9fcd9b06d0bde Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:38:21 +0200 Subject: [PATCH 08/35] fixed improper padding --- comfy/sd1_clip.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 02e925c8..32612cf3 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -247,6 +247,11 @@ class SD1Tokenizer: Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Returned list has the dimensions NxM where M is the input size of CLIP ''' + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + text = escape_important(text) parsed_weights = token_weights(text, 1.0) @@ -277,30 +282,33 @@ class SD1Tokenizer: #reshape token array to CLIP input size batched_tokens = [] - batch = [] + batch = [(self.start_token, 1.0, 0)] batched_tokens.append(batch) for i, t_group in enumerate(tokens): #determine if we're going to try and keep the tokens in a single batch is_large = len(t_group) >= self.max_word_length + while len(t_group) > 0: - if len(t_group) + len(batch) > self.max_tokens_per_section: - remaining_length = self.max_tokens_per_section - len(batch) + if len(t_group) + len(batch) > self.max_length - 1: + remaining_length = self.max_length - len(batch) - 1 + #break word in two and add end token if is_large: batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]]) + batch.append((self.end_token, 1.0, 0)) t_group = t_group[remaining_length:] + #add end token and pad else: - batch.extend([(self.end_token, 1.0, 0)] * remaining_length) - batch = [] + batch.append((self.end_token, 1.0, 0)) + batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + #start new batch + batch = [(self.start_token, 1.0, 0)] batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] #fill last batch - batch.extend([(self.end_token, 1.0, 0)] * (self.max_tokens_per_section - len(batch))) - - #add start and end tokens - batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens] + batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] From eb4035c8bd8504531b5b11dac05303d19b42ee05 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 21:40:39 +0100 Subject: [PATCH 09/35] Adds jsdoc for better dev experience --- web/scripts/app.js | 34 +- web/types/comfy.d.ts | 78 ++ web/types/litegraph.d.ts | 1506 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 1614 insertions(+), 4 deletions(-) create mode 100644 web/types/comfy.d.ts create mode 100644 web/types/litegraph.d.ts diff --git a/web/scripts/app.js b/web/scripts/app.js index 42addc8c..940c5ecf 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -4,27 +4,49 @@ import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js"; -class ComfyApp { /** - * List of {number, batchCount} entries to queue + * @typedef {import("types/comfy").ComfyExtension} ComfyExtension + */ + +export class ComfyApp { + /** + * List of entries to queue + * @type {{number: number, batchCount: number}[]} */ #queueItems = []; /** * If the queue is currently being processed + * @type {boolean} */ #processingQueue = false; constructor() { this.ui = new ComfyUI(this); + + /** + * List of extensions that are registered with the app + * @type {ComfyExtension[]} + */ this.extensions = []; + + /** + * Stores the execution output data for each node + * @type {Record} + */ this.nodeOutputs = {}; + + + /** + * If the shift key on the keyboard is pressed + * @type {boolean} + */ this.shiftDown = false; } /** * Invoke an extension callback - * @param {string} method The extension callback to execute - * @param {...any} args Any arguments to pass to the callback + * @param {keyof ComfyExtension} method The extension callback to execute + * @param {any[]} args Any arguments to pass to the callback * @returns */ #invokeExtensions(method, ...args) { @@ -1120,6 +1142,10 @@ class ComfyApp { } } + /** + * Registers a Comfy web extension with the app + * @param {ComfyExtension} extension + */ registerExtension(extension) { if (!extension.name) { throw new Error("Extensions must have a 'name' property."); diff --git a/web/types/comfy.d.ts b/web/types/comfy.d.ts new file mode 100644 index 00000000..3bb92454 --- /dev/null +++ b/web/types/comfy.d.ts @@ -0,0 +1,78 @@ +import { LGraphNode, IWidget } from "./litegraph"; +import { ComfyApp } from "/scripts/app"; + +export interface ComfyExtension { + /** + * The name of the extension + */ + name: string; + /** + * Allows any initialisation, e.g. loading resources. Called after the canvas is created but before nodes are added + * @param app The ComfyUI app instance + */ + init(app: ComfyApp): Promise; + /** + * Allows any additonal setup, called after the application is fully set up and running + * @param app The ComfyUI app instance + */ + setup(app: ComfyApp): Promise; + /** + * Called before nodes are registered with the graph + * @param defs The collection of node definitions, add custom ones or edit existing ones + * @param app The ComfyUI app instance + */ + addCustomNodeDefs(defs: Record, app: ComfyApp): Promise; + /** + * Allows the extension to add custom widgets + * @param app The ComfyUI app instance + * @returns An array of {[widget name]: widget data} + */ + getCustomWidgets( + app: ComfyApp + ): Promise< + Array< + Record { widget?: IWidget; minWidth?: number; minHeight?: number }> + > + >; + /** + * Allows the extension to add additional handling to the node before it is registered with LGraph + * @param nodeType The node class (not an instance) + * @param nodeData The original node object info config object + * @param app The ComfyUI app instance + */ + beforeRegisterNodeDef(nodeType: typeof LGraphNode, nodeData: ComfyObjectInfo, app: ComfyApp): Promise; + /** + * Allows the extension to register additional nodes with LGraph after standard nodes are added + * @param app The ComfyUI app instance + */ + registerCustomNodes(app: ComfyApp): Promise; + /** + * Allows the extension to modify a node that has been reloaded onto the graph. + * If you break something in the backend and want to patch workflows in the frontend + * This is the place to do this + * @param node The node that has been loaded + * @param app The ComfyUI app instance + */ + loadedGraphNode(node: LGraphNode, app: ComfyApp); + /** + * Allows the extension to run code after the constructor of the node + * @param node The node that has been created + * @param app The ComfyUI app instance + */ + nodeCreated(node: LGraphNode, app: ComfyApp); +} + +export type ComfyObjectInfo = { + name: string; + display_name?: string; + description?: string; + category: string; + input?: { + required?: Record; + optional?: Record; + }; + output?: string[]; + output_name: string[]; +}; + +export type ComfyObjectInfoConfig = [string | any[]] | [string | any[], any]; diff --git a/web/types/litegraph.d.ts b/web/types/litegraph.d.ts new file mode 100644 index 00000000..6629e779 --- /dev/null +++ b/web/types/litegraph.d.ts @@ -0,0 +1,1506 @@ +// Type definitions for litegraph.js 0.7.0 +// Project: litegraph.js +// Definitions by: NateScarlet + +export type Vector2 = [number, number]; +export type Vector4 = [number, number, number, number]; +export type widgetTypes = + | "number" + | "slider" + | "combo" + | "text" + | "toggle" + | "button"; +export type SlotShape = + | typeof LiteGraph.BOX_SHAPE + | typeof LiteGraph.CIRCLE_SHAPE + | typeof LiteGraph.ARROW_SHAPE + | typeof LiteGraph.SQUARE_SHAPE + | number; // For custom shapes + +/** https://github.com/jagenjo/litegraph.js/tree/master/guides#node-slots */ +export interface INodeSlot { + name: string; + type: string | -1; + label?: string; + dir?: + | typeof LiteGraph.UP + | typeof LiteGraph.RIGHT + | typeof LiteGraph.DOWN + | typeof LiteGraph.LEFT; + color_on?: string; + color_off?: string; + shape?: SlotShape; + locked?: boolean; + nameLocked?: boolean; +} + +export interface INodeInputSlot extends INodeSlot { + link: LLink["id"] | null; +} +export interface INodeOutputSlot extends INodeSlot { + links: LLink["id"][] | null; +} + +export type WidgetCallback = ( + this: T, + value: T["value"], + graphCanvas: LGraphCanvas, + node: LGraphNode, + pos: Vector2, + event?: MouseEvent +) => void; + +export interface IWidget { + name: string | null; + value: TValue; + options?: TOptions; + type?: widgetTypes; + y?: number; + property?: string; + last_y?: number; + clicked?: boolean; + marker?: boolean; + callback?: WidgetCallback; + /** Called by `LGraphCanvas.drawNodeWidgets` */ + draw?( + ctx: CanvasRenderingContext2D, + node: LGraphNode, + width: number, + posY: number, + height: number + ): void; + /** + * Called by `LGraphCanvas.processNodeWidgets` + * https://github.com/jagenjo/litegraph.js/issues/76 + */ + mouse?( + event: MouseEvent, + pos: Vector2, + node: LGraphNode + ): boolean; + /** Called by `LGraphNode.computeSize` */ + computeSize?(width: number): [number, number]; +} +export interface IButtonWidget extends IWidget { + type: "button"; +} +export interface IToggleWidget + extends IWidget { + type: "toggle"; +} +export interface ISliderWidget + extends IWidget { + type: "slider"; +} +export interface INumberWidget extends IWidget { + type: "number"; +} +export interface IComboWidget + extends IWidget< + string[], + { + values: + | string[] + | ((widget: IComboWidget, node: LGraphNode) => string[]); + } + > { + type: "combo"; +} + +export interface ITextWidget extends IWidget { + type: "text"; +} + +export interface IContextMenuItem { + content: string; + callback?: ContextMenuEventListener; + /** Used as innerHTML for extra child element */ + title?: string; + disabled?: boolean; + has_submenu?: boolean; + submenu?: { + options: ContextMenuItem[]; + } & IContextMenuOptions; + className?: string; +} +export interface IContextMenuOptions { + callback?: ContextMenuEventListener; + ignore_item_callbacks?: Boolean; + event?: MouseEvent | CustomEvent; + parentMenu?: ContextMenu; + autoopen?: boolean; + title?: string; + extra?: any; +} + +export type ContextMenuItem = IContextMenuItem | null; +export type ContextMenuEventListener = ( + value: ContextMenuItem, + options: IContextMenuOptions, + event: MouseEvent, + parentMenu: ContextMenu | undefined, + node: LGraphNode +) => boolean | void; + +export const LiteGraph: { + VERSION: number; + + CANVAS_GRID_SIZE: number; + + NODE_TITLE_HEIGHT: number; + NODE_TITLE_TEXT_Y: number; + NODE_SLOT_HEIGHT: number; + NODE_WIDGET_HEIGHT: number; + NODE_WIDTH: number; + NODE_MIN_WIDTH: number; + NODE_COLLAPSED_RADIUS: number; + NODE_COLLAPSED_WIDTH: number; + NODE_TITLE_COLOR: string; + NODE_TEXT_SIZE: number; + NODE_TEXT_COLOR: string; + NODE_SUBTEXT_SIZE: number; + NODE_DEFAULT_COLOR: string; + NODE_DEFAULT_BGCOLOR: string; + NODE_DEFAULT_BOXCOLOR: string; + NODE_DEFAULT_SHAPE: string; + DEFAULT_SHADOW_COLOR: string; + DEFAULT_GROUP_FONT: number; + + LINK_COLOR: string; + EVENT_LINK_COLOR: string; + CONNECTING_LINK_COLOR: string; + + MAX_NUMBER_OF_NODES: number; //avoid infinite loops + DEFAULT_POSITION: Vector2; //default node position + VALID_SHAPES: ["default", "box", "round", "card"]; //,"circle" + + //shapes are used for nodes but also for slots + BOX_SHAPE: 1; + ROUND_SHAPE: 2; + CIRCLE_SHAPE: 3; + CARD_SHAPE: 4; + ARROW_SHAPE: 5; + SQUARE_SHAPE: 6; + + //enums + INPUT: 1; + OUTPUT: 2; + + EVENT: -1; //for outputs + ACTION: -1; //for inputs + + ALWAYS: 0; + ON_EVENT: 1; + NEVER: 2; + ON_TRIGGER: 3; + + UP: 1; + DOWN: 2; + LEFT: 3; + RIGHT: 4; + CENTER: 5; + + STRAIGHT_LINK: 0; + LINEAR_LINK: 1; + SPLINE_LINK: 2; + + NORMAL_TITLE: 0; + NO_TITLE: 1; + TRANSPARENT_TITLE: 2; + AUTOHIDE_TITLE: 3; + + node_images_path: string; + + debug: boolean; + catch_exceptions: boolean; + throw_errors: boolean; + /** if set to true some nodes like Formula would be allowed to evaluate code that comes from unsafe sources (like node configuration), which could lead to exploits */ + allow_scripts: boolean; + /** node types by string */ + registered_node_types: Record; + /** used for dropping files in the canvas */ + node_types_by_file_extension: Record; + /** node types by class name */ + Nodes: Record; + + /** used to add extra features to the search box */ + searchbox_extras: Record< + string, + { + data: { outputs: string[][]; title: string }; + desc: string; + type: string; + } + >; + + createNode(type: string): T; + /** Register a node class so it can be listed when the user wants to create a new one */ + registerNodeType(type: string, base: { new (): LGraphNode }): void; + /** removes a node type from the system */ + unregisterNodeType(type: string): void; + /** Removes all previously registered node's types. */ + clearRegisteredTypes(): void; + /** + * Create a new node type by passing a function, it wraps it with a proper class and generates inputs according to the parameters of the function. + * Useful to wrap simple methods that do not require properties, and that only process some input to generate an output. + * @param name node name with namespace (p.e.: 'math/sum') + * @param func + * @param param_types an array containing the type of every parameter, otherwise parameters will accept any type + * @param return_type string with the return type, otherwise it will be generic + * @param properties properties to be configurable + */ + wrapFunctionAsNode( + name: string, + func: (...args: any[]) => any, + param_types?: string[], + return_type?: string, + properties?: object + ): void; + + /** + * Adds this method to all node types, existing and to be created + * (You can add it to LGraphNode.prototype but then existing node types wont have it) + */ + addNodeMethod(name: string, func: (...args: any[]) => any): void; + + /** + * Create a node of a given type with a name. The node is not attached to any graph yet. + * @param type full name of the node class. p.e. "math/sin" + * @param name a name to distinguish from other nodes + * @param options to set options + */ + createNode( + type: string, + title: string, + options: object + ): T; + + /** + * Returns a registered node type with a given name + * @param type full name of the node class. p.e. "math/sin" + */ + getNodeType(type: string): LGraphNodeConstructor; + + /** + * Returns a list of node types matching one category + * @method getNodeTypesInCategory + * @param {String} category category name + * @param {String} filter only nodes with ctor.filter equal can be shown + * @return {Array} array with all the node classes + */ + getNodeTypesInCategory( + category: string, + filter: string + ): LGraphNodeConstructor[]; + + /** + * Returns a list with all the node type categories + * @method getNodeTypesCategories + * @param {String} filter only nodes with ctor.filter equal can be shown + * @return {Array} array with all the names of the categories + */ + getNodeTypesCategories(filter: string): string[]; + + /** debug purposes: reloads all the js scripts that matches a wildcard */ + reloadNodes(folder_wildcard: string): void; + + getTime(): number; + LLink: typeof LLink; + LGraph: typeof LGraph; + DragAndScale: typeof DragAndScale; + compareObjects(a: object, b: object): boolean; + distance(a: Vector2, b: Vector2): number; + colorToString(c: string): string; + isInsideRectangle( + x: number, + y: number, + left: number, + top: number, + width: number, + height: number + ): boolean; + growBounding(bounding: Vector4, x: number, y: number): Vector4; + isInsideBounding(p: Vector2, bb: Vector4): boolean; + hex2num(hex: string): [number, number, number]; + num2hex(triplet: [number, number, number]): string; + ContextMenu: typeof ContextMenu; + extendClass(target: A, origin: B): A & B; + getParameterNames(func: string): string[]; +}; + +export type serializedLGraph< + TNode = ReturnType, + // https://github.com/jagenjo/litegraph.js/issues/74 + TLink = [number, number, number, number, number, string], + TGroup = ReturnType +> = { + last_node_id: LGraph["last_node_id"]; + last_link_id: LGraph["last_link_id"]; + nodes: TNode[]; + links: TLink[]; + groups: TGroup[]; + config: LGraph["config"]; + version: typeof LiteGraph.VERSION; +}; + +export declare class LGraph { + static supported_types: string[]; + static STATUS_STOPPED: 1; + static STATUS_RUNNING: 2; + + constructor(o?: object); + + filter: string; + catch_errors: boolean; + /** custom data */ + config: object; + elapsed_time: number; + fixedtime: number; + fixedtime_lapse: number; + globaltime: number; + inputs: any; + iteration: number; + last_link_id: number; + last_node_id: number; + last_update_time: number; + links: Record; + list_of_graphcanvas: LGraphCanvas[]; + outputs: any; + runningtime: number; + starttime: number; + status: typeof LGraph.STATUS_RUNNING | typeof LGraph.STATUS_STOPPED; + + private _nodes: LGraphNode[]; + private _groups: LGraphGroup[]; + private _nodes_by_id: Record; + /** nodes that are executable sorted in execution order */ + private _nodes_executable: + | (LGraphNode & { onExecute: NonNullable }[]) + | null; + /** nodes that contain onExecute */ + private _nodes_in_order: LGraphNode[]; + private _version: number; + + getSupportedTypes(): string[]; + /** Removes all nodes from this graph */ + clear(): void; + /** Attach Canvas to this graph */ + attachCanvas(graphCanvas: LGraphCanvas): void; + /** Detach Canvas to this graph */ + detachCanvas(graphCanvas: LGraphCanvas): void; + /** + * Starts running this graph every interval milliseconds. + * @param interval amount of milliseconds between executions, if 0 then it renders to the monitor refresh rate + */ + start(interval?: number): void; + /** Stops the execution loop of the graph */ + stop(): void; + /** + * Run N steps (cycles) of the graph + * @param num number of steps to run, default is 1 + */ + runStep(num?: number, do_not_catch_errors?: boolean): void; + /** + * Updates the graph execution order according to relevance of the nodes (nodes with only outputs have more relevance than + * nodes with only inputs. + */ + updateExecutionOrder(): void; + /** This is more internal, it computes the executable nodes in order and returns it */ + computeExecutionOrder(only_onExecute: boolean, set_level: any): T; + /** + * Returns all the nodes that could affect this one (ancestors) by crawling all the inputs recursively. + * It doesn't include the node itself + * @return an array with all the LGraphNodes that affect this node, in order of execution + */ + getAncestors(node: LGraphNode): LGraphNode[]; + /** + * Positions every node in a more readable manner + */ + arrange(margin?: number,layout?: string): void; + /** + * Returns the amount of time the graph has been running in milliseconds + * @return number of milliseconds the graph has been running + */ + getTime(): number; + + /** + * Returns the amount of time accumulated using the fixedtime_lapse var. This is used in context where the time increments should be constant + * @return number of milliseconds the graph has been running + */ + getFixedTime(): number; + + /** + * Returns the amount of time it took to compute the latest iteration. Take into account that this number could be not correct + * if the nodes are using graphical actions + * @return number of milliseconds it took the last cycle + */ + getElapsedTime(): number; + /** + * Sends an event to all the nodes, useful to trigger stuff + * @param eventName the name of the event (function to be called) + * @param params parameters in array format + */ + sendEventToAllNodes(eventName: string, params: any[], mode?: any): void; + + sendActionToCanvas(action: any, params: any[]): void; + /** + * Adds a new node instance to this graph + * @param node the instance of the node + */ + add(node: LGraphNode, skip_compute_order?: boolean): void; + /** + * Called when a new node is added + * @param node the instance of the node + */ + onNodeAdded(node: LGraphNode): void; + /** Removes a node from the graph */ + remove(node: LGraphNode): void; + /** Returns a node by its id. */ + getNodeById(id: number): LGraphNode | undefined; + /** + * Returns a list of nodes that matches a class + * @param classObject the class itself (not an string) + * @return a list with all the nodes of this type + */ + findNodesByClass( + classObject: LGraphNodeConstructor + ): T[]; + /** + * Returns a list of nodes that matches a type + * @param type the name of the node type + * @return a list with all the nodes of this type + */ + findNodesByType(type: string): T[]; + /** + * Returns the first node that matches a name in its title + * @param title the name of the node to search + * @return the node or null + */ + findNodeByTitle(title: string): T | null; + /** + * Returns a list of nodes that matches a name + * @param title the name of the node to search + * @return a list with all the nodes with this name + */ + findNodesByTitle(title: string): T[]; + /** + * Returns the top-most node in this position of the canvas + * @param x the x coordinate in canvas space + * @param y the y coordinate in canvas space + * @param nodes_list a list with all the nodes to search from, by default is all the nodes in the graph + * @return the node at this position or null + */ + getNodeOnPos( + x: number, + y: number, + node_list?: LGraphNode[], + margin?: number + ): T | null; + /** + * Returns the top-most group in that position + * @param x the x coordinate in canvas space + * @param y the y coordinate in canvas space + * @return the group or null + */ + getGroupOnPos(x: number, y: number): LGraphGroup | null; + + onAction(action: any, param: any): void; + trigger(action: any, param: any): void; + /** Tell this graph it has a global graph input of this type */ + addInput(name: string, type: string, value?: any): void; + /** Assign a data to the global graph input */ + setInputData(name: string, data: any): void; + /** Returns the current value of a global graph input */ + getInputData(name: string): T; + /** Changes the name of a global graph input */ + renameInput(old_name: string, name: string): false | undefined; + /** Changes the type of a global graph input */ + changeInputType(name: string, type: string): false | undefined; + /** Removes a global graph input */ + removeInput(name: string): boolean; + /** Creates a global graph output */ + addOutput(name: string, type: string, value: any): void; + /** Assign a data to the global output */ + setOutputData(name: string, value: string): void; + /** Returns the current value of a global graph output */ + getOutputData(name: string): T; + + /** Renames a global graph output */ + renameOutput(old_name: string, name: string): false | undefined; + /** Changes the type of a global graph output */ + changeOutputType(name: string, type: string): false | undefined; + /** Removes a global graph output */ + removeOutput(name: string): boolean; + triggerInput(name: string, value: any): void; + setCallback(name: string, func: (...args: any[]) => any): void; + beforeChange(info?: LGraphNode): void; + afterChange(info?: LGraphNode): void; + connectionChange(node: LGraphNode): void; + /** returns if the graph is in live mode */ + isLive(): boolean; + /** clears the triggered slot animation in all links (stop visual animation) */ + clearTriggeredSlots(): void; + /* Called when something visually changed (not the graph!) */ + change(): void; + setDirtyCanvas(fg: boolean, bg: boolean): void; + /** Destroys a link */ + removeLink(link_id: number): void; + /** Creates a Object containing all the info about this graph, it can be serialized */ + serialize(): T; + /** + * Configure a graph from a JSON string + * @param data configure a graph from a JSON string + * @returns if there was any error parsing + */ + configure(data: object, keep_old?: boolean): boolean | undefined; + load(url: string): void; +} + +export type SerializedLLink = [number, string, number, number, number, number]; +export declare class LLink { + id: number; + type: string; + origin_id: number; + origin_slot: number; + target_id: number; + target_slot: number; + constructor( + id: number, + type: string, + origin_id: number, + origin_slot: number, + target_id: number, + target_slot: number + ); + configure(o: LLink | SerializedLLink): void; + serialize(): SerializedLLink; +} + +export type SerializedLGraphNode = { + id: T["id"]; + type: T["type"]; + pos: T["pos"]; + size: T["size"]; + flags: T["flags"]; + mode: T["mode"]; + inputs: T["inputs"]; + outputs: T["outputs"]; + title: T["title"]; + properties: T["properties"]; + widgets_values?: IWidget["value"][]; +}; + +/** https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#lgraphnode */ +export declare class LGraphNode { + static title_color: string; + static title: string; + static type: null | string; + static widgets_up: boolean; + constructor(title?: string); + + title: string; + type: null | string; + size: Vector2; + graph: null | LGraph; + graph_version: number; + pos: Vector2; + is_selected: boolean; + mouseOver: boolean; + + id: number; + + //inputs available: array of inputs + inputs: INodeInputSlot[]; + outputs: INodeOutputSlot[]; + connections: any[]; + + //local data + properties: Record; + properties_info: any[]; + + flags: Partial<{ + collapsed: boolean + }>; + + color: string; + bgcolor: string; + boxcolor: string; + shape: + | typeof LiteGraph.BOX_SHAPE + | typeof LiteGraph.ROUND_SHAPE + | typeof LiteGraph.CIRCLE_SHAPE + | typeof LiteGraph.CARD_SHAPE + | typeof LiteGraph.ARROW_SHAPE; + + serialize_widgets: boolean; + skip_list: boolean; + + /** Used in `LGraphCanvas.onMenuNodeMode` */ + mode?: + | typeof LiteGraph.ON_EVENT + | typeof LiteGraph.ON_TRIGGER + | typeof LiteGraph.NEVER + | typeof LiteGraph.ALWAYS; + + /** If set to true widgets do not start after the slots */ + widgets_up: boolean; + /** widgets start at y distance from the top of the node */ + widgets_start_y: number; + /** if you render outside the node, it will be clipped */ + clip_area: boolean; + /** if set to false it wont be resizable with the mouse */ + resizable: boolean; + /** slots are distributed horizontally */ + horizontal: boolean; + /** if true, the node will show the bgcolor as 'red' */ + has_errors?: boolean; + + /** configure a node from an object containing the serialized info */ + configure(info: SerializedLGraphNode): void; + /** serialize the content */ + serialize(): SerializedLGraphNode; + /** Creates a clone of this node */ + clone(): this; + /** serialize and stringify */ + toString(): string; + /** get the title string */ + getTitle(): string; + /** sets the value of a property */ + setProperty(name: string, value: any): void; + /** sets the output data */ + setOutputData(slot: number, data: any): void; + /** sets the output data */ + setOutputDataType(slot: number, type: string): void; + /** + * Retrieves the input data (data traveling through the connection) from one slot + * @param slot + * @param force_update if set to true it will force the connected node of this slot to output data into this link + * @return data or if it is not connected returns undefined + */ + getInputData(slot: number, force_update?: boolean): T; + /** + * Retrieves the input data type (in case this supports multiple input types) + * @param slot + * @return datatype in string format + */ + getInputDataType(slot: number): string; + /** + * Retrieves the input data from one slot using its name instead of slot number + * @param slot_name + * @param force_update if set to true it will force the connected node of this slot to output data into this link + * @return data or if it is not connected returns null + */ + getInputDataByName(slot_name: string, force_update?: boolean): T; + /** tells you if there is a connection in one input slot */ + isInputConnected(slot: number): boolean; + /** tells you info about an input connection (which node, type, etc) */ + getInputInfo( + slot: number + ): { link: number; name: string; type: string | 0 } | null; + /** returns the node connected in the input slot */ + getInputNode(slot: number): LGraphNode | null; + /** returns the value of an input with this name, otherwise checks if there is a property with that name */ + getInputOrProperty(name: string): T; + /** tells you the last output data that went in that slot */ + getOutputData(slot: number): T | null; + /** tells you info about an output connection (which node, type, etc) */ + getOutputInfo( + slot: number + ): { name: string; type: string; links: number[] } | null; + /** tells you if there is a connection in one output slot */ + isOutputConnected(slot: number): boolean; + /** tells you if there is any connection in the output slots */ + isAnyOutputConnected(): boolean; + /** retrieves all the nodes connected to this output slot */ + getOutputNodes(slot: number): LGraphNode[]; + /** Triggers an event in this node, this will trigger any output with the same name */ + trigger(action: string, param: any): void; + /** + * Triggers an slot event in this node + * @param slot the index of the output slot + * @param param + * @param link_id in case you want to trigger and specific output link in a slot + */ + triggerSlot(slot: number, param: any, link_id?: number): void; + /** + * clears the trigger slot animation + * @param slot the index of the output slot + * @param link_id in case you want to trigger and specific output link in a slot + */ + clearTriggeredSlot(slot: number, link_id?: number): void; + /** + * add a new property to this node + * @param name + * @param default_value + * @param type string defining the output type ("vec3","number",...) + * @param extra_info this can be used to have special properties of the property (like values, etc) + */ + addProperty( + name: string, + default_value: any, + type: string, + extra_info?: object + ): T; + /** + * add a new output slot to use in this node + * @param name + * @param type string defining the output type ("vec3","number",...) + * @param extra_info this can be used to have special properties of an output (label, special color, position, etc) + */ + addOutput( + name: string, + type: string | -1, + extra_info?: Partial + ): INodeOutputSlot; + /** + * add a new output slot to use in this node + * @param array of triplets like [[name,type,extra_info],[...]] + */ + addOutputs( + array: [string, string | -1, Partial | undefined][] + ): void; + /** remove an existing output slot */ + removeOutput(slot: number): void; + /** + * add a new input slot to use in this node + * @param name + * @param type string defining the input type ("vec3","number",...), it its a generic one use 0 + * @param extra_info this can be used to have special properties of an input (label, color, position, etc) + */ + addInput( + name: string, + type: string | -1, + extra_info?: Partial + ): INodeInputSlot; + /** + * add several new input slots in this node + * @param array of triplets like [[name,type,extra_info],[...]] + */ + addInputs( + array: [string, string | -1, Partial | undefined][] + ): void; + /** remove an existing input slot */ + removeInput(slot: number): void; + /** + * add an special connection to this node (used for special kinds of graphs) + * @param name + * @param type string defining the input type ("vec3","number",...) + * @param pos position of the connection inside the node + * @param direction if is input or output + */ + addConnection( + name: string, + type: string, + pos: Vector2, + direction: string + ): { + name: string; + type: string; + pos: Vector2; + direction: string; + links: null; + }; + setValue(v: any): void; + /** computes the size of a node according to its inputs and output slots */ + computeSize(): [number, number]; + /** + * https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#node-widgets + * @return created widget + */ + addWidget( + type: T["type"], + name: string, + value: T["value"], + callback?: WidgetCallback | string, + options?: T["options"] + ): T; + + addCustomWidget(customWidget: T): T; + + /** + * returns the bounding of the object, used for rendering purposes + * @return [x, y, width, height] + */ + getBounding(): Vector4; + /** checks if a point is inside the shape of a node */ + isPointInside( + x: number, + y: number, + margin?: number, + skipTitle?: boolean + ): boolean; + /** checks if a point is inside a node slot, and returns info about which slot */ + getSlotInPosition( + x: number, + y: number + ): { + input?: INodeInputSlot; + output?: INodeOutputSlot; + slot: number; + link_pos: Vector2; + }; + /** + * returns the input slot with a given name (used for dynamic slots), -1 if not found + * @param name the name of the slot + * @return the slot (-1 if not found) + */ + findInputSlot(name: string): number; + /** + * returns the output slot with a given name (used for dynamic slots), -1 if not found + * @param name the name of the slot + * @return the slot (-1 if not found) + */ + findOutputSlot(name: string): number; + /** + * connect this node output to the input of another node + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param targetNode the target node + * @param targetSlot the input slot of the target node (could be the number of the slot or the string with the name of the slot, or -1 to connect a trigger) + * @return {Object} the link_info is created, otherwise null + */ + connect( + slot: number | string, + targetNode: LGraphNode, + targetSlot: number | string + ): T | null; + /** + * disconnect one output to an specific node + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param target_node the target node to which this slot is connected [Optional, if not target_node is specified all nodes will be disconnected] + * @return if it was disconnected successfully + */ + disconnectOutput(slot: number | string, targetNode?: LGraphNode): boolean; + /** + * disconnect one input + * @param slot (could be the number of the slot or the string with the name of the slot) + * @return if it was disconnected successfully + */ + disconnectInput(slot: number | string): boolean; + /** + * returns the center of a connection point in canvas coords + * @param is_input true if if a input slot, false if it is an output + * @param slot (could be the number of the slot or the string with the name of the slot) + * @param out a place to store the output, to free garbage + * @return the position + **/ + getConnectionPos( + is_input: boolean, + slot: number | string, + out?: Vector2 + ): Vector2; + /** Force align to grid */ + alignToGrid(): void; + /** Console output */ + trace(msg: string): void; + /** Forces to redraw or the main canvas (LGraphNode) or the bg canvas (links) */ + setDirtyCanvas(fg: boolean, bg: boolean): void; + loadImage(url: string): void; + /** Allows to get onMouseMove and onMouseUp events even if the mouse is out of focus */ + captureInput(v: any): void; + /** Collapse the node to make it smaller on the canvas */ + collapse(force: boolean): void; + /** Forces the node to do not move or realign on Z */ + pin(v?: boolean): void; + localToScreen(x: number, y: number, graphCanvas: LGraphCanvas): Vector2; + + // https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#custom-node-appearance + onDrawBackground?( + ctx: CanvasRenderingContext2D, + canvas: HTMLCanvasElement + ): void; + onDrawForeground?( + ctx: CanvasRenderingContext2D, + canvas: HTMLCanvasElement + ): void; + + // https://github.com/jagenjo/litegraph.js/blob/master/guides/README.md#custom-node-behaviour + onMouseDown?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseMove?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseUp?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseEnter?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onMouseLeave?( + event: MouseEvent, + pos: Vector2, + graphCanvas: LGraphCanvas + ): void; + onKey?(event: KeyboardEvent, pos: Vector2, graphCanvas: LGraphCanvas): void; + + /** Called by `LGraphCanvas.selectNodes` */ + onSelected?(): void; + /** Called by `LGraphCanvas.deselectNode` */ + onDeselected?(): void; + /** Called by `LGraph.runStep` `LGraphNode.getInputData` */ + onExecute?(): void; + /** Called by `LGraph.serialize` */ + onSerialize?(o: SerializedLGraphNode): void; + /** Called by `LGraph.configure` */ + onConfigure?(o: SerializedLGraphNode): void; + /** + * when added to graph (warning: this is called BEFORE the node is configured when loading) + * Called by `LGraph.add` + */ + onAdded?(graph: LGraph): void; + /** + * when removed from graph + * Called by `LGraph.remove` `LGraph.clear` + */ + onRemoved?(): void; + /** + * if returns false the incoming connection will be canceled + * Called by `LGraph.connect` + * @param inputIndex target input slot number + * @param outputType type of output slot + * @param outputSlot output slot object + * @param outputNode node containing the output + * @param outputIndex index of output slot + */ + onConnectInput?( + inputIndex: number, + outputType: INodeOutputSlot["type"], + outputSlot: INodeOutputSlot, + outputNode: LGraphNode, + outputIndex: number + ): boolean; + /** + * if returns false the incoming connection will be canceled + * Called by `LGraph.connect` + * @param outputIndex target output slot number + * @param inputType type of input slot + * @param inputSlot input slot object + * @param inputNode node containing the input + * @param inputIndex index of input slot + */ + onConnectOutput?( + outputIndex: number, + inputType: INodeInputSlot["type"], + inputSlot: INodeInputSlot, + inputNode: LGraphNode, + inputIndex: number + ): boolean; + + /** + * Called just before connection (or disconnect - if input is linked). + * A convenient place to switch to another input, or create new one. + * This allow for ability to automatically add slots if needed + * @param inputIndex + * @return selected input slot index, can differ from parameter value + */ + onBeforeConnectInput?( + inputIndex: number + ): number; + + /** a connection changed (new one or removed) (LiteGraph.INPUT or LiteGraph.OUTPUT, slot, true if connected, link_info, input_info or output_info ) */ + onConnectionsChange( + type: number, + slotIndex: number, + isConnected: boolean, + link: LLink, + ioSlot: (INodeOutputSlot | INodeInputSlot) + ): void; + + /** + * if returns false, will abort the `LGraphNode.setProperty` + * Called when a property is changed + * @param property + * @param value + * @param prevValue + */ + onPropertyChanged?(property: string, value: any, prevValue: any): void | boolean; + + /** Called by `LGraphCanvas.processContextMenu` */ + getMenuOptions?(graphCanvas: LGraphCanvas): ContextMenuItem[]; + getSlotMenuOptions?(slot: INodeSlot): ContextMenuItem[]; +} + +export type LGraphNodeConstructor = { + new (): T; +}; + +export type SerializedLGraphGroup = { + title: LGraphGroup["title"]; + bounding: LGraphGroup["_bounding"]; + color: LGraphGroup["color"]; + font: LGraphGroup["font"]; +}; +export declare class LGraphGroup { + title: string; + private _bounding: Vector4; + color: string; + font: string; + + configure(o: SerializedLGraphGroup): void; + serialize(): SerializedLGraphGroup; + move(deltaX: number, deltaY: number, ignoreNodes?: boolean): void; + recomputeInsideNodes(): void; + isPointInside: LGraphNode["isPointInside"]; + setDirtyCanvas: LGraphNode["setDirtyCanvas"]; +} + +export declare class DragAndScale { + constructor(element?: HTMLElement, skipEvents?: boolean); + offset: [number, number]; + scale: number; + max_scale: number; + min_scale: number; + onredraw: Function | null; + enabled: boolean; + last_mouse: Vector2; + element: HTMLElement | null; + visible_area: Vector4; + bindEvents(element: HTMLElement): void; + computeVisibleArea(): void; + onMouse(e: MouseEvent): void; + toCanvasContext(ctx: CanvasRenderingContext2D): void; + convertOffsetToCanvas(pos: Vector2): Vector2; + convertCanvasToOffset(pos: Vector2): Vector2; + mouseDrag(x: number, y: number): void; + changeScale(value: number, zooming_center?: Vector2): void; + changeDeltaScale(value: number, zooming_center?: Vector2): void; + reset(): void; +} + +/** + * This class is in charge of rendering one graph inside a canvas. And provides all the interaction required. + * Valid callbacks are: onNodeSelected, onNodeDeselected, onShowNodePanel, onNodeDblClicked + * + * @param canvas the canvas where you want to render (it accepts a selector in string format or the canvas element itself) + * @param graph + * @param options { skip_rendering, autoresize } + */ +export declare class LGraphCanvas { + static node_colors: Record< + string, + { + color: string; + bgcolor: string; + groupcolor: string; + } + >; + static link_type_colors: Record; + static gradients: object; + static search_limit: number; + + static getFileExtension(url: string): string; + static decodeHTML(str: string): string; + + static onMenuCollapseAll(): void; + static onMenuNodeEdit(): void; + static onShowPropertyEditor( + item: any, + options: any, + e: any, + menu: any, + node: any + ): void; + /** Create menu for `Add Group` */ + static onGroupAdd: ContextMenuEventListener; + /** Create menu for `Add Node` */ + static onMenuAdd: ContextMenuEventListener; + static showMenuNodeOptionalInputs: ContextMenuEventListener; + static showMenuNodeOptionalOutputs: ContextMenuEventListener; + static onShowMenuNodeProperties: ContextMenuEventListener; + static onResizeNode: ContextMenuEventListener; + static onMenuNodeCollapse: ContextMenuEventListener; + static onMenuNodePin: ContextMenuEventListener; + static onMenuNodeMode: ContextMenuEventListener; + static onMenuNodeColors: ContextMenuEventListener; + static onMenuNodeShapes: ContextMenuEventListener; + static onMenuNodeRemove: ContextMenuEventListener; + static onMenuNodeClone: ContextMenuEventListener; + + constructor( + canvas: HTMLCanvasElement | string, + graph?: LGraph, + options?: { + skip_render?: boolean; + autoresize?: boolean; + } + ); + + static active_canvas: HTMLCanvasElement; + + allow_dragcanvas: boolean; + allow_dragnodes: boolean; + /** allow to control widgets, buttons, collapse, etc */ + allow_interaction: boolean; + /** allows to change a connection with having to redo it again */ + allow_reconnect_links: boolean; + /** allow selecting multi nodes without pressing extra keys */ + multi_select: boolean; + /** No effect */ + allow_searchbox: boolean; + always_render_background: boolean; + autoresize?: boolean; + background_image: string; + bgcanvas: HTMLCanvasElement; + bgctx: CanvasRenderingContext2D; + canvas: HTMLCanvasElement; + canvas_mouse: Vector2; + clear_background: boolean; + connecting_node: LGraphNode | null; + connections_width: number; + ctx: CanvasRenderingContext2D; + current_node: LGraphNode | null; + default_connection_color: { + input_off: string; + input_on: string; + output_off: string; + output_on: string; + }; + default_link_color: string; + dirty_area: Vector4 | null; + dirty_bgcanvas?: boolean; + dirty_canvas?: boolean; + drag_mode: boolean; + dragging_canvas: boolean; + dragging_rectangle: Vector4 | null; + ds: DragAndScale; + /** used for transition */ + editor_alpha: number; + filter: any; + fps: number; + frame: number; + graph: LGraph; + highlighted_links: Record; + highquality_render: boolean; + inner_text_font: string; + is_rendering: boolean; + last_draw_time: number; + last_mouse: Vector2; + /** + * Possible duplicated with `last_mouse` + * https://github.com/jagenjo/litegraph.js/issues/70 + */ + last_mouse_position: Vector2; + /** Timestamp of last mouse click, defaults to 0 */ + last_mouseclick: number; + links_render_mode: + | typeof LiteGraph.STRAIGHT_LINK + | typeof LiteGraph.LINEAR_LINK + | typeof LiteGraph.SPLINE_LINK; + live_mode: boolean; + node_capturing_input: LGraphNode | null; + node_dragged: LGraphNode | null; + node_in_panel: LGraphNode | null; + node_over: LGraphNode | null; + node_title_color: string; + node_widget: [LGraphNode, IWidget] | null; + /** Called by `LGraphCanvas.drawBackCanvas` */ + onDrawBackground: + | ((ctx: CanvasRenderingContext2D, visibleArea: Vector4) => void) + | null; + /** Called by `LGraphCanvas.drawFrontCanvas` */ + onDrawForeground: + | ((ctx: CanvasRenderingContext2D, visibleArea: Vector4) => void) + | null; + onDrawOverlay: ((ctx: CanvasRenderingContext2D) => void) | null; + /** Called by `LGraphCanvas.processMouseDown` */ + onMouse: ((event: MouseEvent) => boolean) | null; + /** Called by `LGraphCanvas.drawFrontCanvas` and `LGraphCanvas.drawLinkTooltip` */ + onDrawLinkTooltip: ((ctx: CanvasRenderingContext2D, link: LLink, _this: this) => void) | null; + /** Called by `LGraphCanvas.selectNodes` */ + onNodeMoved: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeSelected` */ + onNodeSelected: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.deselectNode` */ + onNodeDeselected: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeDblClicked` */ + onShowNodePanel: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.processNodeDblClicked` */ + onNodeDblClicked: ((node: LGraphNode) => void) | null; + /** Called by `LGraphCanvas.selectNodes` */ + onSelectionChange: ((nodes: Record) => void) | null; + /** Called by `LGraphCanvas.showSearchBox` */ + onSearchBox: + | (( + helper: Element, + value: string, + graphCanvas: LGraphCanvas + ) => string[]) + | null; + onSearchBoxSelection: + | ((name: string, event: MouseEvent, graphCanvas: LGraphCanvas) => void) + | null; + pause_rendering: boolean; + render_canvas_border: boolean; + render_collapsed_slots: boolean; + render_connection_arrows: boolean; + render_connections_border: boolean; + render_connections_shadows: boolean; + render_curved_connections: boolean; + render_execution_order: boolean; + render_only_selected: boolean; + render_shadows: boolean; + render_title_colored: boolean; + round_radius: number; + selected_group: null | LGraphGroup; + selected_group_resizing: boolean; + selected_nodes: Record; + show_info: boolean; + title_text_font: string; + /** set to true to render title bar with gradients */ + use_gradients: boolean; + visible_area: DragAndScale["visible_area"]; + visible_links: LLink[]; + visible_nodes: LGraphNode[]; + zoom_modify_alpha: boolean; + + /** clears all the data inside */ + clear(): void; + /** assigns a graph, you can reassign graphs to the same canvas */ + setGraph(graph: LGraph, skipClear?: boolean): void; + /** opens a graph contained inside a node in the current graph */ + openSubgraph(graph: LGraph): void; + /** closes a subgraph contained inside a node */ + closeSubgraph(): void; + /** assigns a canvas */ + setCanvas(canvas: HTMLCanvasElement, skipEvents?: boolean): void; + /** binds mouse, keyboard, touch and drag events to the canvas */ + bindEvents(): void; + /** unbinds mouse events from the canvas */ + unbindEvents(): void; + + /** + * this function allows to render the canvas using WebGL instead of Canvas2D + * this is useful if you plant to render 3D objects inside your nodes, it uses litegl.js for webgl and canvas2DtoWebGL to emulate the Canvas2D calls in webGL + **/ + enableWebGL(): void; + + /** + * marks as dirty the canvas, this way it will be rendered again + * @param fg if the foreground canvas is dirty (the one containing the nodes) + * @param bg if the background canvas is dirty (the one containing the wires) + */ + setDirty(fg: boolean, bg: boolean): void; + + /** + * Used to attach the canvas in a popup + * @return the window where the canvas is attached (the DOM root node) + */ + getCanvasWindow(): Window; + /** starts rendering the content of the canvas when needed */ + startRendering(): void; + /** stops rendering the content of the canvas (to save resources) */ + stopRendering(): void; + + processMouseDown(e: MouseEvent): boolean | undefined; + processMouseMove(e: MouseEvent): boolean | undefined; + processMouseUp(e: MouseEvent): boolean | undefined; + processMouseWheel(e: MouseEvent): boolean | undefined; + + /** returns true if a position (in graph space) is on top of a node little corner box */ + isOverNodeBox(node: LGraphNode, canvasX: number, canvasY: number): boolean; + /** returns true if a position (in graph space) is on top of a node input slot */ + isOverNodeInput( + node: LGraphNode, + canvasX: number, + canvasY: number, + slotPos: Vector2 + ): boolean; + + /** process a key event */ + processKey(e: KeyboardEvent): boolean | undefined; + + copyToClipboard(): void; + pasteFromClipboard(): void; + processDrop(e: DragEvent): void; + checkDropItem(e: DragEvent): void; + processNodeDblClicked(n: LGraphNode): void; + processNodeSelected(n: LGraphNode, e: MouseEvent): void; + processNodeDeselected(node: LGraphNode): void; + + /** selects a given node (or adds it to the current selection) */ + selectNode(node: LGraphNode, add?: boolean): void; + /** selects several nodes (or adds them to the current selection) */ + selectNodes(nodes?: LGraphNode[], add?: boolean): void; + /** removes a node from the current selection */ + deselectNode(node: LGraphNode): void; + /** removes all nodes from the current selection */ + deselectAllNodes(): void; + /** deletes all nodes in the current selection from the graph */ + deleteSelectedNodes(): void; + + /** centers the camera on a given node */ + centerOnNode(node: LGraphNode): void; + /** changes the zoom level of the graph (default is 1), you can pass also a place used to pivot the zoom */ + setZoom(value: number, center: Vector2): void; + /** brings a node to front (above all other nodes) */ + bringToFront(node: LGraphNode): void; + /** sends a node to the back (below all other nodes) */ + sendToBack(node: LGraphNode): void; + /** checks which nodes are visible (inside the camera area) */ + computeVisibleNodes(nodes: LGraphNode[]): LGraphNode[]; + /** renders the whole canvas content, by rendering in two separated canvas, one containing the background grid and the connections, and one containing the nodes) */ + draw(forceFG?: boolean, forceBG?: boolean): void; + /** draws the front canvas (the one containing all the nodes) */ + drawFrontCanvas(): void; + /** draws some useful stats in the corner of the canvas */ + renderInfo(ctx: CanvasRenderingContext2D, x: number, y: number): void; + /** draws the back canvas (the one containing the background and the connections) */ + drawBackCanvas(): void; + /** draws the given node inside the canvas */ + drawNode(node: LGraphNode, ctx: CanvasRenderingContext2D): void; + /** draws graphic for node's slot */ + drawSlotGraphic(ctx: CanvasRenderingContext2D, pos: number[], shape: SlotShape, horizontal: boolean): void; + /** draws the shape of the given node in the canvas */ + drawNodeShape( + node: LGraphNode, + ctx: CanvasRenderingContext2D, + size: [number, number], + fgColor: string, + bgColor: string, + selected: boolean, + mouseOver: boolean + ): void; + /** draws every connection visible in the canvas */ + drawConnections(ctx: CanvasRenderingContext2D): void; + /** + * draws a link between two points + * @param a start pos + * @param b end pos + * @param link the link object with all the link info + * @param skipBorder ignore the shadow of the link + * @param flow show flow animation (for events) + * @param color the color for the link + * @param startDir the direction enum + * @param endDir the direction enum + * @param numSublines number of sublines (useful to represent vec3 or rgb) + **/ + renderLink( + a: Vector2, + b: Vector2, + link: object, + skipBorder: boolean, + flow: boolean, + color?: string, + startDir?: number, + endDir?: number, + numSublines?: number + ): void; + + computeConnectionPoint( + a: Vector2, + b: Vector2, + t: number, + startDir?: number, + endDir?: number + ): void; + + drawExecutionOrder(ctx: CanvasRenderingContext2D): void; + /** draws the widgets stored inside a node */ + drawNodeWidgets( + node: LGraphNode, + posY: number, + ctx: CanvasRenderingContext2D, + activeWidget: object + ): void; + /** process an event on widgets */ + processNodeWidgets( + node: LGraphNode, + pos: Vector2, + event: Event, + activeWidget: object + ): void; + /** draws every group area in the background */ + drawGroups(canvas: any, ctx: CanvasRenderingContext2D): void; + adjustNodesSize(): void; + /** resizes the canvas to a given size, if no size is passed, then it tries to fill the parentNode */ + resize(width?: number, height?: number): void; + /** + * switches to live mode (node shapes are not rendered, only the content) + * this feature was designed when graphs where meant to create user interfaces + **/ + switchLiveMode(transition?: boolean): void; + onNodeSelectionChange(): void; + touchHandler(event: TouchEvent): void; + + showLinkMenu(link: LLink, e: any): false; + prompt( + title: string, + value: any, + callback: Function, + event: any + ): HTMLDivElement; + showSearchBox(event?: MouseEvent): void; + showEditPropertyValue(node: LGraphNode, property: any, options: any): void; + createDialog( + html: string, + options?: { position?: Vector2; event?: MouseEvent } + ): void; + + convertOffsetToCanvas: DragAndScale["convertOffsetToCanvas"]; + convertCanvasToOffset: DragAndScale["convertCanvasToOffset"]; + /** converts event coordinates from canvas2D to graph coordinates */ + convertEventToCanvasOffset(e: MouseEvent): Vector2; + /** adds some useful properties to a mouse event, like the position in graph coordinates */ + adjustMouseEvent(e: MouseEvent): void; + + getCanvasMenuOptions(): ContextMenuItem[]; + getNodeMenuOptions(node: LGraphNode): ContextMenuItem[]; + getGroupMenuOptions(): ContextMenuItem[]; + /** Called by `getCanvasMenuOptions`, replace default options */ + getMenuOptions?(): ContextMenuItem[]; + /** Called by `getCanvasMenuOptions`, append to default options */ + getExtraMenuOptions?(): ContextMenuItem[]; + /** Called when mouse right click */ + processContextMenu(node: LGraphNode, event: Event): void; +} + +declare class ContextMenu { + static trigger( + element: HTMLElement, + event_name: string, + params: any, + origin: any + ): void; + static isCursorOverElement(event: MouseEvent, element: HTMLElement): void; + static closeAllContextMenus(window: Window): void; + constructor(values: ContextMenuItem[], options?: IContextMenuOptions, window?: Window); + options: IContextMenuOptions; + parentMenu?: ContextMenu; + lock: boolean; + current_submenu?: ContextMenu; + addItem( + name: string, + value: ContextMenuItem, + options?: IContextMenuOptions + ): void; + close(e?: MouseEvent, ignore_parent_menu?: boolean): void; + getTopMenu(): void; + getFirstEvent(): void; +} + +declare global { + interface CanvasRenderingContext2D { + /** like rect but rounded corners */ + roundRect( + x: number, + y: number, + width: number, + height: number, + radius: number, + radiusLow: number + ): void; + } + + interface Math { + clamp(v: number, min: number, max: number): number; + } +} From 2dd28d4d20ee4f272db5d674bb229a2fe37dadb5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 15 Apr 2023 21:41:21 +0100 Subject: [PATCH 10/35] style --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 940c5ecf..1695dcae 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -4,7 +4,7 @@ import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, importA1111 } from "./pnginfo.js"; - /** +/** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension */ From a908e12d23c820b916900f9d9ce2d5ecd507f3a2 Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Sat, 15 Apr 2023 18:18:19 -0400 Subject: [PATCH 11/35] Update nodes.py with new Note node --- nodes.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nodes.py b/nodes.py index 6468ac6b..d49c830e 100644 --- a/nodes.py +++ b/nodes.py @@ -510,6 +510,14 @@ class EmptyLatentImage: return ({"samples":latent}, ) +class Note: + @classmethod + def INPUT_TYPES(s): + return {"required": {"text": ("STRING", {"multiline": True})}} + + CATEGORY = "other" + RETURN_TYPES = () + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -1072,6 +1080,7 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, + "Note": Note, "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "PreviewImage": PreviewImage, @@ -1138,6 +1147,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentFlip": "Flip Latent", "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", + "Note": "Note", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", # Image From 81d1f00df32e64053343e863c9c71a5d97761675 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Apr 2023 18:46:58 -0400 Subject: [PATCH 12/35] Some refactoring: from_tokens -> encode_from_tokens --- comfy/sd.py | 10 +++++----- comfy/sd1_clip.py | 6 +++--- comfy/sd2_clip.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6e54bc60..d6d45fef 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -375,13 +375,9 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode(self, text, from_tokens=False): + def encode_from_tokens(self, tokens): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) - if from_tokens: - tokens = text - else: - tokens = self.tokenizer.tokenize_with_weights(text) try: self.patcher.patch_model() cond = self.cond_stage_model.encode_token_weights(tokens) @@ -391,6 +387,10 @@ class CLIP: raise e return cond + def encode(self, text): + tokens = self.tokenizer.tokenize_with_weights(text) + return self.encode_from_tokens(tokens) + class VAE: def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None): if config is None: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 97b96953..7f1217c3 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -315,7 +315,7 @@ class SD1Tokenizer: continue #parse word tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) - + #reshape token array to CLIP input size batched_tokens = [] batch = [(self.start_token, 1.0, 0)] @@ -338,11 +338,11 @@ class SD1Tokenizer: batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) #start new batch batch = [(self.start_token, 1.0, 0)] - batched_tokens.append(batch) + batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] - + #fill last batch batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index fda793eb..32f202ae 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -1,4 +1,4 @@ -import sd1_clip +from comfy import sd1_clip import torch import os From 73c3e11e83f6fcf1a47b4965fe60b03075e1a762 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 15 Apr 2023 18:55:17 -0400 Subject: [PATCH 13/35] Fix model_management import so it doesn't get executed twice. --- comfy/ldm/modules/attention.py | 2 +- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ldm/modules/sub_quadratic_attention.py | 2 +- comfy/samplers.py | 2 +- comfy/sd.py | 4 ++-- comfy_extras/nodes_upscale_model.py | 2 +- nodes.py | 14 +++++++------- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 92b3eca7..c8338734 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -9,7 +9,7 @@ from typing import Optional, Any from ldm.modules.diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention -import model_management +from comfy import model_management from . import tomesd diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 788a6fc4..1599d386 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -7,7 +7,7 @@ from einops import rearrange from typing import Optional, Any from ldm.modules.attention import MemoryEfficientCrossAttention -import model_management +from comfy import model_management if model_management.xformers_enabled_vae(): import xformers diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index f3c83f38..573cce74 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -24,7 +24,7 @@ except ImportError: from torch import Tensor from typing import List -import model_management +from comfy import model_management def dynamic_slice( x: Tensor, diff --git a/comfy/samplers.py b/comfy/samplers.py index 93f5d361..ed36442a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -3,7 +3,7 @@ from .k_diffusion import external as k_diffusion_external from .extra_samplers import uni_pc import torch import contextlib -import model_management +from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps diff --git a/comfy/sd.py b/comfy/sd.py index d6d45fef..9c632e24 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -4,7 +4,7 @@ import copy import sd1_clip import sd2_clip -import model_management +from comfy import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL import yaml @@ -388,7 +388,7 @@ class CLIP: return cond def encode(self, text): - tokens = self.tokenizer.tokenize_with_weights(text) + tokens = self.tokenize(text) return self.encode_from_tokens(tokens) class VAE: diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 6a7d0e51..d8754698 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -1,6 +1,6 @@ import os from comfy_extras.chainner_models import model_loading -import model_management +from comfy import model_management import torch import comfy.utils import folder_paths diff --git a/nodes.py b/nodes.py index 6468ac6b..e6ad9434 100644 --- a/nodes.py +++ b/nodes.py @@ -21,16 +21,16 @@ import comfy.utils import comfy.clip_vision -import model_management +import comfy.model_management import importlib import folder_paths def before_node_execution(): - model_management.throw_exception_if_processing_interrupted() + comfy.model_management.throw_exception_if_processing_interrupted() def interrupt_processing(value=True): - model_management.interrupt_current_processing(value) + comfy.model_management.interrupt_current_processing(value) MAX_RESOLUTION=8192 @@ -241,7 +241,7 @@ class DiffusersLoader: model_path = os.path.join(search_path, model_path) break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: @@ -680,7 +680,7 @@ class SetLatentNoiseMask: def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] noise_mask = None - device = model_management.get_torch_device() + device = comfy.model_management.get_torch_device() if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") @@ -696,7 +696,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = noise_mask.to(device) real_model = None - model_management.load_model_gpu(model) + comfy.model_management.load_model_gpu(model) real_model = model.model noise = noise.to(device) @@ -726,7 +726,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, control_net_models = [] for x in control_nets: control_net_models += x.get_control_models() - model_management.load_controlnet_gpu(control_net_models) + comfy.model_management.load_controlnet_gpu(control_net_models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) From 6c35ea505efbeb78ffe3c6bfc6b63e68f4290561 Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:48:24 -0400 Subject: [PATCH 14/35] reverting changes --- nodes.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/nodes.py b/nodes.py index d49c830e..6468ac6b 100644 --- a/nodes.py +++ b/nodes.py @@ -510,14 +510,6 @@ class EmptyLatentImage: return ({"samples":latent}, ) -class Note: - @classmethod - def INPUT_TYPES(s): - return {"required": {"text": ("STRING", {"multiline": True})}} - - CATEGORY = "other" - RETURN_TYPES = () - class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -1080,7 +1072,6 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, - "Note": Note, "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "PreviewImage": PreviewImage, @@ -1147,7 +1138,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentFlip": "Flip Latent", "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", - "Note": "Note", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", # Image From 9587ea90c82998abc73387aab594ec7217f6d50a Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:50:05 -0400 Subject: [PATCH 15/35] Create noteNode.js --- web/extensions/core/noteNode.js | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 web/extensions/core/noteNode.js diff --git a/web/extensions/core/noteNode.js b/web/extensions/core/noteNode.js new file mode 100644 index 00000000..12428773 --- /dev/null +++ b/web/extensions/core/noteNode.js @@ -0,0 +1,38 @@ +import {app} from "../../scripts/app.js"; +import {ComfyWidgets} from "../../scripts/widgets.js"; +// Node that add notes to your project + +app.registerExtension({ + name: "Comfy.NoteNode", + registerCustomNodes() { + class NoteNode { + color=LGraphCanvas.node_colors.yellow.color; + bgcolor=LGraphCanvas.node_colors.yellow.bgcolor; + groupcolor = LGraphCanvas.node_colors.yellow.groupcolor; + constructor() { + if (!this.properties) { + this.properties = {}; + } + + ComfyWidgets.STRING(this, "", ["", {multiline: true}], app) + // This node is purely frontend and does not impact the resulting prompt so should not be serialized + this.isVirtualNode = true; + } + + + } + + // Load default visibility + + LiteGraph.registerNodeType( + "Note", + Object.assign(NoteNode, { + title_mode: LiteGraph.NORMAL_TITLE, + title: "Note", + collapsable: true, + }) + ); + + NoteNode.category = "utils"; + }, +}); From fb61c75e392ae0a3813955d56fb5aceecacff2e4 Mon Sep 17 00:00:00 2001 From: jwd-dev Date: Sat, 15 Apr 2023 19:58:46 -0400 Subject: [PATCH 16/35] default text property incase we need one. --- web/extensions/core/noteNode.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/noteNode.js b/web/extensions/core/noteNode.js index 12428773..1412d437 100644 --- a/web/extensions/core/noteNode.js +++ b/web/extensions/core/noteNode.js @@ -12,9 +12,10 @@ app.registerExtension({ constructor() { if (!this.properties) { this.properties = {}; + this.properties.text=""; } - ComfyWidgets.STRING(this, "", ["", {multiline: true}], app) + ComfyWidgets.STRING(this, "", ["", {default:this.properties.text, multiline: true}], app) // This node is purely frontend and does not impact the resulting prompt so should not be serialized this.isVirtualNode = true; } From 8cd170f40daa635ad17c29fab12296cb5936df69 Mon Sep 17 00:00:00 2001 From: jwd-dev Date: Sat, 15 Apr 2023 20:16:28 -0400 Subject: [PATCH 17/35] node serialization --- web/extensions/core/noteNode.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/noteNode.js b/web/extensions/core/noteNode.js index 1412d437..8d89054e 100644 --- a/web/extensions/core/noteNode.js +++ b/web/extensions/core/noteNode.js @@ -16,8 +16,10 @@ app.registerExtension({ } ComfyWidgets.STRING(this, "", ["", {default:this.properties.text, multiline: true}], app) - // This node is purely frontend and does not impact the resulting prompt so should not be serialized + + this.serialize_widgets = true; this.isVirtualNode = true; + } From bc16b70bdef76d118f055c023279a4b0d4ce16a7 Mon Sep 17 00:00:00 2001 From: Karun Date: Sun, 16 Apr 2023 01:25:11 -0400 Subject: [PATCH 18/35] Adds several keybinds that interact with ComfyUI (#491) * adds keybinds that interact w/ comfy menu * adds remaining keybinds * adds keybinds to readme and converts to table * ctrl s and o save and open workflow * moves keybinds to sep file, update readme * remap load default, support keycodes * update keybinds table, prepends comfy to ids * escape exits out of modals * modifier keys also use map * adds setting for filename prompt * better handle filename prompt Co-authored-by: missionfloyd --- README.md | 30 +++++++++---- web/extensions/core/keybinds.js | 76 +++++++++++++++++++++++++++++++++ web/scripts/app.js | 6 --- web/scripts/ui.js | 32 +++++++++++--- 4 files changed, 124 insertions(+), 20 deletions(-) create mode 100644 web/extensions/core/keybinds.js diff --git a/README.md b/README.md index 77d979ac..f610f949 100644 --- a/README.md +++ b/README.md @@ -32,14 +32,28 @@ This ui will let you design and execute advanced stable diffusion pipelines usin Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) ## Shortcuts -- **Ctrl + A** select all nodes -- **Ctrl + M** mute/unmute selected nodes -- **Delete** or **Backspace** delete selected nodes -- **Space** Holding space key while moving the cursor moves the canvas around. It works when holding the mouse button down so it is easier to connect different nodes when the canvas gets too large. -- **Ctrl/Shift + Click** Add clicked node to selection. -- **Ctrl + C/Ctrl + V** - Copy and paste selected nodes, without maintaining the connection to the outputs of unselected nodes. -- **Ctrl + C/Ctrl + Shift + V** - Copy and paste selected nodes, and maintaining the connection from the outputs of unselected nodes to the inputs of the newly pasted nodes. -- Holding **Shift** and drag selected nodes - Move multiple selected nodes at the same time. + +| Keybind | Explanation | +| - | - | +| Ctrl + Enter | Queue up current graph for generation | +| Ctrl + Shift + Enter | Queue up current graph as first for generation | +| Ctrl + S | Save workflow | +| Ctrl + O | Load workflow | +| Ctrl + A | Select all nodes | +| Ctrl + M | Mute/unmute selected nodes | +| Delete/Backspace | Delete selected nodes | +| Ctrl + Delete/Backspace | Delete the current graph | +| Space | Move the canvas around when held and moving the cursor | +| Ctrl/Shift + Click | Add clicked node to selection | +| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | +| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | +| Shift + Drag | Move multiple selected nodes at the same time | +| Ctrl + D | Load default graph | +| Q | Toggle visibility of the queue | +| H | Toggle visibility of history | +| R | Refresh graph | + +Ctrl can also be replaced with Cmd instead for MacOS users # Installing diff --git a/web/extensions/core/keybinds.js b/web/extensions/core/keybinds.js new file mode 100644 index 00000000..1825007a --- /dev/null +++ b/web/extensions/core/keybinds.js @@ -0,0 +1,76 @@ +import { app } from "/scripts/app.js"; + +const id = "Comfy.Keybinds"; +app.registerExtension({ + name: id, + init() { + const keybindListener = function(event) { + const target = event.composedPath()[0]; + + if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") { + return; + } + + const modifierPressed = event.ctrlKey || event.metaKey; + + // Queue prompt using ctrl or command + enter + if (modifierPressed && (event.key === "Enter" || event.keyCode === 13 || event.keyCode === 10)) { + app.queuePrompt(event.shiftKey ? -1 : 0); + return; + } + + const modifierKeyIdMap = { + "s": "#comfy-save-button", + 83: "#comfy-save-button", + "o": "#comfy-file-input", + 79: "#comfy-file-input", + "Backspace": "#comfy-clear-button", + 8: "#comfy-clear-button", + "Delete": "#comfy-clear-button", + 46: "#comfy-clear-button", + "d": "#comfy-load-default-button", + 68: "#comfy-load-default-button", + }; + + const modifierKeybindId = modifierKeyIdMap[event.key] || modifierKeyIdMap[event.keyCode]; + if (modifierPressed && modifierKeybindId) { + event.preventDefault(); + + const elem = document.querySelector(modifierKeybindId); + elem.click(); + return; + } + + // Finished Handling all modifier keybinds, now handle the rest + if (event.ctrlKey || event.altKey || event.metaKey) { + return; + } + + // Close out of modals using escape + if (event.key === "Escape" || event.keyCode === 27) { + const modals = document.querySelectorAll(".comfy-modal"); + const modal = Array.from(modals).find(modal => window.getComputedStyle(modal).getPropertyValue("display") !== "none"); + if (modal) { + modal.style.display = "none"; + } + } + + const keyIdMap = { + "q": "#comfy-view-queue-button", + 81: "#comfy-view-queue-button", + "h": "#comfy-view-history-button", + 72: "#comfy-view-history-button", + "r": "#comfy-refresh-button", + 82: "#comfy-refresh-button", + }; + + const buttonId = keyIdMap[event.key] || keyIdMap[event.keyCode]; + if (buttonId) { + const button = document.querySelector(buttonId); + button.click(); + } + } + + window.addEventListener("keydown", keybindListener, true); + } +}); diff --git a/web/scripts/app.js b/web/scripts/app.js index 1695dcae..f158f345 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -35,7 +35,6 @@ export class ComfyApp { */ this.nodeOutputs = {}; - /** * If the shift key on the keyboard is pressed * @type {boolean} @@ -713,11 +712,6 @@ export class ComfyApp { #addKeyboardHandler() { window.addEventListener("keydown", (e) => { this.shiftDown = e.shiftKey; - - // Queue prompt using ctrl or command + enter - if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) { - this.queuePrompt(e.shiftKey ? -1 : 0); - } }); window.addEventListener("keyup", (e) => { this.shiftDown = e.shiftKey; diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 09861c44..f320f840 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -431,7 +431,15 @@ export class ComfyUI { defaultValue: true, }); + const promptFilename = this.settings.addSetting({ + id: "Comfy.PromptFilename", + name: "Prompt for filename when saving workflow", + type: "boolean", + defaultValue: true, + }); + const fileInput = $el("input", { + id: "comfy-file-input", type: "file", accept: ".json,image/png", style: { display: "none" }, @@ -448,6 +456,7 @@ export class ComfyUI { $el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }), ]), $el("button.comfy-queue-btn", { + id: "queue-button", textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount), }), @@ -496,9 +505,10 @@ export class ComfyUI { ]), ]), $el("div.comfy-menu-btns", [ - $el("button", { textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }), + $el("button", { id: "queue-front-button", textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }), $el("button", { $: (b) => (this.queue.button = b), + id: "comfy-view-queue-button", textContent: "View Queue", onclick: () => { this.history.hide(); @@ -507,6 +517,7 @@ export class ComfyUI { }), $el("button", { $: (b) => (this.history.button = b), + id: "comfy-view-history-button", textContent: "View History", onclick: () => { this.queue.hide(); @@ -517,14 +528,23 @@ export class ComfyUI { this.queue.element, this.history.element, $el("button", { + id: "comfy-save-button", textContent: "Save", onclick: () => { + let filename = "workflow.json"; + if (promptFilename.value) { + filename = prompt("Save workflow as:", filename); + if (!filename) return; + if (!filename.toLowerCase().endsWith(".json")) { + filename += ".json"; + } + } const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string const blob = new Blob([json], { type: "application/json" }); const url = URL.createObjectURL(blob); const a = $el("a", { href: url, - download: "workflow.json", + download: filename, style: { display: "none" }, parent: document.body, }); @@ -535,15 +555,15 @@ export class ComfyUI { }, 0); }, }), - $el("button", { textContent: "Load", onclick: () => fileInput.click() }), - $el("button", { textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), - $el("button", { textContent: "Clear", onclick: () => { + $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), + $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), + $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); app.graph.clear(); } }}), - $el("button", { textContent: "Load Default", onclick: () => { + $el("button", { id: "comfy-load-default-button", textContent: "Load Default", onclick: () => { if (!confirmClear.value || confirm("Load default workflow?")) { app.loadGraphData() } From 74fc7b772656a59b344508480632d9d45f9127de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 16 Apr 2023 01:36:15 -0400 Subject: [PATCH 19/35] custom_nodes paths can now be set in the extra_model_paths.yaml --- extra_model_paths.yaml.example | 2 +- folder_paths.py | 7 +++++-- main.py | 15 ++++++++------- nodes.py | 17 +++++++++-------- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index af784fd6..f421f54d 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -18,6 +18,6 @@ a111: #other_ui: # base_path: path/to/ui # checkpoints: models/checkpoints - +# custom_nodes: path/custom_nodes diff --git a/folder_paths.py b/folder_paths.py index ab335934..61f446c9 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -12,8 +12,8 @@ except: folder_names_and_paths = {} - -models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") +base_path = os.path.dirname(os.path.realpath(__file__)) +models_dir = os.path.join(base_path, "models") folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) @@ -28,6 +28,9 @@ folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) +folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) + + output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") diff --git a/main.py b/main.py index 9c0a3d8a..02c700eb 100644 --- a/main.py +++ b/main.py @@ -81,6 +81,14 @@ if __name__ == "__main__": server = server.PromptServer(loop) q = execution.PromptQueue(server) + extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") + if os.path.isfile(extra_model_paths_config_path): + load_extra_path_config(extra_model_paths_config_path) + + if args.extra_model_paths_config: + for config_path in itertools.chain(*args.extra_model_paths_config): + load_extra_path_config(config_path) + init_custom_nodes() server.add_routes() hijack_progress(server) @@ -91,13 +99,6 @@ if __name__ == "__main__": dont_print = args.dont_print_server - extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") - if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) - - if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) if args.output_directory: output_dir = os.path.abspath(args.output_directory) diff --git a/nodes.py b/nodes.py index e6ad9434..c775da00 100644 --- a/nodes.py +++ b/nodes.py @@ -1178,15 +1178,16 @@ def load_custom_node(module_path): print(f"Cannot import {module_path} module for custom nodes:", e) def load_custom_nodes(): - CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") - possible_modules = os.listdir(CUSTOM_NODE_PATH) - if "__pycache__" in possible_modules: - possible_modules.remove("__pycache__") + node_paths = folder_paths.get_folder_paths("custom_nodes") + for custom_node_path in node_paths: + possible_modules = os.listdir(custom_node_path) + if "__pycache__" in possible_modules: + possible_modules.remove("__pycache__") - for possible_module in possible_modules: - module_path = os.path.join(CUSTOM_NODE_PATH, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - load_custom_node(module_path) + for possible_module in possible_modules: + module_path = os.path.join(custom_node_path, possible_module) + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + load_custom_node(module_path) def init_custom_nodes(): load_custom_nodes() From 22bde7957e18e8f9c4fb206227a6117dae391417 Mon Sep 17 00:00:00 2001 From: Tomoaki Hayasaka Date: Mon, 17 Apr 2023 01:58:33 +0900 Subject: [PATCH 20/35] Fix "Ctrl+Enter doesn't work when textarea has focus" regression introduced in #491. --- web/extensions/core/keybinds.js | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/web/extensions/core/keybinds.js b/web/extensions/core/keybinds.js index 1825007a..42c22801 100644 --- a/web/extensions/core/keybinds.js +++ b/web/extensions/core/keybinds.js @@ -5,12 +5,6 @@ app.registerExtension({ name: id, init() { const keybindListener = function(event) { - const target = event.composedPath()[0]; - - if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") { - return; - } - const modifierPressed = event.ctrlKey || event.metaKey; // Queue prompt using ctrl or command + enter @@ -19,6 +13,12 @@ app.registerExtension({ return; } + const target = event.composedPath()[0]; + + if (target.tagName === "INPUT" || target.tagName === "TEXTAREA") { + return; + } + const modifierKeyIdMap = { "s": "#comfy-save-button", 83: "#comfy-save-button", From 0ab5c619eafa026d4be1a3f6bf462a6f7f9d25d6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 01:04:54 -0400 Subject: [PATCH 21/35] Clarify in README that it's AMD GPUs. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f610f949..be2cb8ec 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ Put your VAE in: models/vae At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10. -### AMD (Linux only) +### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` From 884ea653c8d6fe19b3724f45a04a0d74cd881f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 11:05:15 -0400 Subject: [PATCH 22/35] Add a way for nodes to set a custom CFG function. --- comfy/samplers.py | 5 ++++- comfy/sd.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index ed36442a..05af6fe8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -211,7 +211,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) - return uncond + (cond - uncond) * cond_scale + if "sampler_cfg_function" in model_options: + return model_options["sampler_cfg_function"](cond, uncond, cond_scale) + else: + return uncond + (cond - uncond) * cond_scale class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): diff --git a/comfy/sd.py b/comfy/sd.py index 9c632e24..1d777474 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -250,6 +250,9 @@ class ModelPatcher: def set_model_tomesd(self, ratio): self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} + def set_model_sampler_cfg_function(self, sampler_cfg_function): + self.model_options["sampler_cfg_function"] = sampler_cfg_function + def model_dtype(self): return self.model.diffusion_model.dtype From 6f7852bc47de2fa432672a1b93c1727c0824d78b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 17:24:58 -0400 Subject: [PATCH 23/35] Add a LatentFromBatch node to pick a single latent from a batch. Works before and after sampling. --- nodes.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index c775da00..c745ce28 100644 --- a/nodes.py +++ b/nodes.py @@ -510,6 +510,24 @@ class EmptyLatentImage: return ({"samples":latent}, ) +class LatentFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "rotate" + + CATEGORY = "latent" + + def rotate(self, samples, batch_index): + s = samples.copy() + s_in = samples["samples"] + batch_index = min(s_in.shape[0] - 1, batch_index) + s["samples"] = s_in[batch_index:batch_index + 1].clone() + s["batch_index"] = batch_index + return (s,) class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -685,7 +703,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") + batch_index = 0 + if "batch_index" in latent: + batch_index = latent["batch_index"] + + generator = torch.manual_seed(seed) + for i in range(batch_index + 1): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") if "noise_mask" in latent: noise_mask = latent['noise_mask'] @@ -1073,6 +1097,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentFromBatch": LatentFromBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, From 7b5eb196dbf4248eb6c67af2843cacb28863ce2f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 17 Apr 2023 17:29:22 -0400 Subject: [PATCH 24/35] allows control arrow to edit attention in textarea --- web/extensions/core/editAttention.js | 117 +++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 web/extensions/core/editAttention.js diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js new file mode 100644 index 00000000..d943290c --- /dev/null +++ b/web/extensions/core/editAttention.js @@ -0,0 +1,117 @@ +import { app } from "/scripts/app.js"; + +// Allows you to edit the attention weight by holding ctrl (or cmd) and using the up/down arrow keys + +const id = "Comfy.EditAttention"; +app.registerExtension({ +name:id, + init() { + function incrementWeight(weight, delta) { + const floatWeight = parseFloat(weight); + if (isNaN(floatWeight)) return weight; + const newWeight = floatWeight + delta; + if (newWeight < 0) return "0"; + return String(Number(newWeight.toFixed(10))); + } + + function findNearestEnclosure(text, cursorPos) { + let start = cursorPos, end = cursorPos; + let openCount = 0, closeCount = 0; + + // Find opening parenthesis before cursor + while (start >= 0) { + start--; + if (text[start] === "(" && openCount === closeCount) break; + if (text[start] === "(") openCount++; + if (text[start] === ")") closeCount++; + } + if (start < 0) return false; + + openCount = 0; + closeCount = 0; + + // Find closing parenthesis after cursor + while (end < text.length) { + if (text[end] === ")" && openCount === closeCount) break; + if (text[end] === "(") openCount++; + if (text[end] === ")") closeCount++; + end++; + } + if (end === text.length) return false; + + return { start: start + 1, end: end }; + } + + function addWeightToParentheses(text) { + const parenRegex = /^\((.*)\)$/; + const parenMatch = text.match(parenRegex); + + const floatRegex = /:([+-]?(\d*\.)?\d+([eE][+-]?\d+)?)/; + const floatMatch = text.match(floatRegex); + + if (parenMatch && !floatMatch) { + return `(${parenMatch[1]}:1.0)`; + } else { + return text; + } + }; + + function editAttention(event) { + const inputField = event.composedPath()[0]; + const delta = 0.1; + + if (inputField.tagName !== "TEXTAREA") return; + if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return; + if (!event.ctrlKey && !event.metaKey) return; + + event.preventDefault(); + + let start = inputField.selectionStart; + let end = inputField.selectionEnd; + let selectedText = inputField.value.substring(start, end); + + // If there is no selection, attempt to find the nearest enclosure + if (!selectedText) { + const nearestEnclosure = findNearestEnclosure(inputField.value, start); + if (nearestEnclosure) { + start = nearestEnclosure.start; + end = nearestEnclosure.end; + selectedText = inputField.value.substring(start, end); + } else { + return; + } + } + + // If the selection ends with a space, remove it + if (selectedText[selectedText.length - 1] === " ") { + selectedText = selectedText.substring(0, selectedText.length - 1); + end -= 1; + } + + // If there are parentheses left and right of the selection, select them + if (inputField.value[start - 1] === "(" && inputField.value[end] === ")") { + start -= 1; + end += 1; + selectedText = inputField.value.substring(start, end); + } + + // If the selection is not enclosed in parentheses, add them + if (selectedText[0] !== "(" || selectedText[selectedText.length - 1] !== ")") { + console.log("adding parentheses", inputField.value[start], inputField.value[end], selectedText); + selectedText = `(${selectedText})`; + } + + // If the selection does not have a weight, add a weight of 1.0 + selectedText = addWeightToParentheses(selectedText); + + // Increment the weight + const weightDelta = event.key === "ArrowUp" ? delta : -delta; + const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => { + return prefix + incrementWeight(weight, weightDelta) + suffix; + }); + + inputField.setRangeText(updatedText, start, end, "select"); + } + window.addEventListener("keydown", editAttention); + }, +}); From f03dade5ab8f17a165d63efc205eb34a2330b7d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 18:19:57 -0400 Subject: [PATCH 25/35] Fix bug. --- nodes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index c745ce28..06b69f45 100644 --- a/nodes.py +++ b/nodes.py @@ -708,8 +708,9 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, batch_index = latent["batch_index"] generator = torch.manual_seed(seed) - for i in range(batch_index + 1): + for i in range(batch_index): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") if "noise_mask" in latent: noise_mask = latent['noise_mask'] From b8c636b10d39e77742f3f435bf6b85c3aa806583 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 17 Apr 2023 18:21:24 -0400 Subject: [PATCH 26/35] Lower how much CTRL+arrow key changes the number. --- web/extensions/core/editAttention.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index d943290c..fe395c3c 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -58,7 +58,7 @@ name:id, function editAttention(event) { const inputField = event.composedPath()[0]; - const delta = 0.1; + const delta = 0.025; if (inputField.tagName !== "TEXTAREA") return; if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return; From 79ba0399d8d70bc655269fc3318455a70d14e180 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 17 Apr 2023 19:02:08 -0400 Subject: [PATCH 27/35] selects current word automatically --- web/extensions/core/editAttention.js | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index fe395c3c..55201953 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -70,7 +70,7 @@ name:id, let end = inputField.selectionEnd; let selectedText = inputField.value.substring(start, end); - // If there is no selection, attempt to find the nearest enclosure + // If there is no selection, attempt to find the nearest enclosure, or select the current word if (!selectedText) { const nearestEnclosure = findNearestEnclosure(inputField.value, start); if (nearestEnclosure) { @@ -78,7 +78,18 @@ name:id, end = nearestEnclosure.end; selectedText = inputField.value.substring(start, end); } else { - return; + // Select the current word, find the start and end of the word (first space before and after) + start = inputField.value.substring(0, start).lastIndexOf(" ") + 1; + end = inputField.value.substring(end).indexOf(" ") + end; + // Remove all punctuation at the end and beginning of the word + while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { + start++; + } + while (inputField.value[end - 1].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { + end--; + } + selectedText = inputField.value.substring(start, end); + if (!selectedText) return; } } @@ -97,7 +108,6 @@ name:id, // If the selection is not enclosed in parentheses, add them if (selectedText[0] !== "(" || selectedText[selectedText.length - 1] !== ")") { - console.log("adding parentheses", inputField.value[start], inputField.value[end], selectedText); selectedText = `(${selectedText})`; } From a962222992479057b104cdd06bf399d2a2cae2fa Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 17 Apr 2023 23:40:44 -0400 Subject: [PATCH 28/35] correctly checks end of the text --- web/extensions/core/editAttention.js | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index 55201953..206d0830 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -79,8 +79,16 @@ name:id, selectedText = inputField.value.substring(start, end); } else { // Select the current word, find the start and end of the word (first space before and after) - start = inputField.value.substring(0, start).lastIndexOf(" ") + 1; - end = inputField.value.substring(end).indexOf(" ") + end; + const wordStart = inputField.value.substring(0, start).lastIndexOf(" ") + 1; + const wordEnd = inputField.value.substring(end).indexOf(" "); + // If there is no space after the word, select to the end of the string + if (wordEnd === -1) { + end = inputField.value.length; + } else { + end += wordEnd; + } + start = wordStart; + // Remove all punctuation at the end and beginning of the word while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { start++; From a7c7da68dc8a5e6bf1e316b6b36c4a61c7571445 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 18 Apr 2023 00:22:05 -0600 Subject: [PATCH 29/35] Editattention setting (#533) * Add editAttention delta setting * Update editAttention.js * Update web/extensions/core/editAttention.js Co-authored-by: Karun * Update editAttention.js * Update editAttention.js * Fix setting value --------- Co-authored-by: Karun --- web/extensions/core/editAttention.js | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index 206d0830..66d4a837 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -2,10 +2,21 @@ import { app } from "/scripts/app.js"; // Allows you to edit the attention weight by holding ctrl (or cmd) and using the up/down arrow keys -const id = "Comfy.EditAttention"; app.registerExtension({ -name:id, + name: "Comfy.EditAttention", init() { + const editAttentionDelta = app.ui.settings.addSetting({ + id: "Comfy.EditAttention.Delta", + name: "Ctrl+up/down precision", + type: "slider", + attrs: { + min: 0.01, + max: 2, + step: 0.01, + }, + defaultValue: 0.1, + }); + function incrementWeight(weight, delta) { const floatWeight = parseFloat(weight); if (isNaN(floatWeight)) return weight; @@ -58,7 +69,7 @@ name:id, function editAttention(event) { const inputField = event.composedPath()[0]; - const delta = 0.025; + const delta = parseFloat(editAttentionDelta.value); if (inputField.tagName !== "TEXTAREA") return; if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return; @@ -125,7 +136,7 @@ name:id, // Increment the weight const weightDelta = event.key === "ArrowUp" ? delta : -delta; const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => { - return prefix + incrementWeight(weight, weightDelta) + suffix; + return prefix + incrementWeight(weight, weightDelta) + suffix; }); inputField.setRangeText(updatedText, start, end, "select"); From b016e2769f0a16fcba21c020023413cad68f704b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Apr 2023 02:25:57 -0400 Subject: [PATCH 30/35] Saner range of values. --- web/extensions/core/editAttention.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index 66d4a837..bebc80b1 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -11,10 +11,10 @@ app.registerExtension({ type: "slider", attrs: { min: 0.01, - max: 2, + max: 0.5, step: 0.01, }, - defaultValue: 0.1, + defaultValue: 0.05, }); function incrementWeight(weight, delta) { From 472b1cc0d881c4009e5a89e0893c5835f3a4c47d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Apr 2023 19:34:07 -0400 Subject: [PATCH 31/35] Add a github action to use pip xformers package for dependencies. --- .../windows_release_cu118_dependencies_2.yml | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/windows_release_cu118_dependencies_2.yml diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml new file mode 100644 index 00000000..a8844952 --- /dev/null +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -0,0 +1,30 @@ +name: "Windows Release cu118 dependencies 2" + +on: + workflow_dispatch: +# push: +# branches: +# - master + +jobs: + build_dependencies: + runs-on: windows-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10.9' + + - shell: bash + run: | + python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir cu118_python_deps + tar cf cu118_python_deps.tar cu118_python_deps + + - uses: actions/cache/save@v3 + with: + path: cu118_python_deps.tar + key: ${{ runner.os }}-build-cu118 From 3696d1699a6fece2485c063317cf65abbcddb79b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 09:36:19 -0400 Subject: [PATCH 32/35] Add support for GLIGEN textbox model. --- comfy/gligen.py | 343 ++++++++++++++++++ comfy/ldm/modules/attention.py | 16 + .../modules/diffusionmodules/openaimodel.py | 2 + comfy/model_management.py | 6 +- comfy/samplers.py | 57 ++- comfy/sd.py | 22 +- folder_paths.py | 2 + models/gligen/put_gligen_models_here | 0 nodes.py | 71 +++- 9 files changed, 491 insertions(+), 28 deletions(-) create mode 100644 comfy/gligen.py create mode 100644 models/gligen/put_gligen_models_here diff --git a/comfy/gligen.py b/comfy/gligen.py new file mode 100644 index 00000000..8770383e --- /dev/null +++ b/comfy/gligen.py @@ -0,0 +1,343 @@ +import torch +from torch import nn, einsum +from ldm.modules.attention import CrossAttention +from inspect import isfunction + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * torch.nn.functional.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class GatedCrossAttentionDense(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + self.attn = CrossAttention( + query_dim=query_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + x = x + self.scale * \ + torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class GatedSelfAttentionDense(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj + # feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = CrossAttention( + query_dim=query_dim, + context_dim=query_dim, + heads=n_heads, + dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + N_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( + self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class GatedSelfAttentionDense2(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj + # feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = CrossAttention( + query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + B, N_visual, _ = x.shape + B, N_ground, _ = objs.shape + + objs = self.linear(objs) + + # sanity check + size_v = math.sqrt(N_visual) + size_g = math.sqrt(N_ground) + assert int(size_v) == size_v, "Visual tokens must be square rootable" + assert int(size_g) == size_g, "Grounding tokens must be square rootable" + size_v = int(size_v) + size_g = int(size_g) + + # select grounding token and resize it to visual token size as residual + out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[ + :, N_visual:, :] + out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g) + out = torch.nn.functional.interpolate( + out, (size_v, size_v), mode='bicubic') + residual = out.reshape(B, -1, N_visual).permute(0, 2, 1) + + # add residual to visual feature + x = x + self.scale * torch.tanh(self.alpha_attn) * residual + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class FourierEmbedder(): + def __init__(self, num_freqs=64, temperature=100): + + self.num_freqs = num_freqs + self.temperature = temperature + self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) + + @torch.no_grad() + def __call__(self, x, cat_dim=-1): + "x: arbitrary shape of tensor. dim: cat dim" + out = [] + for freq in self.freq_bands: + out.append(torch.sin(freq * x)) + out.append(torch.cos(freq * x)) + return torch.cat(out, cat_dim) + + +class PositionNet(nn.Module): + def __init__(self, in_dim, out_dim, fourier_freqs=8): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy + + self.linears = nn.Sequential( + nn.Linear(self.in_dim + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_positive_feature = torch.nn.Parameter( + torch.zeros([self.in_dim])) + self.null_position_feature = torch.nn.Parameter( + torch.zeros([self.position_dim])) + + def forward(self, boxes, masks, positive_embeddings): + B, N, _ = boxes.shape + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C + + # learnable null embedding + positive_null = self.null_positive_feature.view(1, 1, -1) + xyxy_null = self.null_position_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings * \ + masks + (1 - masks) * positive_null + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + objs = self.linears( + torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + assert objs.shape == torch.Size([B, N, self.out_dim]) + return objs + + +class Gligen(nn.Module): + def __init__(self, modules, position_net, key_dim): + super().__init__() + self.module_list = nn.ModuleList(modules) + self.position_net = position_net + self.key_dim = key_dim + self.max_objs = 30 + + def _set_position(self, boxes, masks, positive_embeddings): + objs = self.position_net(boxes, masks, positive_embeddings) + + def func(key, x): + module = self.module_list[key] + return module(x, objs) + return func + + def set_position(self, latent_image_shape, position_params, device): + batch, c, h, w = latent_image_shape + masks = torch.zeros([self.max_objs], device="cpu") + boxes = [] + positive_embeddings = [] + for p in position_params: + x1 = (p[4]) / w + y1 = (p[3]) / h + x2 = (p[4] + p[2]) / w + y2 = (p[3] + p[1]) / h + masks[len(boxes)] = 1.0 + boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)] + positive_embeddings += [p[0]] + append_boxes = [] + append_conds = [] + if len(boxes) < self.max_objs: + append_boxes = [torch.zeros( + [self.max_objs - len(boxes), 4], device="cpu")] + append_conds = [torch.zeros( + [self.max_objs - len(boxes), self.key_dim], device="cpu")] + + box_out = torch.cat( + boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1) + masks = masks.unsqueeze(0).repeat(batch, 1) + conds = torch.cat(positive_embeddings + + append_conds).unsqueeze(0).repeat(batch, 1, 1) + return self._set_position( + box_out.to(device), + masks.to(device), + conds.to(device)) + + def set_empty(self, latent_image_shape, device): + batch, c, h, w = latent_image_shape + masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1) + box_out = torch.zeros([self.max_objs, 4], + device="cpu").repeat(batch, 1, 1) + conds = torch.zeros([self.max_objs, self.key_dim], + device="cpu").repeat(batch, 1, 1) + return self._set_position( + box_out.to(device), + masks.to(device), + conds.to(device)) + + def cleanup(self): + pass + + def get_models(self): + return [self] + +def load_gligen(sd): + sd_k = sd.keys() + output_list = [] + key_dim = 768 + for a in ["input_blocks", "middle_block", "output_blocks"]: + for b in range(20): + k_temp = filter(lambda k: "{}.{}.".format(a, b) + in k and ".fuser." in k, sd_k) + k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp) + + n_sd = {} + for k in k_temp: + n_sd[k[1]] = sd[k[0]] + if len(n_sd) > 0: + query_dim = n_sd["linear.weight"].shape[0] + key_dim = n_sd["linear.weight"].shape[1] + + if key_dim == 768: # SD1.x + n_heads = 8 + d_head = query_dim // n_heads + else: + d_head = 64 + n_heads = query_dim // d_head + + gated = GatedSelfAttentionDense( + query_dim, key_dim, n_heads, d_head) + gated.load_state_dict(n_sd, strict=False) + output_list.append(gated) + + if "position_net.null_positive_feature" in sd_k: + in_dim = sd["position_net.null_positive_feature"].shape[0] + out_dim = sd["position_net.linears.4.weight"].shape[0] + + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.position_net = PositionNet(in_dim, out_dim) + w.load_state_dict(sd, strict=False) + + gligen = Gligen(output_list, w.position_net, key_dim) + return gligen diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c8338734..98dbda63 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): + current_index = None + if "current_index" in transformer_options: + current_index = transformer_options["current_index"] + if "patches" in transformer_options: + transformer_patches = transformer_options["patches"] + else: + transformer_patches = {} + n = self.norm1(x) if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) @@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module): n = self.attn1(n, context=context if self.disable_self_attn else None) x += n + if "middle_patch" in transformer_patches: + patch = transformer_patches["middle_patch"] + for p in patch: + x = p(current_index, x) + n = self.norm2(x) n = self.attn2(n, context=context) x += n x = self.ff(self.norm3(x)) + x + + if current_index is not None: + transformer_options["current_index"] += 1 return x diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 8a4e8b3e..4c69c856 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -782,6 +782,8 @@ class UNetModel(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) + transformer_options["current_index"] = 0 + assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" diff --git a/comfy/model_management.py b/comfy/model_management.py index 76455e4a..a0d1313d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -176,7 +176,7 @@ def load_model_gpu(model): model_accelerated = True return current_loaded_model -def load_controlnet_gpu(models): +def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state if vram_state == VRAMState.CPU: @@ -186,6 +186,10 @@ def load_controlnet_gpu(models): #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return + models = [] + for m in control_models: + models += m.get_models() + for m in current_gpu_controlnets: if m not in models: m.cpu() diff --git a/comfy/samplers.py b/comfy/samplers.py index 05af6fe8..31968e18 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con control = None if 'control' in cond[1]: control = cond[1]['control'] - return (input_x, mult, conditionning, area, control) + + patches = None + if 'gligen' in cond[1]: + gligen = cond[1]['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + return (input_x, mult, conditionning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def can_concat_cond(c1, c2): if c1[0].shape != c2[0].shape: return False + + #control if (c1[4] is None) != (c2[4] is None): return False if c1[4] is not None: if c1[4] is not c2[4]: return False + #patches + if (c1[5] is None) != (c2[5] is None): + return False + if (c1[5] is not None): + if c1[5] is not c2[5]: + return False + return cond_equal_size(c1[2], c2[2]) def cond_cat(c_list): @@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con cond_or_uncond = [] area = [] control = None + patches = None for x in to_batch: o = to_run.pop(x) p = o[0] @@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con area += [p[3]] cond_or_uncond += [o[1]] control = p[4] + patches = p[5] batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) @@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if control is not None: c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) + transformer_options = {} if 'transformer_options' in model_options: - c['transformer_options'] = model_options['transformer_options'] + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + transformer_options["patches"] = patches + + c['transformer_options'] = transformer_options output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x @@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] - -def apply_control_net_to_equal_area(conds, uncond): +def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] cond_other = [] uncond_cnets = [] @@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond): for t in range(len(conds)): x = conds[t] if 'area' not in x[1]: - if 'control' in x[1] and x[1]['control'] is not None: - cond_cnets.append(x[1]['control']) + if name in x[1] and x[1][name] is not None: + cond_cnets.append(x[1][name]) else: cond_other.append((x, t)) for t in range(len(uncond)): x = uncond[t] if 'area' not in x[1]: - if 'control' in x[1] and x[1]['control'] is not None: - uncond_cnets.append(x[1]['control']) + if name in x[1] and x[1][name] is not None: + uncond_cnets.append(x[1][name]) else: uncond_other.append((x, t)) @@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond): for x in range(len(cond_cnets)): temp = uncond_other[x % len(uncond_other)] o = temp[0] - if 'control' in o[1] and o[1]['control'] is not None: + if name in o[1] and o[1][name] is not None: n = o[1].copy() - n['control'] = cond_cnets[x] + n[name] = uncond_fill_func(cond_cnets, x) uncond += [[o[0], n]] else: n = o[1].copy() - n['control'] = cond_cnets[x] + n[name] = uncond_fill_func(cond_cnets, x) uncond[temp[1]] = [o[0], n] + def encode_adm(noise_augmentor, conds, batch_size, device): for t in range(len(conds)): x = conds[t] @@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds + class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", @@ -466,7 +498,8 @@ class KSampler: for c in negative: create_cond_with_same_area_if_none(positive, c) - apply_control_net_to_equal_area(positive, negative) + apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if self.model.model.diffusion_model.dtype == torch.float16: precision_scope = torch.autocast diff --git a/comfy/sd.py b/comfy/sd.py index 1d777474..211acd70 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -13,6 +13,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision +from . import gligen def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -378,7 +379,7 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode_from_tokens(self, tokens): + def encode_from_tokens(self, tokens, return_pooled=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: @@ -388,6 +389,10 @@ class CLIP: except Exception as e: self.patcher.unpatch_model() raise e + if return_pooled: + eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__) + pooled = cond[:, eos_token_index] + return cond, pooled return cond def encode(self, text): @@ -564,10 +569,10 @@ class ControlNet: c.strength = self.strength return c - def get_control_models(self): + def get_models(self): out = [] if self.previous_controlnet is not None: - out += self.previous_controlnet.get_control_models() + out += self.previous_controlnet.get_models() out.append(self.control_model) return out @@ -737,10 +742,10 @@ class T2IAdapter: del self.cond_hint self.cond_hint = None - def get_control_models(self): + def get_models(self): out = [] if self.previous_controlnet is not None: - out += self.previous_controlnet.get_control_models() + out += self.previous_controlnet.get_models() return out def load_t2i_adapter(t2i_data): @@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None): clip.load_from_state_dict(clip_data) return clip +def load_gligen(ckpt_path): + data = utils.load_torch_file(ckpt_path) + model = gligen.load_gligen(data) + if model_management.should_use_fp16(): + model = model.half() + return model + def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): with open(config_path, 'r') as stream: config = yaml.safe_load(stream) diff --git a/folder_paths.py b/folder_paths.py index 61f446c9..3c4ad371 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) +folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) + folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) diff --git a/models/gligen/put_gligen_models_here b/models/gligen/put_gligen_models_here new file mode 100644 index 00000000..e69de29b diff --git a/nodes.py b/nodes.py index 06b69f45..8555f272 100644 --- a/nodes.py +++ b/nodes.py @@ -490,6 +490,51 @@ class unCLIPConditioning: c.append(n) return (c, ) +class GLIGENLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}} + + RETURN_TYPES = ("GLIGEN",) + FUNCTION = "load_gligen" + + CATEGORY = "_for_testing/gligen" + + def load_gligen(self, gligen_name): + gligen_path = folder_paths.get_full_path("gligen", gligen_name) + gligen = comfy.sd.load_gligen(gligen_path) + return (gligen,) + +class GLIGENTextBoxApply: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_to": ("CONDITIONING", ), + "clip": ("CLIP", ), + "gligen_textbox_model": ("GLIGEN", ), + "text": ("STRING", {"multiline": True}), + "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "_for_testing/gligen" + + def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): + c = [] + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + for t in conditioning_to: + n = [t[0], t[1].copy()] + position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] + prev = [] + if "gligen" in n[1]: + prev = n[1]['gligen'][2] + + n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params) + c.append(n) + return (c, ) class EmptyLatentImage: def __init__(self, device="cpu"): @@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative_copy = [] control_nets = [] + def get_models(cond): + models = [] + for c in cond: + if 'control' in c[1]: + models += [c[1]['control']] + if 'gligen' in c[1]: + models += [c[1]['gligen'][1]] + return models + for p in positive: t = p[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) - if 'control' in p[1]: - control_nets += [p[1]['control']] positive_copy += [[t] + p[1:]] for n in negative: t = n[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) - if 'control' in n[1]: - control_nets += [n[1]['control']] negative_copy += [[t] + n[1:]] - control_net_models = [] - for x in control_nets: - control_net_models += x.get_control_models() - comfy.model_management.load_controlnet_gpu(control_net_models) + models = get_models(positive) + get_models(negative) + comfy.model_management.load_controlnet_gpu(models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) @@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() - for c in control_nets: - c.cleanup() + for m in models: + m.cleanup() out = latent.copy() out["samples"] = samples @@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeTiled": VAEEncodeTiled, "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, + "GLIGENLoader": GLIGENLoader, + "GLIGENTextBoxApply": GLIGENTextBoxApply, + "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, } From 781b724ac667e42900c331988f356a85670c0ec5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 11:30:18 -0400 Subject: [PATCH 33/35] Add GLIGEN model link to colab. --- notebooks/comfyui_colab.ipynb | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c088de89..c1982d8b 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -138,6 +138,11 @@ "# Controlnet Preprocessor nodes by Fannovel16\n", "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n", "\n", + "\n", + "# GLIGEN\n", + "#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n", + "\n", + "\n", "# ESRGAN upscale model\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", From 2d546d510d1f7919bbae3ac08108e0d05e9c0bae Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 11:47:49 -0400 Subject: [PATCH 34/35] Add gligen entry to extra_model_paths example. --- extra_model_paths.yaml.example | 1 + 1 file changed, 1 insertion(+) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index f421f54d..ac1ffe9d 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -18,6 +18,7 @@ a111: #other_ui: # base_path: path/to/ui # checkpoints: models/checkpoints +# gligen: models/gligen # custom_nodes: path/custom_nodes From 96b57a9ad6447b95921b91e5f52fb3684f73514f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 21:11:38 -0400 Subject: [PATCH 35/35] Don't pass adm to model when it doesn't support it. --- comfy/samplers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 31968e18..19ebc97d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -36,8 +36,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con strength = cond[1]['strength'] adm_cond = None - if 'adm' in cond[1]: - adm_cond = cond[1]['adm'] + if 'adm_encoded' in cond[1]: + adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] mult = torch.ones_like(input_x) * strength @@ -405,7 +405,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): else: adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) x[1] = x[1].copy() - x[1]["adm"] = torch.cat([adm_out] * batch_size) + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) return conds