instruct-pix2pix support

This commit is contained in:
Kyle 2023-02-02 09:37:01 -05:00
parent 226d840e84
commit 269833067d
2 changed files with 5 additions and 5 deletions

View File

@ -186,7 +186,7 @@ class StableDiffusionProcessing:
return conditioning return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
return conditioning_image return conditioning_image

View File

@ -77,9 +77,9 @@ class CFGDenoiser(torch.nn.Module):
batch_size = len(conds_list) batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params) cfg_denoiser_callback(denoiser_params)
@ -88,7 +88,7 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond]) cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})