From 924e222004ab54273806c5f2ca7a0e7cfa76ad83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 23:04:24 +0300 Subject: [PATCH] add option to show/hide warnings removed hiding warnings from LDSR fixed/reworked few places that produced warnings --- extensions-builtin/LDSR/ldsr_model_arch.py | 3 -- javascript/localization.js | 2 +- modules/hypernetworks/hypernetwork.py | 7 +++- modules/sd_hijack.py | 8 ---- modules/sd_hijack_checkpoint.py | 38 ++++++++++++++++++- modules/shared.py | 1 + .../textual_inversion/textual_inversion.py | 6 ++- modules/ui.py | 31 ++++++++------- scripts/prompts_from_file.py | 2 +- style.css | 5 +-- 10 files changed, 71 insertions(+), 32 deletions(-) diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 0ad49f4e..bc11cc6e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -1,7 +1,6 @@ import os import gc import time -import warnings import numpy as np import torch @@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap from modules import shared, sd_hijack -warnings.filterwarnings("ignore", category=UserWarning) - cached_ldsr_model: torch.nn.Module = None diff --git a/javascript/localization.js b/javascript/localization.js index bf9e1506..1a5a1dbb 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,7 +11,7 @@ ignore_ids_for_localization={ train_embedding: 'OPTION', train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', - img2img_styles 'OPTION', + img2img_styles: 'OPTION', setting_random_artist_categories: 'SPAN', setting_face_restoration_model: 'SPAN', setting_realesrgan_enabled_models: 'SPAN', diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c963fc40..74e78582 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes +from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi pbar = tqdm.tqdm(total=steps - initial_step) try: + sd_hijack_checkpoint.add() + for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break @@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() hypernetwork.eval() #report_statistics(loss_dict) + sd_hijack_checkpoint.remove() + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6b0d95af..870eba88 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -69,12 +69,6 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -def fix_checkpoint(): - ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward - ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward - ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward - - class StableDiffusionModelHijack: fixes = None comments = [] @@ -106,8 +100,6 @@ class StableDiffusionModelHijack: self.optimization_method = apply_optimizations() self.clip = m.cond_stage_model - - fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 5712972f..2604d969 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -1,10 +1,46 @@ from torch.utils.checkpoint import checkpoint +import ldm.modules.attention +import ldm.modules.diffusionmodules.openaimodel + + def BasicTransformerBlock_forward(self, x, context=None): return checkpoint(self._forward, x, context) + def AttentionBlock_forward(self, x): return checkpoint(self._forward, x) + def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb) \ No newline at end of file + return checkpoint(self._forward, x, emb) + + +stored = [] + + +def add(): + if len(stored) != 0: + return + + stored.extend([ + ldm.modules.attention.BasicTransformerBlock.forward, + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward + ]) + + ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward + + +def remove(): + if len(stored) == 0: + return + + ldm.modules.attention.BasicTransformerBlock.forward = stored[0] + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] + + stored.clear() + diff --git a/modules/shared.py b/modules/shared.py index a708f23c..ddb97f99 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" })) options_templates.update(options_section(('system', "System"), { + "show_warnings": OptionInfo(False, "Show warnings in console."), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7e4a6d24..5a7be422 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -15,7 +15,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st pbar = tqdm.tqdm(total=steps - initial_step) try: + sd_hijack_checkpoint.add() + for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break @@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() shared.sd_model.first_stage_model.to(devices.device) shared.parallel_processing_allowed = old_parallel_processing_allowed + sd_hijack_checkpoint.remove() return embedding, filename + def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): old_embedding_name = embedding.name old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None diff --git a/modules/ui.py b/modules/ui.py index 6d70a795..25818fb0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -11,6 +11,7 @@ import tempfile import time import traceback from functools import partial, reduce +import warnings import gradio as gr import gradio.routes @@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) + # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -417,17 +420,16 @@ def apply_setting(key, value): return value -def update_generation_info(args): - generation_info, html_info, img_index = args +def update_generation_info(generation_info, html_info, img_index): try: generation_info = json.loads(generation_info) if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() except Exception: pass # if the json parse or anything else fails, just return the old html_info - return html_info + return html_info, gr.update() def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): @@ -508,10 +510,9 @@ Requested path was: {f} generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False + _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], ) save.click( @@ -526,7 +527,8 @@ Requested path was: {f} outputs=[ download_files, html_log, - ] + ], + show_progress=False, ) save_zip.click( @@ -588,7 +590,7 @@ def create_ui(): txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): @@ -768,7 +770,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): @@ -1768,7 +1770,10 @@ def create_ui(): if saved_value is None: ui_settings[key] = getattr(obj, field) elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + pass + + # this warning is generally not useful; + # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') else: setattr(obj, field, saved_value) if init_field is not None: diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index f3e711d7..76dc5778 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -116,7 +116,7 @@ class Script(scripts.Script): checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) - file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file")) + file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) diff --git a/style.css b/style.css index 61279a19..0845519a 100644 --- a/style.css +++ b/style.css @@ -299,9 +299,8 @@ input[type="range"]{ } /* more gradio's garbage cleanup */ -.min-h-\[4rem\] { - min-height: unset !important; -} +.min-h-\[4rem\] { min-height: unset !important; } +.min-h-\[6rem\] { min-height: unset !important; } .progressDiv{ position: absolute;