Add custom karras scheduler

This commit is contained in:
Kohaku-Blueleaf 2023-05-22 21:52:46 +08:00
parent ee65e72931
commit a104879869
7 changed files with 75 additions and 3 deletions

View File

@ -78,7 +78,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, enable_k_sched, sigma_min, sigma_max, rho, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
@ -155,6 +155,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
inpaint_full_res_padding=inpaint_full_res_padding, inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
override_settings=override_settings, override_settings=override_settings,
enable_karras=enable_k_sched,
sigma_min=sigma_min,
sigma_max=sigma_max,
rho=rho
) )
p.scripts = modules.scripts.scripts_img2img p.scripts = modules.scripts.scripts_img2img

View File

@ -106,7 +106,7 @@ class StableDiffusionProcessing:
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None, enable_karras: bool = False, sigma_min: float=0.1, sigma_max: float=10.0, rho: float=7.0):
if sampler_index is not None: if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@ -146,6 +146,10 @@ class StableDiffusionProcessing:
self.s_tmin = s_tmin or opts.s_tmin self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise self.s_noise = s_noise or opts.s_noise
self.enable_karras = enable_karras
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.rho = rho
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False self.is_using_inpainting_conditioning = False
@ -558,6 +562,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": p.sampler_name, "Sampler": p.sampler_name,
"Enable Custom Karras Schedule": p.enable_karras,
"Karras Scheduler sigma_max": p.sigma_max,
"Karras Scheduler sigma_min": p.sigma_min,
"Karras Scheduler rho": p.rho,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None), "Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": all_seeds[index], "Seed": all_seeds[index],

View File

@ -304,6 +304,12 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps) sigmas = p.sampler_noise_scheduler_override(steps)
elif p.enable_karras:
sigma_max = p.sigma_max
sigma_min = p.sigma_min
rho = p.rho
print(f"\nsigma_min: {sigma_min}, sigma_max: {sigma_max}, rho: {rho}")
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())

View File

@ -47,6 +47,7 @@ ui_reorder_categories = [
"inpaint", "inpaint",
"sampler", "sampler",
"checkboxes", "checkboxes",
"karras_scheduler",
"hires_fix", "hires_fix",
"dimensions", "dimensions",
"cfg", "cfg",

View File

@ -7,7 +7,7 @@ from modules.ui import plaintext_to_html
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args): def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, enable_k_sched, sigma_min, sigma_max, rho, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
@ -43,6 +43,10 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_prompt=hr_prompt, hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt, hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings, override_settings=override_settings,
enable_karras=enable_k_sched,
sigma_min=sigma_min,
sigma_max=sigma_max,
rho=rho
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img

View File

