controlnet.py 47.8 KB
Newer Older
Mikubill's avatar
Mikubill committed
1
import gc
Kakigōri Maker's avatar
Kakigōri Maker committed
2
3
import os
from collections import OrderedDict
ljleb's avatar
ljleb committed
4
from copy import copy
5
from typing import Dict, Optional, Tuple
6
import importlib
Kakigōri Maker's avatar
Kakigōri Maker committed
7
import modules.scripts as scripts
FNSpd's avatar
FNSpd committed
8
from modules import shared, devices, script_callbacks, processing, masking, images
Kakigōri Maker's avatar
Kakigōri Maker committed
9
import gradio as gr
Mikubill's avatar
Mikubill committed
10

Kakigōri Maker's avatar
Kakigōri Maker committed
11
from einops import rearrange
12
from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version, utils
13
from scripts.controlnet_ui import controlnet_ui_group
ljleb's avatar
ljleb committed
14
importlib.reload(processor)
15
importlib.reload(utils)
ljleb's avatar
ljleb committed
16
17
18
importlib.reload(global_state)
importlib.reload(hook)
importlib.reload(external_code)
19
20
21
22
# Reload ui group as `ControlNetUnit` is redefined in `external_code`. If `controlnet_ui_group`
# is not reloaded, `UiControlNetUnit` will inherit from a stale version of `ControlNetUnit`,
# which can cause typecheck to fail.
importlib.reload(controlnet_ui_group)  
ljleb's avatar
ljleb committed
23
importlib.reload(batch_hijack)
ljleb's avatar
ljleb committed
24
from scripts.cldm import PlugableControlModel
25
from scripts.processor import *
ljleb's avatar
ljleb committed
26
27
from scripts.adapter import PlugableAdapter
from scripts.utils import load_state_dict
lvmin's avatar
lvmin committed
28
from scripts.hook import ControlParams, UnetHook, ControlModelType
29
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
Chenlei Hu's avatar
Chenlei Hu committed
30
from scripts.logging import logger
31
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img
105gun's avatar
105gun committed
32
from modules.images import save_image
33

lvmin's avatar
lvmin committed
34
import cv2
35
36
37
import numpy as np
import torch

lvmin's avatar
lvmin committed
38
from pathlib import Path
39
from PIL import Image, ImageFilter, ImageOps
lvmin's avatar
lvmin committed
40
from scripts.lvminthin import lvmin_thin, nake_nms
41
from scripts.processor import model_free_preprocessors
Kakigōri Maker's avatar
Kakigōri Maker committed
42

Mikubill's avatar
Mikubill committed
43
44
45
46
47
48
49
50
51
gradio_compat = True
try:
    from distutils.version import LooseVersion
    from importlib_metadata import version
    if LooseVersion(version("gradio")) < LooseVersion("3.10"):
        gradio_compat = False
except ImportError:
    pass

lllyasviel's avatar
lllyasviel committed
52
53
54
55
56
57
58

# Gradio 3.32 bug fix
import tempfile
gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio')
os.makedirs(gradio_tempfile_path, exist_ok=True)


Kakigōri Maker's avatar
Kakigōri Maker committed
59
60
61
def find_closest_lora_model_name(search: str):
    if not search:
        return None
ljleb's avatar
ljleb committed
62
    if search in global_state.cn_models:
Kakigōri Maker's avatar
Kakigōri Maker committed
63
64
        return search
    search = search.lower()
ljleb's avatar
ljleb committed
65
66
67
    if search in global_state.cn_models_names:
        return global_state.cn_models_names.get(search)
    applicable = [name for name in global_state.cn_models_names.keys()
Kakigōri Maker's avatar
Kakigōri Maker committed
68
69
70
71
                  if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
ljleb's avatar
ljleb committed
72
    return global_state.cn_models_names[applicable[0]]
Kakigōri Maker's avatar
Kakigōri Maker committed
73
74


75
76
77
78
79
80
81
82
83
def swap_img2img_pipeline(p: processing.StableDiffusionProcessingImg2Img):
    p.__class__ = processing.StableDiffusionProcessingTxt2Img
    dummy = processing.StableDiffusionProcessingTxt2Img()
    for k,v in dummy.__dict__.items():
        if hasattr(p, k):
            continue
        setattr(p, k, v)


ljleb's avatar
ljleb committed
84
global_state.update_cn_models()
Kakigōri Maker's avatar
Kakigōri Maker committed
85

ljleb's avatar
ljleb committed
86
87

def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
ljleb's avatar
ljleb committed
88
89
90
91
92
93
94
    if image is None:
        return None

    if isinstance(image, (tuple, list)):
        image = {'image': image[0], 'mask': image[1]}
    elif not isinstance(image, dict):
        image = {'image': image, 'mask': None}
ljleb's avatar
ljleb committed
95
96
97
    else:  # type(image) is dict
        # copy to enable modifying the dict and prevent response serialization error
        image = dict(image)
ljleb's avatar
ljleb committed
98

ljleb's avatar
ljleb committed
99
    if isinstance(image['image'], str):
ljleb's avatar
ljleb committed
100
        if os.path.exists(image['image']):
101
            image['image'] = np.array(Image.open(image['image'])).astype('uint8')
102
        elif image['image']:
ljleb's avatar
ljleb committed
103
            image['image'] = external_code.to_base64_nparray(image['image'])
104
        else:
105
            image['image'] = None            
ljleb's avatar
ljleb committed
106

107
108
109
110
    # If there is no image, return image with None image and None mask
    if image['image'] is None:
        image['mask'] = None
        return image
ljleb's avatar
ljleb committed
111

ljleb's avatar
ljleb committed
112
    if isinstance(image['mask'], str):
ljleb's avatar
ljleb committed
113
        if os.path.exists(image['mask']):
114
            image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
115
        elif image['mask']:
ljleb's avatar
ljleb committed
116
            image['mask'] = external_code.to_base64_nparray(image['mask'])
117
118
        else:
            image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
ljleb's avatar
ljleb committed
119
120
    elif image['mask'] is None:
        image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
ljleb's avatar
ljleb committed
121

ljleb's avatar
ljleb committed
122
123
124
    return image


Chenlei Hu's avatar
Chenlei Hu committed
125
def image_has_mask(input_image: np.ndarray) -> bool:
126
127
128
129
130
131
132
133
134
135
    """
    Determine if an image has an alpha channel (mask) that is not empty.

    The function checks if the input image has three dimensions (height, width, channels), 
    and if the third dimension (channel dimension) is of size 4 (presumably RGB + alpha). 
    Then it checks if the maximum value in the alpha channel is greater than 127. This is 
    presumably to check if there is any non-transparent (or semi-transparent) pixel in the 
    image. A pixel is considered non-transparent if its alpha value is above 127.

    Args:
Chenlei Hu's avatar
Chenlei Hu committed
136
137
        input_image (np.ndarray): A 3D numpy array representing an image. The dimensions 
        should represent [height, width, channels].
138
139
140
141
142
143
144
145
146
147
148

    Returns:
        bool: True if the image has a non-empty alpha channel, False otherwise.
    """    
    return (
        input_image.ndim == 3 and 
        input_image.shape[2] == 4 and 
        np.max(input_image[:, :, 3]) > 127
    )


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def prepare_mask(
    mask: Image.Image, p: processing.StableDiffusionProcessing
) -> Image.Image:
    """
    Prepare an image mask for the inpainting process.

    This function takes as input a PIL Image object and an instance of the 
    StableDiffusionProcessing class, and performs the following steps to prepare the mask:

    1. Convert the mask to grayscale (mode "L").
    2. If the 'inpainting_mask_invert' attribute of the processing instance is True,
       invert the mask colors.
    3. If the 'mask_blur' attribute of the processing instance is greater than 0,
       apply a Gaussian blur to the mask with a radius equal to 'mask_blur'.

    Args:
        mask (Image.Image): The input mask as a PIL Image object.
        p (processing.StableDiffusionProcessing): An instance of the StableDiffusionProcessing class 
                                                   containing the processing parameters.

    Returns:
        mask (Image.Image): The prepared mask as a PIL Image object.
    """
    mask = mask.convert("L")
    if getattr(p, "inpainting_mask_invert", False):
        mask = ImageOps.invert(mask)
    if getattr(p, "mask_blur", 0) > 0:
        mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
    return mask


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]:
    """
    Set the random seed for NumPy based on the provided parameters.

    Args:
        p (processing.StableDiffusionProcessing): The instance of the StableDiffusionProcessing class.

    Returns:
        Optional[int]: The computed random seed if successful, or None if an exception occurs.

    This function sets the random seed for NumPy using the seed and subseed values from the given instance of
    StableDiffusionProcessing. If either seed or subseed is -1, it uses the first value from `all_seeds`.
    Otherwise, it takes the maximum of the provided seed value and 0.

    The final random seed is computed by adding the seed and subseed values, applying a bitwise AND operation
    with 0xFFFFFFFF to ensure it fits within a 32-bit integer.
    """
    try:
        tmp_seed = int(p.all_seeds[0] if p.seed == -1 else max(int(p.seed), 0))
        tmp_subseed = int(p.all_seeds[0] if p.subseed == -1 else max(int(p.subseed), 0))
        seed = (tmp_seed + tmp_subseed) & 0xFFFFFFFF
        np.random.seed(seed)
        return seed
    except Exception as e:
        logger.warning(e)
        logger.warning('Warning: Failed to use consistent random seed.')
        return None


