2023-01-30 10:11:30 +03:00
|
|
|
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-01-30 09:51:06 +03:00
|
|
|
# imports for functions that previously were here and are used by other modules
|
2023-05-10 09:02:23 +03:00
|
|
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
2022-09-03 17:21:15 +03:00
|
|
|
|
2022-10-06 12:08:48 +03:00
|
|
|
all_samplers = [
|
2023-01-30 10:11:30 +03:00
|
|
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
|
|
|
*sd_samplers_compvis.samplers_data_compvis,
|
2022-09-03 12:08:45 +03:00
|
|
|
]
|
2022-11-19 12:01:51 +03:00
|
|
|
all_samplers_map = {x.name: x for x in all_samplers}
|
2022-10-06 12:08:48 +03:00
|
|
|
|
|
|
|
samplers = []
|
|
|
|
samplers_for_img2img = []
|
2022-11-27 13:43:10 +03:00
|
|
|
samplers_map = {}
|
2022-10-06 12:08:48 +03:00
|
|
|
|
|
|
|
|
2022-11-19 12:01:51 +03:00
|
|
|
def create_sampler(name, model):
|
|
|
|
if name is not None:
|
|
|
|
config = all_samplers_map.get(name, None)
|
|
|
|
else:
|
|
|
|
config = all_samplers[0]
|
|
|
|
|
|
|
|
assert config is not None, f'bad sampler name: {name}'
|
|
|
|
|
2022-10-06 14:12:52 +03:00
|
|
|
sampler = config.constructor(model)
|
|
|
|
sampler.config = config
|
2022-11-19 12:01:51 +03:00
|
|
|
|
2022-10-06 14:12:52 +03:00
|
|
|
return sampler
|
|
|
|
|
|
|
|
|
2022-10-06 12:08:48 +03:00
|
|
|
def set_samplers():
|
|
|
|
global samplers, samplers_for_img2img
|
|
|
|
|
2023-01-30 10:11:30 +03:00
|
|
|
hidden = set(shared.opts.hide_samplers)
|
2023-02-10 05:00:09 -08:00
|
|
|
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
2022-10-06 12:08:48 +03:00
|
|
|
|
|
|
|
samplers = [x for x in all_samplers if x.name not in hidden]
|
|
|
|
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
|
|
|
|
2022-11-27 13:43:10 +03:00
|
|
|
samplers_map.clear()
|
|
|
|
for sampler in all_samplers:
|
|
|
|
samplers_map[sampler.name.lower()] = sampler.name
|
|
|
|
for alias in sampler.aliases:
|
|
|
|
samplers_map[alias.lower()] = sampler.name
|
|
|
|
|
2022-10-06 12:08:48 +03:00
|
|
|
|
|
|
|
set_samplers()
|