@ -484,6 +484,7 @@ def create_ui():
with FormRow(elem_classes="checkboxes-row", variant="compact"): with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
t2i_enable_k_sched = gr.Checkbox(label='Custom Karras Scheduler', value=False, elem_id="txt2img_enable_k_sched")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
@ -510,6 +511,13 @@ def create_ui():
with gr.Row(): with gr.Row():
hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
elif category == "karras_scheduler":
with FormGroup(visible=False, elem_id="txt2img_karras_scheduler") as t2i_k_sched_options:
with FormRow(elem_id="txt2img_karras_scheduler_row1", variant="compact"):
t2i_k_sched_sigma_max = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, label='sigma min', value=0.1, elem_id="txt2img_sigma_min")
t2i_k_sched_sigma_min = gr.Slider(minimum=5.0, maximum=50.0, step=0.1, label='sigma max', value=10.0, elem_id="txt2img_sigma_max")
t2i_k_sched_rho = gr.Slider(minimum=3.0, maximum=10.0, step=0.5, label='rho', value=7.0, elem_id="txt2img_rho")
elif category == "batch": elif category == "batch":
if not opts.dimensions_and_batch_together: if not opts.dimensions_and_batch_together:
with FormRow(elem_id="txt2img_column_batch"): with FormRow(elem_id="txt2img_column_batch"):
@ -578,6 +586,10 @@ def create_ui():
hr_prompt, hr_prompt,
hr_negative_prompt, hr_negative_prompt,
override_settings, override_settings,
t2i_enable_k_sched,
t2i_k_sched_sigma_max,
t2i_k_sched_sigma_min,
t2i_k_sched_rho
] + custom_inputs, ] + custom_inputs,
@ -627,6 +639,13 @@ def create_ui():
show_progress = False, show_progress = False,
) )
t2i_enable_k_sched.change(
fn=lambda x: gr_show(x),
inputs=[t2i_enable_k_sched],
outputs=[t2i_k_sched_options],
show_progress=False
)
txt2img_paste_fields = [ txt2img_paste_fields = [
(txt2img_prompt, "Prompt"), (txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"), (txt2img_negative_prompt, "Negative prompt"),
@ -655,6 +674,10 @@ def create_ui():
(hr_prompt, "Hires prompt"), (hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"), (hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
(t2i_enable_k_sched, "Enable CustomKarras Schedule"),
(t2i_k_sched_sigma_max, "Karras Scheduler sigma_max"),
(t2i_k_sched_sigma_min, "Karras Scheduler sigma_min"),
(t2i_k_sched_rho, "Karras Scheduler rho"),
*modules.scripts.scripts_txt2img.infotext_fields *modules.scripts.scripts_txt2img.infotext_fields
] ]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
@ -846,6 +869,14 @@ def create_ui():
with FormRow(elem_classes="checkboxes-row", variant="compact"): with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
i2i_enable_k_sched = gr.Checkbox(label='Custom Karras Scheduler', value=False, elem_id="txt2img_enable_k_sched")
elif category == "karras_scheduler":
with FormGroup(visible=False, elem_id="img2img_karras_scheduler") as i2i_k_sched_options:
with FormRow(elem_id="img2img_karras_scheduler_row1", variant="compact"):
i2i_k_sched_sigma_max = gr.Slider(minimum=0.0, maximum=0.5, step=0.05, label='sigma min', value=0.1, elem_id="txt2img_sigma_min")
i2i_k_sched_sigma_min = gr.Slider(minimum=5.0, maximum=50.0, step=0.1, label='sigma max', value=10.0, elem_id="txt2img_sigma_max")
i2i_k_sched_rho = gr.Slider(minimum=3.0, maximum=10.0, step=0.5, label='rho', value=7.0, elem_id="txt2img_rho")
elif category == "batch": elif category == "batch":
if not opts.dimensions_and_batch_together: if not opts.dimensions_and_batch_together:
@ -949,6 +980,10 @@ def create_ui():
img2img_batch_output_dir, img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir, img2img_batch_inpaint_mask_dir,
override_settings, override_settings,
i2i_enable_k_sched,
i2i_k_sched_sigma_max,
i2i_k_sched_sigma_min,
i2i_k_sched_rho
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
img2img_gallery, img2img_gallery,
@ -1032,6 +1067,13 @@ def create_ui():
outputs=[prompt, negative_prompt, styles], outputs=[prompt, negative_prompt, styles],
) )
i2i_enable_k_sched.change(
fn=lambda x: gr_show(x),
inputs=[i2i_enable_k_sched],
outputs=[i2i_k_sched_options],
show_progress=False
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
@ -1043,6 +1085,10 @@ def create_ui():
(steps, "Steps"), (steps, "Steps"),
(sampler_index, "Sampler"), (sampler_index, "Sampler"),
(restore_faces, "Face restoration"), (restore_faces, "Face restoration"),
(i2i_enable_k_sched, "Enable Karras Schedule"),
(i2i_k_sched_sigma_max, "Karras Scheduler sigma_max"),
(i2i_k_sched_sigma_min, "Karras Scheduler sigma_min"),
(i2i_k_sched_rho, "Karras Scheduler rho"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"), (image_cfg_scale, "Image CFG scale"),
(seed, "Seed"), (seed, "Seed"),

View File

@ -220,6 +220,9 @@ axis_options = [
AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")), AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")), AxisOption("Sigma noise", float, apply_field("s_noise")),
AxisOption("Karras Scheduler Sigma Min", float, apply_field("sigma_min")),
AxisOption("Karras Scheduler Sigma Max", float, apply_field("sigma_max")),
AxisOption("Karras Scheduler rho", float, apply_field("rho")),
AxisOption("Eta", float, apply_field("eta")), AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")), AxisOption("Denoising", float, apply_field("denoising_strength")),