Kakigōri Maker's avatar
Kakigōri Maker committed
209
class Script(scripts.Script):
Jairo Correa's avatar
Jairo Correa committed
210
    model_cache = OrderedDict()
ljleb's avatar
ljleb committed
211

Kakigōri Maker's avatar
Kakigōri Maker committed
212
213
214
    def __init__(self) -> None:
        super().__init__()
        self.latest_network = None
215
        self.preprocessor = global_state.cache_preprocessors(global_state.cn_preprocessor_modules)
216
        self.unloadable = global_state.cn_preprocessor_unloadable
Kakigōri Maker's avatar
Kakigōri Maker committed
217
218
        self.input_image = None
        self.latest_model_hash = ""
ljleb's avatar
ljleb committed
219
220
        self.enabled_units = []
        self.detected_map = []
lllyasviel's avatar
lllyasviel committed
221
        self.post_processors = []
ljleb's avatar
ljleb committed
222
223
224
225
        batch_hijack.instance.process_batch_callbacks.append(self.batch_tab_process)
        batch_hijack.instance.process_batch_each_callbacks.append(self.batch_tab_process_each)
        batch_hijack.instance.postprocess_batch_each_callbacks.insert(0, self.batch_tab_postprocess_each)
        batch_hijack.instance.postprocess_batch_callbacks.insert(0, self.batch_tab_postprocess)
Kakigōri Maker's avatar
Kakigōri Maker committed
226
227

    def title(self):
ljleb's avatar
ljleb committed
228
        return "ControlNet"
Kakigōri Maker's avatar
Kakigōri Maker committed
229
230
231

    def show(self, is_img2img):
        return scripts.AlwaysVisible
lvmin's avatar
lvmin committed
232

233
234
    @staticmethod
    def get_default_ui_unit(is_ui=True):
ljleb's avatar
ljleb committed
235
236
        cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit
        return cls(
ljleb's avatar
ljleb committed
237
238
            enabled=False,
            module="none",
lvmin's avatar
lvmin committed
239
            model="None"
ljleb's avatar
ljleb committed
240
        )
241

242
243
244
245
    def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str):
        group = ControlNetUiGroup(
            gradio_compat,
            self.infotext_fields,
246
            Script.get_default_ui_unit(),
247
248
249
250
251
            self.preprocessor,
        )
        group.render(tabname, elem_id_tabname)
        group.register_callbacks(is_img2img)
        return group.render_and_register_unit(tabname, is_img2img)
Kakigōri Maker's avatar
Kakigōri Maker committed
252

253
254
255
256
257
    def ui(self, is_img2img):
        """this function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be an array of all components that are used in processing.
        Values of those returned components will be passed to run() and process() functions.
        """
258
        self.infotext_fields = []
259
        self.paste_field_names = []
ljleb's avatar
ljleb committed
260
        controls = ()
261
        max_models = shared.opts.data.get("control_net_max_models_num", 1)
262
263
        elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet"
        with gr.Group(elem_id=elem_id_tabname):
264
            with gr.Accordion(f"ControlNet {controlnet_version.version_flag}", open = False, elem_id="controlnet"):
LonicaMewinsky's avatar
LonicaMewinsky committed
265
                if max_models > 1:
266
                    with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"):
ljleb's avatar
ljleb committed
267
                        for i in range(max_models):
268
269
                            with gr.Tab(f"ControlNet Unit {i}", 
                                        elem_classes=['cnet-unit-tab']):
270
                                controls += (self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname),)
