Switched to exception handling
This commit is contained in:
parent
907a88b2d0
commit
b2368a3bce
@ -22,7 +22,6 @@ class PersonalizedBase(Dataset):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
self.extns = [".jpg",".jpeg",".png",".webp",".bmp"]
|
||||
|
||||
self.dataset = []
|
||||
|
||||
@ -33,12 +32,13 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in self.extns]
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
image = Image.open(path)
|
||||
image = image.convert('RGB')
|
||||
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
try:
|
||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
filename = os.path.basename(path)
|
||||
filename_tokens = os.path.splitext(filename)[0]
|
||||
|
@ -12,13 +12,12 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
dst = os.path.abspath(process_dst)
|
||||
extns = [".jpg",".jpeg",".png",".webp",".bmp"]
|
||||
|
||||
assert src != dst, 'same directory specified as source and destination'
|
||||
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
|
||||
files = [i for i in os.listdir(src) if os.path.splitext(i.casefold())[1] in extns]
|
||||
files = os.listdir(src)
|
||||
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
@ -47,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
subindex = [0]
|
||||
filename = os.path.join(src, imagefile)
|
||||
img = Image.open(filename).convert("RGB")
|
||||
try:
|
||||
img = Image.open(filename).convert("RGB")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
@ -161,7 +161,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
extns = [".jpg",".jpeg",".png",".webp",".bmp"]
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
|
||||
@ -201,10 +200,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
if ititial_step > steps:
|
||||
return embedding, filename
|
||||
|
||||
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns])
|
||||
|
||||
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, (x, text) in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
@ -228,10 +223,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_num = embedding.step // epoch_len
|
||||
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
@ -243,9 +238,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
prompt=text,
|
||||
steps=20,
|
||||
height=training_height,
|
||||
steps=28,
|
||||
height=768,
|
||||
width=training_width,
|
||||
negative_prompt="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, artist name",
|
||||
cfg_scale=7.0,
|
||||
sampler_index=0,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user