remove/simplify some changes from #6481
This commit is contained in:
parent
bdd57ad073
commit
43bb5190fc
@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
|||||||
|
|
||||||
|
|
||||||
class DatasetEntry:
|
class DatasetEntry:
|
||||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None):
|
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.filename_text = filename_text
|
self.filename_text = filename_text
|
||||||
self.latent_dist = latent_dist
|
self.latent_dist = latent_dist
|
||||||
@ -25,7 +25,6 @@ class DatasetEntry:
|
|||||||
self.cond = cond
|
self.cond = cond
|
||||||
self.cond_text = cond_text
|
self.cond_text = cond_text
|
||||||
self.pixel_values = pixel_values
|
self.pixel_values = pixel_values
|
||||||
self.img_shape = img_shape
|
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
@ -46,12 +45,10 @@ class PersonalizedBase(Dataset):
|
|||||||
assert data_root, 'dataset directory not specified'
|
assert data_root, 'dataset directory not specified'
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
if varsize:
|
assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
|
||||||
assert batch_size == 1, 'variable img size must have batch size 1'
|
|
||||||
|
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
|
||||||
|
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
self.tag_drop_out = tag_drop_out
|
self.tag_drop_out = tag_drop_out
|
||||||
|
|
||||||
@ -91,14 +88,14 @@ class PersonalizedBase(Dataset):
|
|||||||
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
||||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
latent_sampling_method = "once"
|
latent_sampling_method = "once"
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||||
elif latent_sampling_method == "deterministic":
|
elif latent_sampling_method == "deterministic":
|
||||||
# Works only for DiagonalGaussianDistribution
|
# Works only for DiagonalGaussianDistribution
|
||||||
latent_dist.std = 0
|
latent_dist.std = 0
|
||||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||||
elif latent_sampling_method == "random":
|
elif latent_sampling_method == "random":
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size)
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
||||||
|
|
||||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
entry.cond_text = self.create_text(filename_text)
|
entry.cond_text = self.create_text(filename_text)
|
||||||
@ -154,7 +151,6 @@ class BatchLoader:
|
|||||||
self.cond_text = [entry.cond_text for entry in data]
|
self.cond_text = [entry.cond_text for entry in data]
|
||||||
self.cond = [entry.cond for entry in data]
|
self.cond = [entry.cond for entry in data]
|
||||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||||
self.img_shape = [entry.img_shape for entry in data]
|
|
||||||
#self.emb_index = [entry.emb_index for entry in data]
|
#self.emb_index = [entry.emb_index for entry in data]
|
||||||
#print(self.latent_sample.device)
|
#print(self.latent_sample.device)
|
||||||
|
|
||||||
|
@ -492,8 +492,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
else:
|
else:
|
||||||
p.prompt = batch.cond_text[0]
|
p.prompt = batch.cond_text[0]
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = batch.img_shape[0][0]
|
p.width = training_width
|
||||||
p.height = batch.img_shape[0][1]
|
p.height = training_height
|
||||||
|
|
||||||
preview_text = p.prompt
|
preview_text = p.prompt
|
||||||
|
|
||||||
|
@ -1348,7 +1348,7 @@ def create_ui():
|
|||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
|
||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
||||||
varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize")
|
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
|
Loading…
Reference in New Issue
Block a user