LonicaMewinsky's avatar
LonicaMewinsky committed
271
272
                else:
                    with gr.Column():
273
                        controls += (self.uigroup(f"ControlNet", is_img2img, elem_id_tabname),)
lvmin's avatar
lvmin committed
274

275
276
277
        if shared.opts.data.get("control_net_sync_field_args", False):
            for _, field_name in self.infotext_fields:
                self.paste_field_names.append(field_name)
ljleb's avatar
ljleb committed
278

Kakigōri Maker's avatar
Kakigōri Maker committed
279
        return controls
280
281
282
    
    @staticmethod
    def clear_control_model_cache():
Jairo Correa's avatar
Jairo Correa committed
283
284
285
286
        Script.model_cache.clear()
        gc.collect()
        devices.torch_gc()

287
288
    @staticmethod
    def load_control_model(p, unet, model, lowvram):
Jairo Correa's avatar
Jairo Correa committed
289
        if model in Script.model_cache:
Chenlei Hu's avatar
Chenlei Hu committed
290
            logger.info(f"Loading model from cache: {model}")
Jairo Correa's avatar
Jairo Correa committed
291
292
293
294
295
296
297
298
            return Script.model_cache[model]

        # Remove model from cache to clear space before building another model
        if len(Script.model_cache) > 0 and len(Script.model_cache) >= shared.opts.data.get("control_net_model_cache_size", 2):
            Script.model_cache.popitem(last=False)
            gc.collect()
            devices.torch_gc()

299
        model_net = Script.build_control_model(p, unet, model, lowvram)
Jairo Correa's avatar
Jairo Correa committed
300
301
302
303
304
305

        if shared.opts.data.get("control_net_model_cache_size", 2) > 0:
            Script.model_cache[model] = model_net

        return model_net

306
307
    @staticmethod
    def build_control_model(p, unet, model, lowvram):
308
309
310
        if model is None or model == 'None':
            raise RuntimeError(f"You have not selected any ControlNet Model.")

ljleb's avatar
ljleb committed
311
        model_path = global_state.cn_models.get(model, None)
312
313
314
315
        if model_path is None:
            model = find_closest_lora_model_name(model)
            model_path = global_state.cn_models.get(model, None)

316
317
318
319
320
321
322
323
324
325
        if model_path is None:
            raise RuntimeError(f"model not found: {model}")

        # trim '"' at start/end
        if model_path.startswith("\"") and model_path.endswith("\""):
            model_path = model_path[1:-1]

        if not os.path.exists(model_path):
            raise ValueError(f"file not found: {model_path}")

Chenlei Hu's avatar
Chenlei Hu committed
326
        logger.info(f"Loading model: {model}")
327
328
        state_dict = load_state_dict(model_path)
        network_module = PlugableControlModel
ljleb's avatar
ljleb committed
329
        network_config = shared.opts.data.get("control_net_model_config", global_state.default_conf)
Mikubill's avatar
Mikubill committed
330
        if not os.path.isabs(network_config):
ljleb's avatar
ljleb committed
331
            network_config = os.path.join(global_state.script_dir, network_config)
332

333
        if any([k.startswith("body.") or k == 'style_embedding' for k, v in state_dict.items()]):
lvmin's avatar
lvmin committed
334
            # adapter model
335
            network_module = PlugableAdapter
ljleb's avatar
ljleb committed
336
            network_config = shared.opts.data.get("control_net_model_adapter_config", global_state.default_conf_adapter)
Mikubill's avatar
Mikubill committed
337
            if not os.path.isabs(network_config):
ljleb's avatar
ljleb committed
338
                network_config = os.path.join(global_state.script_dir, network_config)
lvmin's avatar
lvmin committed
339

lvmin's avatar
lvmin committed
340
        model_path = os.path.abspath(model_path)
lvmin's avatar
lvmin committed
341
        model_stem = Path(model_path).stem
lvmin's avatar
lvmin committed
342
343
344
345
346
347
348
349
        model_dir_name = os.path.dirname(model_path)

        possible_config_filenames = [
            os.path.join(model_dir_name, model_stem + ".yaml"),
            os.path.join(global_state.script_dir, 'models', model_stem + ".yaml"),
            os.path.join(model_dir_name, model_stem.replace('_fp16', '') + ".yaml"),
            os.path.join(global_state.script_dir, 'models', model_stem.replace('_fp16', '') + ".yaml"),
            os.path.join(model_dir_name, model_stem.replace('_diff', '') + ".yaml"),
lvmin's avatar
lvmin committed
350
351
352
353
354
            os.path.join(global_state.script_dir, 'models', model_stem.replace('_diff', '') + ".yaml"),
            os.path.join(model_dir_name, model_stem.replace('-fp16', '') + ".yaml"),
            os.path.join(global_state.script_dir, 'models', model_stem.replace('-fp16', '') + ".yaml"),
            os.path.join(model_dir_name, model_stem.replace('-diff', '') + ".yaml"),
            os.path.join(global_state.script_dir, 'models', model_stem.replace('-diff', '') + ".yaml")
lvmin's avatar
lvmin committed
355
356
357
358
359
360
361
362
        ]

        override_config = possible_config_filenames[0]

        for possible_config_filename in possible_config_filenames:
            if os.path.exists(possible_config_filename):
                override_config = possible_config_filename
                break
lvmin's avatar
lvmin committed
363

lvmin's avatar
lvmin committed
364
        if 'v11' in model_stem.lower() or 'shuffle' in model_stem.lower():
lvmin's avatar
lvmin committed
365
366
            assert os.path.exists(override_config), f'Error: The model config {override_config} is missing. ControlNet 1.1 must have configs.'

367
368
        if os.path.exists(override_config):
            network_config = override_config
lvmin's avatar
lvmin committed
369
        else:
Chenlei Hu's avatar
Chenlei Hu committed
370
371
            # Note: This error is triggered in unittest, but not caught.
            # TODO: Replace `print` with `logger.error`.
372
            print(f'ERROR: ControlNet cannot find model config [{override_config}] \n'
lvmin's avatar
lvmin committed
373
374
375
376
377
378
379
                  f'ERROR: ControlNet will use a WRONG config [{network_config}] to load your model. \n'
                  f'ERROR: The WRONG config may not match your model. The generated results can be bad. \n'
                  f'ERROR: You are using a ControlNet model [{model_stem}] without correct YAML config file. \n'
                  f'ERROR: The performance of this model may be worse than your expectation. \n'
                  f'ERROR: If this model cannot get good results, the reason is that you do not have a YAML file for the model. \n'
                  f'Solution: Please download YAML file, or ask your model provider to provide [{override_config}] for you to download.\n'
                  f'Hint: You can take a look at [{os.path.join(global_state.script_dir, "models")}] to find many existing YAML files.\n')
