diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 1568b2b8..af9fbcf2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -115,7 +115,7 @@ class PersonalizedBase(Dataset): weight /= weight.mean() elif use_weight: #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later - weight = torch.ones([channels] + latent_size) + weight = torch.ones(latent_sample.shape) else: weight = None