380

Chenlei Hu's avatar
Chenlei Hu committed
381
        logger.info(f"Loading config: {network_config}")
382
        network = network_module(
lvmin's avatar
lvmin committed
383
384
            state_dict=state_dict,
            config_path=network_config,
385
386
387
388
            lowvram=lowvram,
            base_model=unet,
        )
        network.to(p.sd_model.device, dtype=p.sd_model.dtype)
Chenlei Hu's avatar
Chenlei Hu committed
389
        logger.info(f"ControlNet model {model} loaded.")
390
        return network
391
392
393
394
395
396

    @staticmethod
    def get_remote_call(p, attribute, default=None, idx=0, strict=False, force=False):
        if not force and not shared.opts.data.get("control_net_allow_script_control", False):
            return default

ljleb's avatar
ljleb committed
397
        def get_element(obj, strict=False):
398
399
400
401
402
403
404
            if not isinstance(obj, list):
                return obj if not strict or idx == 0 else None
            elif idx < len(obj):
                return obj[idx]
            else:
                return None

ljleb's avatar
ljleb committed
405
406
        attribute_value = get_element(getattr(p, attribute, None), strict)
        default_value = get_element(default)
407
408
        return attribute_value if attribute_value is not None else default_value

409
410
411
    @staticmethod
    def parse_remote_call(p, unit: external_code.ControlNetUnit, idx):
        selector = Script.get_remote_call
ljleb's avatar
draft    
ljleb committed
412

ljleb's avatar
ljleb committed
413
414
415
416
417
418
419
420
421
422
423
424
        unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True)
        unit.module = selector(p, "control_net_module", unit.module, idx)
        unit.model = selector(p, "control_net_model", unit.model, idx)
        unit.weight = selector(p, "control_net_weight", unit.weight, idx)
        unit.image = selector(p, "control_net_image", unit.image, idx)
        unit.resize_mode = selector(p, "control_net_resize_mode", unit.resize_mode, idx)
        unit.low_vram = selector(p, "control_net_lowvram", unit.low_vram, idx)
        unit.processor_res = selector(p, "control_net_pres", unit.processor_res, idx)
        unit.threshold_a = selector(p, "control_net_pthr_a", unit.threshold_a, idx)
        unit.threshold_b = selector(p, "control_net_pthr_b", unit.threshold_b, idx)
        unit.guidance_start = selector(p, "control_net_guidance_start", unit.guidance_start, idx)
        unit.guidance_end = selector(p, "control_net_guidance_end", unit.guidance_end, idx)
425
426
        # Backward compatibility. See https://github.com/Mikubill/sd-webui-controlnet/issues/1740
        # for more details.
ljleb's avatar
ljleb committed
427
        unit.guidance_end = selector(p, "control_net_guidance_strength", unit.guidance_end, idx)
lvmin's avatar
lvmin committed
428
        unit.control_mode = selector(p, "control_net_control_mode", unit.control_mode, idx)
lvmin's avatar
fix api    
lvmin committed
429
        unit.pixel_perfect = selector(p, "control_net_pixel_perfect", unit.pixel_perfect, idx)
ljleb's avatar
ljleb committed
430
431
432

        return unit

433
434
    @staticmethod
    def detectmap_proc(detected_map, module, resize_mode, h, w):
lvmin's avatar
lvmin committed
435

lvmin's avatar
lvmin committed
436
        if 'inpaint' in module:
lvmin's avatar
lvmin committed
437
            detected_map = detected_map.astype(np.float32)
438
        else:
lvmin's avatar
lvmin committed
439
            detected_map = HWC3(detected_map)
lvmin's avatar
lvmin committed
440

lvmin's avatar
lvmin committed
441
        def safe_numpy(x):
lvmin's avatar
lvmin committed
442
443
444
445
446
            # A very safe method to make sure that Apple/Mac works
            y = x

            # below is very boring but do not change these. If you change these Apple or Mac may fail.
            y = y.copy()
lvmin's avatar
lvmin committed
447
            y = np.ascontiguousarray(y)
lvmin's avatar
lvmin committed
448
            y = y.copy()
lvmin's avatar
lvmin committed
449
450
451
452
453
454
455
            return y

        def get_pytorch_control(x):
            # A very safe method to make sure that Apple/Mac works
            y = x

            # below is very boring but do not change these. If you change these Apple or Mac may fail.
lvmin's avatar
lvmin committed
456
457
            y = torch.from_numpy(y)
            y = y.float() / 255.0
lllyasviel's avatar
lllyasviel committed
458
            y = rearrange(y, 'h w c -> 1 c h w')
lvmin's avatar
lvmin committed
459
            y = y.clone()
lvmin's avatar
lvmin committed
460
461
462
            y = y.to(devices.get_device_for("controlnet"))
            y = y.clone()
            return y
lvmin's avatar
lvmin committed
463

lvmin's avatar
lvmin committed
464
        def high_quality_resize(x, size):
lvmin's avatar
lvmin committed
465
466
467
            # Written by lvmin
            # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges

lvmin's avatar
lvmin committed
468
469
470
471
472
            inpaint_mask = None
            if x.ndim == 3 and x.shape[2] == 4:
                inpaint_mask = x[:, :, 3]
                x = x[:, :, 0:3]

lvmin's avatar
lvmin committed
473
            new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
lvmin's avatar
lvmin committed
474
            new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
lvmin's avatar
lvmin committed
475
476
477
478
479
            unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
            is_one_pixel_edge = False
            is_binary = False
            if unique_color_count == 2:
                is_binary = np.min(x) < 16 and np.max(x) > 240
lvmin's avatar
lvmin committed
480
481
482
483
484
485
486
                if is_binary:
                    xc = x
                    xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
                    xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
                    one_pixel_edge_count = np.where(xc < x)[0].shape[0]
                    all_edge_count = np.where(x > 127)[0].shape[0]
                    is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
lvmin's avatar
lvmin committed
487

lvmin's avatar
lvmin committed
488
            if 2 < unique_color_count < 200:
lvmin's avatar
lvmin committed
489
490
491
492
                interpolation = cv2.INTER_NEAREST
            elif new_size_is_smaller:
                interpolation = cv2.INTER_AREA
            else:
lvmin's avatar
lvmin committed
493
                interpolation = cv2.INTER_CUBIC  # Must be CUBIC because we now use nms. NEVER CHANGE THIS
lvmin's avatar
lvmin committed
494
495

            y = cv2.resize(x, size, interpolation=interpolation)
lvmin's avatar
lvmin committed
496
497
            if inpaint_mask is not None:
                inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
lvmin's avatar
lvmin committed
498
499

            if is_binary:
lvmin's avatar
lvmin committed
500
                y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
lvmin's avatar
lvmin committed
501
                if is_one_pixel_edge:
lvmin's avatar
lvmin committed
502
503
                    y = nake_nms(y)
                    _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
lvmin's avatar
lvmin committed
504
                    y = lvmin_thin(y, prunings=new_size_is_bigger)
lvmin's avatar
lvmin committed
505
506
                else:
                    _, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
lvmin's avatar
lvmin committed
507
508
                y = np.stack([y] * 3, axis=2)

lvmin's avatar
lvmin committed
509
            if inpaint_mask is not None:
lllyasviel's avatar
lllyasviel committed
510
511
512
                inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
                inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
                y = np.concatenate([y, inpaint_mask], axis=2)
lvmin's avatar
lvmin committed
513

lvmin's avatar
lvmin committed
514
            return y
lvmin's avatar
lvmin committed
515

lvmin's avatar
lvmin committed
516
        if resize_mode == external_code.ResizeMode.RESIZE:
lvmin's avatar
lvmin committed
517
            detected_map = high_quality_resize(detected_map, (w, h))
lvmin's avatar
lvmin committed
518
            detected_map = safe_numpy(detected_map)
lvmin's avatar
lvmin committed
519
520
521
            return get_pytorch_control(detected_map), detected_map

        old_h, old_w, _ = detected_map.shape
lvmin's avatar
lvmin committed
522
523
524
525
526
527
        old_w = float(old_w)
        old_h = float(old_h)
        k0 = float(h) / old_h
        k1 = float(w) / old_w

        safeint = lambda x: int(np.round(x))
lvmin's avatar
lvmin committed
528

lvmin's avatar
lvmin committed
529
530
        if resize_mode == external_code.ResizeMode.OUTER_FIT:
            k = min(k0, k1)
lllyasviel's avatar
lllyasviel committed
531
            borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
lvmin's avatar
lvmin committed
532
            high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
lllyasviel's avatar
lllyasviel committed
533
534
535
            if len(high_quality_border_color) == 4:
                # Inpaint hijack
                high_quality_border_color[3] = 255
lvmin's avatar
lvmin committed
536
            high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
lvmin's avatar
lvmin committed
537
            detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
lvmin's avatar
lvmin committed
538
            new_h, new_w, _ = detected_map.shape
lvmin's avatar
lvmin committed
539
540
            pad_h = max(0, (h - new_h) // 2)
            pad_w = max(0, (w - new_w) // 2)
541
            high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map
lvmin's avatar
lvmin committed
542
            detected_map = high_quality_background
lvmin's avatar
lvmin committed
543
            detected_map = safe_numpy(detected_map)
lvmin's avatar
lvmin committed
544
            return get_pytorch_control(detected_map), detected_map
545
        else:
lvmin's avatar
lvmin committed
546
            k = max(k0, k1)
lvmin's avatar
lvmin committed
547
            detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
lvmin's avatar
lvmin committed
548
            new_h, new_w, _ = detected_map.shape
lvmin's avatar
lvmin committed
549
550
            pad_h = max(0, (new_h - h) // 2)
            pad_w = max(0, (new_w - w) // 2)
lvmin's avatar
lvmin committed
551
            detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w]
lvmin's avatar
lvmin committed
552
            detected_map = safe_numpy(detected_map)
lvmin's avatar
lvmin committed
553
            return get_pytorch_control(detected_map), detected_map
Ashen's avatar
Ashen committed
554

555
556
    @staticmethod
    def get_enabled_units(p):
ljleb's avatar
ljleb committed
557
        units = external_code.get_all_units_in_processing(p)
ljleb's avatar
ljleb committed
558
        enabled_units = []
ljleb's avatar
ljleb committed
559
560

        if len(units) == 0:
561
            # fill a null group
562
            remote_unit = Script.parse_remote_call(p, Script.get_default_ui_unit(), 0)
ljleb's avatar
ljleb committed
563
            if remote_unit.enabled:
ljleb's avatar
ljleb committed
564
                units.append(remote_unit)
565

ljleb's avatar
ljleb committed
566
        for idx, unit in enumerate(units):
567
            unit = Script.parse_remote_call(p, unit, idx)
ljleb's avatar
ljleb committed
568
            if not unit.enabled:
569
                continue
ljleb's avatar
ljleb committed
570

ljleb's avatar
ljleb committed
571
572
            enabled_units.append(copy(unit))
            if len(units) != 1:
lvmin's avatar
lvmin committed
573
                log_key = f"ControlNet {idx}"
105gun's avatar
105gun committed
574
            else:
lvmin's avatar
lvmin committed
575
576
577
578
579
580
581
582
583
584
585
586
                log_key = "ControlNet"

            log_value = {
                "preprocessor": unit.module,
                "model": unit.model,
                "weight": unit.weight,
                "starting/ending": str((unit.guidance_start, unit.guidance_end)),
                "resize mode": str(unit.resize_mode),
                "pixel perfect": str(unit.pixel_perfect),
                "control mode": str(unit.control_mode),
                "preprocessor params": str((unit.processor_res, unit.threshold_a, unit.threshold_b)),
            }
lvmin's avatar
lvmin committed
587
            log_value = str(log_value).replace('\'', '').replace('{', '').replace('}', '')
lvmin's avatar
lvmin committed
588

lvmin's avatar
lvmin committed
589
            p.extra_generation_params.update({log_key: log_value})
590

ljleb's avatar
ljleb committed
591
592
        return enabled_units

593
594
595
596
597
    @staticmethod
    def choose_input_image(
            p: processing.StableDiffusionProcessing, 
            unit: external_code.ControlNetUnit,
            idx: int
598
        ) -> Tuple[np.ndarray, bool]:
599
600
601
602
603
604
605
606
607
608
        """ Choose input image from following sources with descending priority:
         - p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
         - p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
         - unit.image: 
           - ControlNet tab input image.
           - Input image from API call.
         - p.init_images: A1111 img2img tab input image.

        Returns:
            - The input image in ndarray form.
609
            - Whether input image is from A1111.
610
        """
611
        image_from_a1111 = False
612
613
614
615
616

        p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
        image = image_dict_from_any(unit.image)

        if batch_hijack.instance.is_batch and getattr(p, "image_control", None) is not None:
Chenlei Hu's avatar
Chenlei Hu committed
617
            logger.warning("Warn: Using legacy field 'p.image_control'.")
618
619
            input_image = HWC3(np.asarray(p.image_control))
        elif p_input_image is not None:
Chenlei Hu's avatar
Chenlei Hu committed
620
            logger.warning("Warn: Using legacy field 'p.controlnet_input_image'")
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
            if isinstance(p_input_image, dict) and "mask" in p_input_image and "image" in p_input_image:
                color = HWC3(np.asarray(p_input_image['image']))
                alpha = np.asarray(p_input_image['mask'])[..., None]
                input_image = np.concatenate([color, alpha], axis=2)
            else:
                input_image = HWC3(np.asarray(p_input_image))
        elif image is not None:
            while len(image['mask'].shape) < 3:
                image['mask'] = image['mask'][..., np.newaxis]

            # Need to check the image for API compatibility
            if isinstance(image['image'], str):
                from modules.api.api import decode_base64_to_image
                input_image = HWC3(np.asarray(decode_base64_to_image(image['image'])))
            else:
                input_image = HWC3(image['image'])

            have_mask = 'mask' in image and not ((image['mask'][:, :, 0] == 0).all() or (image['mask'][:, :, 0] == 255).all())

            if 'inpaint' in unit.module:
Chenlei Hu's avatar
Chenlei Hu committed
641
                logger.info("using inpaint as input")
642
643
644
645
646
647
648
649
                color = HWC3(image['image'])
                if have_mask:
                    alpha = image['mask'][:, :, 0:1]
                else:
                    alpha = np.zeros_like(color)[:, :, 0:1]
                input_image = np.concatenate([color, alpha], axis=2)
            else:
                if have_mask:
Chenlei Hu's avatar
Chenlei Hu committed
650
                    logger.info("using mask as input")
651
652
653
654
655
656
657
658
659
660
661
                    input_image = HWC3(image['mask'][:, :, 0])
                    unit.module = 'none'  # Always use black bg and white line
        else:
            # use img2img init_image as default
            input_image = getattr(p, "init_images", [None])[0]
            if input_image is None:
                if batch_hijack.instance.is_batch:
                    shared.state.interrupted = True
                raise ValueError('controlnet is enabled but no input image is given')

            input_image = HWC3(np.asarray(input_image))
662
            image_from_a1111 = True
663
664
        
        assert isinstance(input_image, np.ndarray)
665
        return input_image, image_from_a1111
666
    
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    @staticmethod
    def bound_check_params(unit: external_code.ControlNetUnit) -> None:
        """
        Checks and corrects negative parameters in ControlNetUnit 'unit'.
        Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to 
        their default values if negative.
        
        Args:
            unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
        """
        cfg = preprocessor_sliders_config.get(
            global_state.get_module_basename(unit.module), [])
        defaults = {
            param: cfg_default['value']
            for param, cfg_default in zip(
                ("processor_res", 'threshold_a', 'threshold_b'), cfg)
            if cfg_default is not None
        }
        for param, default_value in defaults.items():
            value = getattr(unit, param)
            if value < 0:
                setattr(unit, param, default_value)
689
                logger.warning(f'[{unit.module}.{param}] Invalid value({value}), using default value {default_value}.')
690

ljleb's avatar
ljleb committed
691
692
693
694
695
696
    def process(self, p, *args):
        """
        This function is called before processing begins for AlwaysVisible scripts.
        You can modify the processing object (p) here, inject hooks, etc.
        args contains all values returned by components from ui()
        """
697
698
699
        sd_ldm = p.sd_model
        unet = sd_ldm.model.diffusion_model

700
701
        setattr(p, 'controlnet_initial_noise_modifier', None)

ljleb's avatar
ljleb committed
702
703
704
705
706
        if self.latest_network is not None:
            # always restore (~0.05s)
            self.latest_network.restore(unet)

        if not batch_hijack.instance.is_batch:
707
            self.enabled_units = Script.get_enabled_units(p)
ljleb's avatar
ljleb committed
708
709

        if len(self.enabled_units) == 0:
Mikubill's avatar
Mikubill committed
710
           self.latest_network = None
lvmin's avatar
lvmin committed
711
           return
712

713
714
        detected_maps = []
        forward_params = []
lllyasviel's avatar
lllyasviel committed
715
        post_processors = []
lvmin's avatar
lvmin committed
716

Mikubill's avatar
Mikubill committed
717
        # cache stuff
Jairo Correa's avatar
Jairo Correa committed
718
        if self.latest_model_hash != p.sd_model.sd_model_hash:
719
            Script.clear_control_model_cache()
ljleb's avatar
ljleb committed
720

Mikubill's avatar
Mikubill committed
721
        # unload unused preproc
ljleb's avatar
ljleb committed
722
        module_list = [unit.module for unit in self.enabled_units]
723
        for key in self.unloadable:
Mikubill's avatar
Mikubill committed
724
            if key not in module_list:
ljleb's avatar
ljleb committed
725
                self.unloadable.get(key, lambda:None)()
ljleb's avatar
ljleb committed
726

Mikubill's avatar
Mikubill committed
727
        self.latest_model_hash = p.sd_model.sd_model_hash
ljleb's avatar
ljleb committed
728
        for idx, unit in enumerate(self.enabled_units):
729
730
            Script.bound_check_params(unit)

731
            unit.module = global_state.get_module_basename(unit.module)
ljleb's avatar
ljleb committed
732
            resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
ljleb's avatar
ljleb committed
733
            control_mode = external_code.control_mode_from_value(unit.control_mode)
ljleb's avatar
ljleb committed
734

735
736
737
            if unit.module in model_free_preprocessors:
                model_net = None
            else:
738
                model_net = Script.load_control_model(p, unet, unit.model, unit.low_vram)
739
                model_net.reset()
Jairo Correa's avatar
Jairo Correa committed
740

741
742
743
744
745
            input_image, image_from_a1111 = Script.choose_input_image(p, unit, idx)
            if image_from_a1111:
                a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
                if a1111_i2i_resize_mode is not None:
                    resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
746
            
747
748
749
            a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
            if 'inpaint' in unit.module and not image_has_mask(input_image) and a1111_mask_image is not None:
                a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
lvmin's avatar
lvmin committed
750
751
752
753
                if a1111_mask.ndim == 2:
                    if a1111_mask.shape[0] == input_image.shape[0]:
                        if a1111_mask.shape[1] == input_image.shape[1]:
                            input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
lvmin's avatar
lvmin committed
754
755
                            a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
                            if a1111_i2i_resize_mode is not None:
756
                                resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
ljleb's avatar
ljleb committed
757

758
            if 'reference' not in unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) \
759
                    and p.inpaint_full_res and a1111_mask_image is not None:
760

lvmin's avatar
lvmin committed
761
762
763
                input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
                input_image = [Image.fromarray(x) for x in input_image]

764
                mask = prepare_mask(a1111_mask_image, p)
765

766
767
768
                crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding)
                crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height)

769
770
771
772
                input_image = [
                    images.resize_image(resize_mode.int_value(), i, mask.width, mask.height) 
                    for i in input_image
                ]
773

lvmin's avatar
lvmin committed
774
                input_image = [x.crop(crop_region) for x in input_image]
775
776
777
778
                input_image = [
                    images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height) 
                    for x in input_image
                ]
lvmin's avatar
lvmin committed
779

lvmin's avatar
lvmin committed
780
781
                input_image = [np.asarray(x)[:, :, 0] for x in input_image]
                input_image = np.stack(input_image, axis=2)
ljleb's avatar
ljleb committed
782

783
            if 'inpaint_only' == unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) and p.image_mask is not None:
Chenlei Hu's avatar
Chenlei Hu committed
784
                logger.warning('A1111 inpaint and ControlNet inpaint duplicated. ControlNet support enabled.')
lllyasviel's avatar
lllyasviel committed
785
786
                unit.module = 'inpaint'

lvmin's avatar
lvmin committed
787
788
            # safe numpy
            input_image = np.ascontiguousarray(input_image.copy()).copy()
ljleb's avatar
ljleb committed
789

Chenlei Hu's avatar
Chenlei Hu committed
790
            logger.info(f"Loading preprocessor: {unit.module}")
791
            preprocessor = self.preprocessor[unit.module]
lvmin's avatar
lvmin committed
792

lvmin's avatar
lvmin committed
793
            high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
lvmin's avatar
lvmin committed
794

lvmin's avatar
lvmin committed
795
796
            h = (p.height // 8) * 8
            w = (p.width // 8) * 8
797

lvmin's avatar
lvmin committed
798
            if high_res_fix:
lvmin's avatar
lvmin committed
799
800
801
802
803
804
805
806
807
808
809
                if p.hr_resize_x == 0 and p.hr_resize_y == 0:
                    hr_y = int(p.height * p.hr_scale)
                    hr_x = int(p.width * p.hr_scale)
                else:
                    hr_y, hr_x = p.hr_resize_y, p.hr_resize_x
                hr_y = (hr_y // 8) * 8
                hr_x = (hr_x // 8) * 8
            else:
                hr_y = h
                hr_x = w

810
811
            if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
                # inpaint_only+lama is special and required outpaint fix
lvmin's avatar
lvmin committed
812
                _, input_image = Script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
813

lvmin's avatar
lvmin committed
814
815
            preprocessor_resolution = unit.processor_res
            if unit.pixel_perfect:
816
817
818
819
820
821
                preprocessor_resolution = external_code.pixel_perfect_resolution(
                    input_image,
                    target_H=h,
                    target_W=w,
                    resize_mode=resize_mode
                )
lvmin's avatar
lvmin committed
822

Chenlei Hu's avatar
Chenlei Hu committed
823
            logger.info(f'preprocessor resolution = {preprocessor_resolution}')
824
825
826
827
828
829
830
831
832
833
834
835
836
            # Preprocessor result may depend on numpy random operations, use the
            # random seed in `StableDiffusionProcessing` to make the 
            # preprocessor result reproducable.
            # Currently following preprocessors use numpy random:
            # - shuffle
            seed = set_numpy_seed(p)
            logger.debug(f"Use numpy seed {seed}.")
            detected_map, is_image = preprocessor(
                input_image, 
                res=preprocessor_resolution, 
                thr_a=unit.threshold_a,
                thr_b=unit.threshold_b,
            )
Chris's avatar
Chris committed
837

Chris's avatar
oops    
Chris committed
838
            if unit.module == "none" and "style" in unit.model:
Chris's avatar
Chris committed
839
840
841
842
                detected_map_bytes = detected_map[:,:,0].tobytes()
                detected_map = np.ndarray((round(input_image.shape[0]/4),input_image.shape[1]),dtype="float32",buffer=detected_map_bytes)
                detected_map = torch.Tensor(detected_map).to(devices.get_device_for("controlnet"))
                is_image = False
843

lvmin's avatar
lvmin committed
844
            if high_res_fix:
845
                if is_image:
846
                    hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
lvmin's avatar
lvmin committed
847
                    detected_maps.append((hr_detected_map, unit.module))
848
849
                else:
                    hr_control = detected_map
850
            else:
851
852
                hr_control = None

853
            if is_image:
854
                control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
ljleb's avatar
ljleb committed
855
                detected_maps.append((detected_map, unit.module))
856
            else:
Chris's avatar
Chris committed
857
858
                control = detected_map
                if unit.module == 'clip_vision':
lvmin's avatar
lvmin committed
859
                    detected_maps.append((processor.clip_vision_visualization(detected_map), unit.module))
ljleb's avatar
ljleb committed
860

lvmin's avatar
lvmin committed
861
862
863
864
865
866
867
868
            control_model_type = ControlModelType.ControlNet

            if isinstance(model_net, PlugableAdapter):
                control_model_type = ControlModelType.T2I_Adapter

            if getattr(model_net, "target", None) == "scripts.adapter.StyleAdapter":
                control_model_type = ControlModelType.T2I_StyleAdapter

869
870
871
872
873
874
875
876
            if 'reference' in unit.module:
                control_model_type = ControlModelType.AttentionInjection

            global_average_pooling = False

            if model_net is not None:
                if model_net.config.model.params.get("global_average_pooling", False):
                    global_average_pooling = True
877
878
879
880
881
882
883

            preprocessor_dict = dict(
                name=unit.module,
                preprocessor_resolution=preprocessor_resolution,
                threshold_a=unit.threshold_a,
                threshold_b=unit.threshold_b
            )
884

885
886
            forward_param = ControlParams(
                control_model=model_net,
887
                preprocessor=preprocessor_dict,
888
                hint_cond=control,
ljleb's avatar
ljleb committed
889
                weight=unit.weight,
890
                guidance_stopped=False,
ljleb's avatar
ljleb committed
891
892
                start_guidance_percent=unit.guidance_start,
                stop_guidance_percent=unit.guidance_end,
893
                advanced_weighting=None,
lvmin's avatar
lvmin committed
894
                control_model_type=control_model_type,
895
                global_average_pooling=global_average_pooling,
lvmin's avatar
lvmin committed
896
                hr_hint_cond=hr_control,
ljleb's avatar
ljleb committed
897
898
                soft_injection=control_mode != external_code.ControlMode.BALANCED,
                cfg_injection=control_mode == external_code.ControlMode.CONTROL,
899
            )
900
            forward_params.append(forward_param)
Jairo Correa's avatar
Jairo Correa committed
901

902
            if 'inpaint_only' in unit.module:
lllyasviel's avatar
lllyasviel committed
903
                final_inpaint_feed = hr_control if hr_control is not None else control
904
                final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
lllyasviel's avatar
lllyasviel committed
905
                final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
906
907
                final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
                final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
908
909
                sigma = 7
                final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
910
911
912
913
                final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
                _, Hmask, Wmask = final_inpaint_mask.shape
                final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
                final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
lllyasviel's avatar
lllyasviel committed
914
915

                def inpaint_only_post_processing(x):
916
                    _, H, W = x.shape
lllyasviel's avatar
lllyasviel committed
917
                    if Hmask != H or Wmask != W:
Chenlei Hu's avatar
Chenlei Hu committed
918
                        logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
lllyasviel's avatar
lllyasviel committed
919
                        return x
920
921
                    r = final_inpaint_raw.to(x.dtype).to(x.device)
                    m = final_inpaint_mask.to(x.dtype).to(x.device)
lllyasviel's avatar
lllyasviel committed
922
923
924
                    y = m * x.clip(0, 1) + (1 - m) * r
                    y = y.clip(0, 1)
                    return y
lllyasviel's avatar
lllyasviel committed
925
926
927

                post_processors.append(inpaint_only_post_processing)

928
929
930
            if '+lama' in unit.module:
                forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control)
                setattr(p, 'controlnet_initial_noise_modifier', forward_param.used_hint_cond_latent)
Jairo Correa's avatar
Jairo Correa committed
931
            del model_net
ljleb's avatar
ljleb committed
932

933
        self.latest_network = UnetHook(lowvram=any(unit.low_vram for unit in self.enabled_units))
934
        self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
Mikubill's avatar
Mikubill committed
935
        self.detected_map = detected_maps
lllyasviel's avatar
lllyasviel committed
936
        self.post_processors = post_processors
ljleb's avatar
ljleb committed
937

938
939
940
941
942
943
944
    def postprocess_batch(self, p, *args, **kwargs):
        images = kwargs.get('images', [])
        for post_processor in self.post_processors:
            for i in range(images.shape[0]):
                images[i] = post_processor(images[i])
        return

ljleb's avatar
ljleb committed
945
    def postprocess(self, p, processed, *args):
946
947
948
949
        self.post_processors = []
        setattr(p, 'controlnet_initial_noise_modifier', None)
        setattr(p, 'controlnet_vae_cache', None)

950
        processor_params_flag = (', '.join(getattr(processed, 'extra_generation_params', []))).lower()
lllyasviel's avatar
lllyasviel committed
951
        self.post_processors = []
952

ljleb's avatar
ljleb committed
953
954
955
        if not batch_hijack.instance.is_batch:
            self.enabled_units.clear()

956
        if shared.opts.data.get("control_net_detectmap_autosaving", False) and self.latest_network is not None:
Mikubill's avatar
Mikubill committed
957
            for detect_map, module in self.detected_map:
ljleb's avatar
ljleb committed
958
                detectmap_dir = os.path.join(shared.opts.data.get("control_net_detectedmap_dir", ""), module)
Vladimir Mandic's avatar
Vladimir Mandic committed
959
960
                if not os.path.isabs(detectmap_dir):
                    detectmap_dir = os.path.join(p.outpath_samples, detectmap_dir)
961
962
                if module != "none":
                    os.makedirs(detectmap_dir, exist_ok=True)
lvmin's avatar
lvmin committed
963
                    img = Image.fromarray(np.ascontiguousarray(detect_map.clip(0, 255).astype(np.uint8)).copy())
964
                    save_image(img, detectmap_dir, module)
965

ljleb's avatar
ljleb committed
966
        if self.latest_network is None:
Mikubill's avatar
Mikubill committed
967
            return
968

lvmin's avatar
lvmin committed
969
970
        if not batch_hijack.instance.is_batch:
            if not shared.opts.data.get("control_net_no_detectmap", False):
lvmin's avatar
lvmin committed
971
                if 'sd upscale' not in processor_params_flag:
lvmin's avatar
lvmin committed
972
973
974
975
                    if self.detected_map is not None:
                        for detect_map, module in self.detected_map:
                            if detect_map is None:
                                continue
lllyasviel's avatar
lllyasviel committed
976
                            detect_map = np.ascontiguousarray(detect_map.copy()).copy()
977
                            detect_map = external_code.visualize_inpaint_mask(detect_map)
lvmin's avatar
lvmin committed
978
979
                            processed.images.extend([
                                Image.fromarray(
lllyasviel's avatar
lllyasviel committed
980
                                    detect_map.clip(0, 255).astype(np.uint8)
lvmin's avatar
lvmin committed
981
982
                                )
                            ])
ljleb's avatar
ljleb committed
983

Sang's avatar
Sang committed
984
985
986
        self.input_image = None
        self.latest_network.restore(p.sd_model.model.diffusion_model)
        self.latest_network = None
ljleb's avatar
ljleb committed
987
        self.detected_map.clear()
Mikubill's avatar
Mikubill committed
988

Jairo Correa's avatar
Jairo Correa committed
989
990
        gc.collect()
        devices.torch_gc()
Kakigōri Maker's avatar
Kakigōri Maker committed
991

ljleb's avatar
ljleb committed
992
993
994
995
996
997
998
999
1000
    def batch_tab_process(self, p, batches, *args, **kwargs):
        self.enabled_units = self.get_enabled_units(p)
        for unit_i, unit in enumerate(self.enabled_units):
            unit.batch_images = iter([batch[unit_i] for batch in batches])

    def batch_tab_process_each(self, p, *args, **kwargs):
        for unit_i, unit in enumerate(self.enabled_units):
            if getattr(unit, 'loopback', False) and batch_hijack.instance.batch_index > 0: continue

For faster browsing, not all history is shown. View entire blame