- Fix file/folder opening behind the browser window
- Add WD14 and BLIP captioning to utilities - Improve overall GUI layout
This commit is contained in:
parent
0ca93a7aa7
commit
c90aa2cc61
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,7 +1,8 @@
|
|||||||
venv
|
venv
|
||||||
venv1
|
|
||||||
mytraining.ps
|
|
||||||
__pycache__
|
__pycache__
|
||||||
|
*.txt
|
||||||
|
cudnn_windows
|
||||||
.vscode
|
.vscode
|
||||||
*.egg-info
|
*.egg-info
|
||||||
build
|
build
|
||||||
|
wd14_tagger_model
|
21
BLIP_caption/configs/med_config.json
Normal file
21
BLIP_caption/configs/med_config.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"BertModel"
|
||||||
|
],
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-12,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"model_type": "bert",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"vocab_size": 30524,
|
||||||
|
"encoder_width": 768,
|
||||||
|
"add_cross_attention": true
|
||||||
|
}
|
115
BLIP_caption/make_captions.py
Normal file
115
BLIP_caption/make_captions.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
from models.blip import blip_decoder
|
||||||
|
# from Salesforce_BLIP.models.blip import blip_decoder
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
cwd = os.getcwd()
|
||||||
|
print('Current Working Directory is: ', cwd)
|
||||||
|
|
||||||
|
os.chdir('.\BLIP_caption')
|
||||||
|
|
||||||
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
print(f"loading BLIP caption: {args.caption_weights}")
|
||||||
|
# image_size = 384
|
||||||
|
# model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config='configs/med_config.json')
|
||||||
|
# model.eval()
|
||||||
|
# model = model.to(device)
|
||||||
|
|
||||||
|
image_size = 384
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||||
|
])
|
||||||
|
|
||||||
|
model_url = args.caption_weights # 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
|
||||||
|
|
||||||
|
model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
print("BLIP loaded")
|
||||||
|
# 正方形でいいのか? という気がするがソースがそうなので
|
||||||
|
# transform = transforms.Compose([
|
||||||
|
# transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||||
|
# transforms.ToTensor(),
|
||||||
|
# transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||||
|
# ])
|
||||||
|
|
||||||
|
# captioningする
|
||||||
|
def run_batch(path_imgs):
|
||||||
|
imgs = torch.stack([im for _, im in path_imgs]).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if args.beam_search:
|
||||||
|
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
|
||||||
|
max_length=args.max_length, min_length=args.min_length)
|
||||||
|
else:
|
||||||
|
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
||||||
|
|
||||||
|
for (image_path, _), caption in zip(path_imgs, captions):
|
||||||
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||||
|
f.write(caption + "\n")
|
||||||
|
if args.debug:
|
||||||
|
print(image_path, caption)
|
||||||
|
|
||||||
|
b_imgs = []
|
||||||
|
for image_path in tqdm(image_paths, smoothing=0.0):
|
||||||
|
raw_image = Image.open(image_path)
|
||||||
|
if raw_image.mode != "RGB":
|
||||||
|
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
|
||||||
|
raw_image = raw_image.convert("RGB")
|
||||||
|
|
||||||
|
image = transform(raw_image)
|
||||||
|
b_imgs.append((image_path, image))
|
||||||
|
if len(b_imgs) >= args.batch_size:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
b_imgs.clear()
|
||||||
|
if len(b_imgs) > 0:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("caption_weights", type=str,
|
||||||
|
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
||||||
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
|
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
|
parser.add_argument("--beam_search", action="store_true",
|
||||||
|
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||||
|
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||||
|
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# スペルミスしていたオプションを復元する
|
||||||
|
if args.caption_extention is not None:
|
||||||
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
|
main(args)
|
238
BLIP_caption/models/blip.py
Normal file
238
BLIP_caption/models/blip.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
'''
|
||||||
|
* Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
* All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
* By Junnan Li
|
||||||
|
'''
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
from models.vit import VisionTransformer, interpolate_pos_embed
|
||||||
|
from models.med import BertConfig, BertModel, BertLMHeadModel
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import os
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from timm.models.hub import download_cached_file
|
||||||
|
|
||||||
|
class BLIP_Base(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
med_config = 'configs/med_config.json',
|
||||||
|
image_size = 224,
|
||||||
|
vit = 'base',
|
||||||
|
vit_grad_ckpt = False,
|
||||||
|
vit_ckpt_layer = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||||
|
image_size (int): input image size
|
||||||
|
vit (str): model size of vision transformer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
||||||
|
self.tokenizer = init_tokenizer()
|
||||||
|
med_config = BertConfig.from_json_file(med_config)
|
||||||
|
med_config.encoder_width = vision_width
|
||||||
|
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, image, caption, mode):
|
||||||
|
|
||||||
|
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
||||||
|
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
||||||
|
|
||||||
|
if mode=='image':
|
||||||
|
# return image features
|
||||||
|
image_embeds = self.visual_encoder(image)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
elif mode=='text':
|
||||||
|
# return text features
|
||||||
|
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
||||||
|
return_dict = True, mode = 'text')
|
||||||
|
return text_output.last_hidden_state
|
||||||
|
|
||||||
|
elif mode=='multimodal':
|
||||||
|
# return multimodel features
|
||||||
|
image_embeds = self.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||||
|
|
||||||
|
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
||||||
|
output = self.text_encoder(text.input_ids,
|
||||||
|
attention_mask = text.attention_mask,
|
||||||
|
encoder_hidden_states = image_embeds,
|
||||||
|
encoder_attention_mask = image_atts,
|
||||||
|
return_dict = True,
|
||||||
|
)
|
||||||
|
return output.last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BLIP_Decoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
med_config = 'configs/med_config.json',
|
||||||
|
image_size = 384,
|
||||||
|
vit = 'base',
|
||||||
|
vit_grad_ckpt = False,
|
||||||
|
vit_ckpt_layer = 0,
|
||||||
|
prompt = 'a picture of ',
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||||
|
image_size (int): input image size
|
||||||
|
vit (str): model size of vision transformer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
||||||
|
self.tokenizer = init_tokenizer()
|
||||||
|
med_config = BertConfig.from_json_file(med_config)
|
||||||
|
med_config.encoder_width = vision_width
|
||||||
|
self.text_decoder = BertLMHeadModel(config=med_config)
|
||||||
|
|
||||||
|
self.prompt = prompt
|
||||||
|
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, image, caption):
|
||||||
|
|
||||||
|
image_embeds = self.visual_encoder(image)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||||
|
|
||||||
|
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
||||||
|
|
||||||
|
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
||||||
|
|
||||||
|
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
||||||
|
decoder_targets[:,:self.prompt_length] = -100
|
||||||
|
|
||||||
|
decoder_output = self.text_decoder(text.input_ids,
|
||||||
|
attention_mask = text.attention_mask,
|
||||||
|
encoder_hidden_states = image_embeds,
|
||||||
|
encoder_attention_mask = image_atts,
|
||||||
|
labels = decoder_targets,
|
||||||
|
return_dict = True,
|
||||||
|
)
|
||||||
|
loss_lm = decoder_output.loss
|
||||||
|
|
||||||
|
return loss_lm
|
||||||
|
|
||||||
|
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
||||||
|
image_embeds = self.visual_encoder(image)
|
||||||
|
|
||||||
|
if not sample:
|
||||||
|
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
||||||
|
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||||
|
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
||||||
|
|
||||||
|
prompt = [self.prompt] * image.size(0)
|
||||||
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
||||||
|
input_ids[:,0] = self.tokenizer.bos_token_id
|
||||||
|
input_ids = input_ids[:, :-1]
|
||||||
|
|
||||||
|
if sample:
|
||||||
|
#nucleus sampling
|
||||||
|
outputs = self.text_decoder.generate(input_ids=input_ids,
|
||||||
|
max_length=max_length,
|
||||||
|
min_length=min_length,
|
||||||
|
do_sample=True,
|
||||||
|
top_p=top_p,
|
||||||
|
num_return_sequences=1,
|
||||||
|
eos_token_id=self.tokenizer.sep_token_id,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
**model_kwargs)
|
||||||
|
else:
|
||||||
|
#beam search
|
||||||
|
outputs = self.text_decoder.generate(input_ids=input_ids,
|
||||||
|
max_length=max_length,
|
||||||
|
min_length=min_length,
|
||||||
|
num_beams=num_beams,
|
||||||
|
eos_token_id=self.tokenizer.sep_token_id,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
**model_kwargs)
|
||||||
|
|
||||||
|
captions = []
|
||||||
|
for output in outputs:
|
||||||
|
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
||||||
|
captions.append(caption[len(self.prompt):])
|
||||||
|
return captions
|
||||||
|
|
||||||
|
|
||||||
|
def blip_decoder(pretrained='',**kwargs):
|
||||||
|
model = BLIP_Decoder(**kwargs)
|
||||||
|
if pretrained:
|
||||||
|
model,msg = load_checkpoint(model,pretrained)
|
||||||
|
assert(len(msg.missing_keys)==0)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def blip_feature_extractor(pretrained='',**kwargs):
|
||||||
|
model = BLIP_Base(**kwargs)
|
||||||
|
if pretrained:
|
||||||
|
model,msg = load_checkpoint(model,pretrained)
|
||||||
|
assert(len(msg.missing_keys)==0)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def init_tokenizer():
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||||
|
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||||
|
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
||||||
|
|
||||||
|
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
||||||
|
if vit=='base':
|
||||||
|
vision_width = 768
|
||||||
|
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
||||||
|
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||||
|
drop_path_rate=0 or drop_path_rate
|
||||||
|
)
|
||||||
|
elif vit=='large':
|
||||||
|
vision_width = 1024
|
||||||
|
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
||||||
|
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||||
|
drop_path_rate=0.1 or drop_path_rate
|
||||||
|
)
|
||||||
|
return visual_encoder, vision_width
|
||||||
|
|
||||||
|
def is_url(url_or_filename):
|
||||||
|
parsed = urlparse(url_or_filename)
|
||||||
|
return parsed.scheme in ("http", "https")
|
||||||
|
|
||||||
|
def load_checkpoint(model,url_or_filename):
|
||||||
|
if is_url(url_or_filename):
|
||||||
|
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
||||||
|
checkpoint = torch.load(cached_file, map_location='cpu')
|
||||||
|
elif os.path.isfile(url_or_filename):
|
||||||
|
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
||||||
|
else:
|
||||||
|
raise RuntimeError('checkpoint url or path is invalid')
|
||||||
|
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
|
||||||
|
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
||||||
|
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
||||||
|
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
||||||
|
model.visual_encoder_m)
|
||||||
|
for key in model.state_dict().keys():
|
||||||
|
if key in state_dict.keys():
|
||||||
|
if state_dict[key].shape!=model.state_dict()[key].shape:
|
||||||
|
del state_dict[key]
|
||||||
|
|
||||||
|
msg = model.load_state_dict(state_dict,strict=False)
|
||||||
|
print('load checkpoint from %s'%url_or_filename)
|
||||||
|
return model,msg
|
||||||
|
|
955
BLIP_caption/models/med.py
Normal file
955
BLIP_caption/models/med.py
Normal file
@ -0,0 +1,955 @@
|
|||||||
|
'''
|
||||||
|
* Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
* All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
* By Junnan Li
|
||||||
|
* Based on huggingface code base
|
||||||
|
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
||||||
|
'''
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, device, dtype, nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.file_utils import (
|
||||||
|
ModelOutput,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
|
MaskedLMOutput,
|
||||||
|
MultipleChoiceModelOutput,
|
||||||
|
NextSentencePredictorOutput,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
|
SequenceClassifierOutput,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.models.bert.configuration_bert import BertConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BertEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word and position embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
|
||||||
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||||
|
# any TensorFlow checkpoint file
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
|
if input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
else:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
embeddings = inputs_embeds
|
||||||
|
|
||||||
|
if self.position_embedding_type == "absolute":
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
embeddings += position_embeddings
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config, is_cross_attention):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
|
raise ValueError(
|
||||||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
if is_cross_attention:
|
||||||
|
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
||||||
|
else:
|
||||||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
self.save_attention = False
|
||||||
|
|
||||||
|
def save_attn_gradients(self, attn_gradients):
|
||||||
|
self.attn_gradients = attn_gradients
|
||||||
|
|
||||||
|
def get_attn_gradients(self):
|
||||||
|
return self.attn_gradients
|
||||||
|
|
||||||
|
def save_attention_map(self, attention_map):
|
||||||
|
self.attention_map = attention_map
|
||||||
|
|
||||||
|
def get_attention_map(self):
|
||||||
|
return self.attention_map
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
seq_length = hidden_states.size()[1]
|
||||||
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||||
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||||
|
distance = position_ids_l - position_ids_r
|
||||||
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||||
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key":
|
||||||
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores
|
||||||
|
elif self.position_embedding_type == "relative_key_query":
|
||||||
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||||
|
|
||||||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||||
|
|
||||||
|
if is_cross_attention and self.save_attention:
|
||||||
|
self.save_attention_map(attention_probs)
|
||||||
|
attention_probs.register_hook(self.save_attn_gradients)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs_dropped = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs_dropped = attention_probs_dropped * head_mask
|
||||||
|
|
||||||
|
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertAttention(nn.Module):
|
||||||
|
def __init__(self, config, is_cross_attention=False):
|
||||||
|
super().__init__()
|
||||||
|
self.self = BertSelfAttention(config, is_cross_attention)
|
||||||
|
self.output = BertSelfOutput(config)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prune linear layers
|
||||||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
|
|
||||||
|
# Update hyper params and store pruned heads
|
||||||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
self_outputs = self.self(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class BertIntermediate(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_num):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
self.attention = BertAttention(config)
|
||||||
|
self.layer_num = layer_num
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
||||||
|
self.intermediate = BertIntermediate(config)
|
||||||
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
mode=None,
|
||||||
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
|
)
|
||||||
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
|
||||||
|
if mode=='multimodal':
|
||||||
|
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
||||||
|
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
attention_output,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
layer_output = apply_chunking_to_forward(
|
||||||
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
|
)
|
||||||
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def feed_forward_chunk(self, attention_output):
|
||||||
|
intermediate_output = self.intermediate(attention_output)
|
||||||
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
|
class BertEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for i in range(self.config.num_hidden_layers):
|
||||||
|
layer_module = self.layer[i]
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warn(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(layer_module),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BertPooler(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
|
# to the first token.
|
||||||
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
class BertPredictionHeadTransform(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.transform_act_fn = config.hidden_act
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertLMPredictionHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.transform = BertPredictionHeadTransform(config)
|
||||||
|
|
||||||
|
# The output weights are the same as the input embeddings, but there is
|
||||||
|
# an output-only bias for each token.
|
||||||
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||||
|
|
||||||
|
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||||
|
self.decoder.bias = self.bias
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.transform(hidden_states)
|
||||||
|
hidden_states = self.decoder(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class BertOnlyMLMHead(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.predictions = BertLMPredictionHead(config)
|
||||||
|
|
||||||
|
def forward(self, sequence_output):
|
||||||
|
prediction_scores = self.predictions(sequence_output)
|
||||||
|
return prediction_scores
|
||||||
|
|
||||||
|
|
||||||
|
class BertPreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = BertConfig
|
||||||
|
base_model_prefix = "bert"
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
""" Initialize the weights """
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class BertModel(BertPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||||
|
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||||
|
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||||
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
||||||
|
input to the forward pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = BertEmbeddings(config)
|
||||||
|
|
||||||
|
self.encoder = BertEncoder(config)
|
||||||
|
|
||||||
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
"""
|
||||||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||||
|
class PreTrainedModel
|
||||||
|
"""
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
|
||||||
|
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
||||||
|
"""
|
||||||
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
attention_mask (:obj:`torch.Tensor`):
|
||||||
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||||
|
input_shape (:obj:`Tuple[int]`):
|
||||||
|
The shape of the input to the model.
|
||||||
|
device: (:obj:`torch.device`):
|
||||||
|
The device of the input to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
||||||
|
"""
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
if attention_mask.dim() == 3:
|
||||||
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
elif attention_mask.dim() == 2:
|
||||||
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if is_decoder:
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
||||||
|
# causal and attention masks must have same type with pytorch version < 1.3
|
||||||
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||||
|
|
||||||
|
if causal_mask.shape[1] < attention_mask.shape[1]:
|
||||||
|
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
||||||
|
causal_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
||||||
|
causal_mask,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||||
|
input_shape, attention_mask.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
return extended_attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
is_decoder=False,
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = input_ids.device
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
elif encoder_embeds is not None:
|
||||||
|
input_shape = encoder_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = encoder_embeds.device
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
|
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
||||||
|
device, is_decoder)
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
if type(encoder_hidden_states) == list:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||||
|
else:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
|
||||||
|
if type(encoder_attention_mask) == list:
|
||||||
|
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
||||||
|
elif encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
if encoder_embeds is None:
|
||||||
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_output = encoder_embeds
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BertLMHeadModel(BertPreTrainedModel):
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.bert = BertModel(config, add_pooling_layer=False)
|
||||||
|
self.cls = BertOnlyMLMHead(config)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
labels=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
return_logits=False,
|
||||||
|
is_decoder=True,
|
||||||
|
reduction='mean',
|
||||||
|
mode='multimodal',
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||||
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
Returns:
|
||||||
|
Example::
|
||||||
|
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
||||||
|
>>> import torch
|
||||||
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||||
|
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
||||||
|
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> prediction_logits = outputs.logits
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
if labels is not None:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
is_decoder=is_decoder,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.cls(sequence_output)
|
||||||
|
|
||||||
|
if return_logits:
|
||||||
|
return prediction_scores[:, :-1, :].contiguous()
|
||||||
|
|
||||||
|
lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||||
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
||||||
|
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
if reduction=='none':
|
||||||
|
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (prediction_scores,) + outputs[2:]
|
||||||
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=lm_loss,
|
||||||
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, **model_kwargs):
|
||||||
|
input_shape = input_ids.shape
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"past_key_values": past,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"encoder_attention_mask": encoder_attention_mask,
|
||||||
|
"is_decoder": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
|
return reordered_past
|
305
BLIP_caption/models/vit.py
Normal file
305
BLIP_caption/models/vit.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
'''
|
||||||
|
* Copyright (c) 2022, salesforce.com, inc.
|
||||||
|
* All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
|
* By Junnan Li
|
||||||
|
* Based on timm code base
|
||||||
|
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||||
|
'''
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from timm.models.vision_transformer import _cfg, PatchEmbed
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
from timm.models.layers import trunc_normal_, DropPath
|
||||||
|
from timm.models.helpers import named_apply, adapt_input_conv
|
||||||
|
|
||||||
|
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||||
|
"""
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.attn_gradients = None
|
||||||
|
self.attention_map = None
|
||||||
|
|
||||||
|
def save_attn_gradients(self, attn_gradients):
|
||||||
|
self.attn_gradients = attn_gradients
|
||||||
|
|
||||||
|
def get_attn_gradients(self):
|
||||||
|
return self.attn_gradients
|
||||||
|
|
||||||
|
def save_attention_map(self, attention_map):
|
||||||
|
self.attention_map = attention_map
|
||||||
|
|
||||||
|
def get_attention_map(self):
|
||||||
|
return self.attention_map
|
||||||
|
|
||||||
|
def forward(self, x, register_hook=False):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
if register_hook:
|
||||||
|
self.save_attention_map(attn)
|
||||||
|
attn.register_hook(self.save_attn_gradients)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||||
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
if use_grad_checkpointing:
|
||||||
|
self.attn = checkpoint_wrapper(self.attn)
|
||||||
|
self.mlp = checkpoint_wrapper(self.mlp)
|
||||||
|
|
||||||
|
def forward(self, x, register_hook=False):
|
||||||
|
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
""" Vision Transformer
|
||||||
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
||||||
|
https://arxiv.org/abs/2010.11929
|
||||||
|
"""
|
||||||
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||||
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
||||||
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
||||||
|
use_grad_checkpointing=False, ckpt_layer=0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img_size (int, tuple): input image size
|
||||||
|
patch_size (int, tuple): patch size
|
||||||
|
in_chans (int): number of input channels
|
||||||
|
num_classes (int): number of classes for classification head
|
||||||
|
embed_dim (int): embedding dimension
|
||||||
|
depth (int): depth of transformer
|
||||||
|
num_heads (int): number of attention heads
|
||||||
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||||
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||||
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||||
|
drop_rate (float): dropout rate
|
||||||
|
attn_drop_rate (float): attention dropout rate
|
||||||
|
drop_path_rate (float): stochastic depth rate
|
||||||
|
norm_layer: (nn.Module): normalization layer
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||||
|
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Block(
|
||||||
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||||
|
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
||||||
|
)
|
||||||
|
for i in range(depth)])
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
|
||||||
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {'pos_embed', 'cls_token'}
|
||||||
|
|
||||||
|
def forward(self, x, register_blk=-1):
|
||||||
|
B = x.shape[0]
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
x = x + self.pos_embed[:,:x.size(1),:]
|
||||||
|
x = self.pos_drop(x)
|
||||||
|
|
||||||
|
for i,blk in enumerate(self.blocks):
|
||||||
|
x = blk(x, register_blk==i)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.jit.ignore()
|
||||||
|
def load_pretrained(self, checkpoint_path, prefix=''):
|
||||||
|
_load_weights(self, checkpoint_path, prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
||||||
|
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def _n2p(w, t=True):
|
||||||
|
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
||||||
|
w = w.flatten()
|
||||||
|
if t:
|
||||||
|
if w.ndim == 4:
|
||||||
|
w = w.transpose([3, 2, 0, 1])
|
||||||
|
elif w.ndim == 3:
|
||||||
|
w = w.transpose([2, 0, 1])
|
||||||
|
elif w.ndim == 2:
|
||||||
|
w = w.transpose([1, 0])
|
||||||
|
return torch.from_numpy(w)
|
||||||
|
|
||||||
|
w = np.load(checkpoint_path)
|
||||||
|
if not prefix and 'opt/target/embedding/kernel' in w:
|
||||||
|
prefix = 'opt/target/'
|
||||||
|
|
||||||
|
if hasattr(model.patch_embed, 'backbone'):
|
||||||
|
# hybrid
|
||||||
|
backbone = model.patch_embed.backbone
|
||||||
|
stem_only = not hasattr(backbone, 'stem')
|
||||||
|
stem = backbone if stem_only else backbone.stem
|
||||||
|
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
||||||
|
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
||||||
|
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
||||||
|
if not stem_only:
|
||||||
|
for i, stage in enumerate(backbone.stages):
|
||||||
|
for j, block in enumerate(stage.blocks):
|
||||||
|
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
||||||
|
for r in range(3):
|
||||||
|
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
||||||
|
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
||||||
|
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
||||||
|
if block.downsample is not None:
|
||||||
|
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
||||||
|
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
||||||
|
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
||||||
|
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
||||||
|
else:
|
||||||
|
embed_conv_w = adapt_input_conv(
|
||||||
|
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||||
|
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
||||||
|
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||||
|
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||||
|
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||||
|
if pos_embed_w.shape != model.pos_embed.shape:
|
||||||
|
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||||
|
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||||
|
model.pos_embed.copy_(pos_embed_w)
|
||||||
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||||
|
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||||
|
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||||
|
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||||
|
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||||
|
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
||||||
|
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
||||||
|
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
||||||
|
for i, block in enumerate(model.blocks.children()):
|
||||||
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||||
|
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
||||||
|
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||||
|
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||||
|
block.attn.qkv.weight.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||||
|
block.attn.qkv.bias.copy_(torch.cat([
|
||||||
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||||
|
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||||
|
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||||
|
for r in range(2):
|
||||||
|
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
||||||
|
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
||||||
|
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
||||||
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
||||||
|
# interpolate position embedding
|
||||||
|
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||||
|
num_patches = visual_encoder.patch_embed.num_patches
|
||||||
|
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
||||||
|
# height (== width) for the checkpoint position embedding
|
||||||
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||||
|
# height (== width) for the new position embedding
|
||||||
|
new_size = int(num_patches ** 0.5)
|
||||||
|
|
||||||
|
if orig_size!=new_size:
|
||||||
|
# class_token and dist_token are kept unchanged
|
||||||
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||||
|
# only the position tokens are interpolated
|
||||||
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||||
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||||
|
pos_tokens = torch.nn.functional.interpolate(
|
||||||
|
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||||
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||||
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||||
|
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
||||||
|
|
||||||
|
return new_pos_embed
|
||||||
|
else:
|
||||||
|
return pos_embed_checkpoint
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -130,6 +130,10 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 12/19 (v18.2) update:
|
||||||
|
- Fix file/folder opening behind the browser window
|
||||||
|
- Add WD14 and BLIP captioning to utilities
|
||||||
|
- Improve overall GUI layout
|
||||||
* 12/18 (v18.1) update:
|
* 12/18 (v18.1) update:
|
||||||
- Add Stable Diffusion model conversion utility. Make sure to run `pip upgrade -U -r requirements.txt` after updating to this release as this introduce new pip requirements.
|
- Add Stable Diffusion model conversion utility. Make sure to run `pip upgrade -U -r requirements.txt` after updating to this release as this introduce new pip requirements.
|
||||||
* 12/17 (v18) update:
|
* 12/17 (v18) update:
|
||||||
|
@ -11,15 +11,18 @@ import subprocess
|
|||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
|
from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
|
||||||
from library.caption_gui import gradio_caption_gui_tab
|
from library.basic_caption_gui import gradio_basic_caption_gui_tab
|
||||||
|
from library.convert_model_gui import gradio_convert_model_tab
|
||||||
|
from library.blip_caption_gui import gradio_blip_caption_gui_tab
|
||||||
|
from library.wd14_caption_gui import gradio_wd14_caption_gui_tab
|
||||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||||
from library.common_gui import (
|
from library.common_gui import (
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
get_file_path,
|
get_file_path,
|
||||||
|
get_saveasfile_path
|
||||||
)
|
)
|
||||||
from library.convert_model_gui import gradio_convert_model_tab
|
from easygui import msgbox
|
||||||
from easygui import filesavebox, msgbox
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -65,19 +68,21 @@ def save_configuration(
|
|||||||
|
|
||||||
if save_as_bool:
|
if save_as_bool:
|
||||||
print('Save as...')
|
print('Save as...')
|
||||||
file_path = filesavebox(
|
# file_path = filesavebox(
|
||||||
'Select the config file to save',
|
# 'Select the config file to save',
|
||||||
default='finetune.json',
|
# default='finetune.json',
|
||||||
filetypes='*.json',
|
# filetypes='*.json',
|
||||||
)
|
# )
|
||||||
|
file_path = get_saveasfile_path(file_path)
|
||||||
else:
|
else:
|
||||||
print('Save...')
|
print('Save...')
|
||||||
if file_path == None or file_path == '':
|
if file_path == None or file_path == '':
|
||||||
file_path = filesavebox(
|
# file_path = filesavebox(
|
||||||
'Select the config file to save',
|
# 'Select the config file to save',
|
||||||
default='finetune.json',
|
# default='finetune.json',
|
||||||
filetypes='*.json',
|
# filetypes='*.json',
|
||||||
)
|
# )
|
||||||
|
file_path = get_saveasfile_path(file_path)
|
||||||
|
|
||||||
if file_path == None:
|
if file_path == None:
|
||||||
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
@ -455,261 +460,258 @@ interface = gr.Blocks(css=css)
|
|||||||
with interface:
|
with interface:
|
||||||
dummy_true = gr.Label(value=True, visible=False)
|
dummy_true = gr.Label(value=True, visible=False)
|
||||||
dummy_false = gr.Label(value=False, visible=False)
|
dummy_false = gr.Label(value=False, visible=False)
|
||||||
gr.Markdown('Enter kohya finetuner parameter using this interface.')
|
with gr.Tab('Dreambooth'):
|
||||||
with gr.Accordion('Configuration File Load/Save', open=False):
|
gr.Markdown('Enter kohya finetuner parameter using this interface.')
|
||||||
with gr.Row():
|
with gr.Accordion('Configuration File Load/Save', open=False):
|
||||||
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
|
with gr.Row():
|
||||||
button_save_config = gr.Button('Save 💾', elem_id='open_folder')
|
button_open_config = gr.Button('Open 📂', elem_id='open_folder')
|
||||||
button_save_as_config = gr.Button(
|
button_save_config = gr.Button('Save 💾', elem_id='open_folder')
|
||||||
'Save as... 💾', elem_id='open_folder'
|
button_save_as_config = gr.Button(
|
||||||
|
'Save as... 💾', elem_id='open_folder'
|
||||||
|
)
|
||||||
|
config_file_name = gr.Textbox(
|
||||||
|
label='',
|
||||||
|
placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
||||||
)
|
)
|
||||||
config_file_name = gr.Textbox(
|
config_file_name.change(
|
||||||
label='',
|
remove_doublequote,
|
||||||
placeholder="type the configuration file path or use the 'Open' button above to select it...",
|
inputs=[config_file_name],
|
||||||
)
|
outputs=[config_file_name],
|
||||||
config_file_name.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[config_file_name],
|
|
||||||
outputs=[config_file_name],
|
|
||||||
)
|
|
||||||
with gr.Tab('Source model'):
|
|
||||||
# Define the input elements
|
|
||||||
with gr.Row():
|
|
||||||
pretrained_model_name_or_path_input = gr.Textbox(
|
|
||||||
label='Pretrained model name or path',
|
|
||||||
placeholder='enter the path to custom model or name of pretrained model',
|
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_fille = gr.Button(
|
with gr.Tab('Source model'):
|
||||||
document_symbol, elem_id='open_folder_small'
|
# Define the input elements
|
||||||
|
with gr.Row():
|
||||||
|
pretrained_model_name_or_path_input = gr.Textbox(
|
||||||
|
label='Pretrained model name or path',
|
||||||
|
placeholder='enter the path to custom model or name of pretrained model',
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_fille = gr.Button(
|
||||||
|
document_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_fille.click(
|
||||||
|
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_folder = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_folder.click(
|
||||||
|
get_folder_path, outputs=pretrained_model_name_or_path_input
|
||||||
|
)
|
||||||
|
model_list = gr.Dropdown(
|
||||||
|
label='(Optional) Model Quick Pick',
|
||||||
|
choices=[
|
||||||
|
'custom',
|
||||||
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
|
'stabilityai/stable-diffusion-2-base',
|
||||||
|
'stabilityai/stable-diffusion-2-1',
|
||||||
|
'stabilityai/stable-diffusion-2',
|
||||||
|
'runwayml/stable-diffusion-v1-5',
|
||||||
|
'CompVis/stable-diffusion-v1-4',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
save_model_as_dropdown = gr.Dropdown(
|
||||||
|
label='Save trained model as',
|
||||||
|
choices=[
|
||||||
|
'same as source model',
|
||||||
|
'ckpt',
|
||||||
|
'diffusers',
|
||||||
|
"diffusers_safetensors",
|
||||||
|
'safetensors',
|
||||||
|
],
|
||||||
|
value='same as source model'
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
v2_input = gr.Checkbox(label='v2', value=True)
|
||||||
|
v_parameterization_input = gr.Checkbox(
|
||||||
|
label='v_parameterization', value=False
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_input.change(
|
||||||
|
remove_doublequote,
|
||||||
|
inputs=[pretrained_model_name_or_path_input],
|
||||||
|
outputs=[pretrained_model_name_or_path_input],
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_fille.click(
|
model_list.change(
|
||||||
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
|
set_pretrained_model_name_or_path_input,
|
||||||
)
|
inputs=[model_list, v2_input, v_parameterization_input],
|
||||||
pretrained_model_name_or_path_folder = gr.Button(
|
outputs=[
|
||||||
folder_symbol, elem_id='open_folder_small'
|
pretrained_model_name_or_path_input,
|
||||||
)
|
v2_input,
|
||||||
pretrained_model_name_or_path_folder.click(
|
v_parameterization_input,
|
||||||
get_folder_path, outputs=pretrained_model_name_or_path_input
|
|
||||||
)
|
|
||||||
model_list = gr.Dropdown(
|
|
||||||
label='(Optional) Model Quick Pick',
|
|
||||||
choices=[
|
|
||||||
'custom',
|
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
|
||||||
'stabilityai/stable-diffusion-2-base',
|
|
||||||
'stabilityai/stable-diffusion-2-1',
|
|
||||||
'stabilityai/stable-diffusion-2',
|
|
||||||
'runwayml/stable-diffusion-v1-5',
|
|
||||||
'CompVis/stable-diffusion-v1-4',
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
save_model_as_dropdown = gr.Dropdown(
|
|
||||||
label='Save trained model as',
|
|
||||||
choices=[
|
|
||||||
'same as source model',
|
|
||||||
'ckpt',
|
|
||||||
'diffusers',
|
|
||||||
"diffusers_safetensors",
|
|
||||||
'safetensors',
|
|
||||||
],
|
|
||||||
value='same as source model'
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
v2_input = gr.Checkbox(label='v2', value=True)
|
|
||||||
v_parameterization_input = gr.Checkbox(
|
|
||||||
label='v_parameterization', value=False
|
|
||||||
)
|
|
||||||
pretrained_model_name_or_path_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[pretrained_model_name_or_path_input],
|
|
||||||
outputs=[pretrained_model_name_or_path_input],
|
|
||||||
)
|
|
||||||
model_list.change(
|
|
||||||
set_pretrained_model_name_or_path_input,
|
|
||||||
inputs=[model_list, v2_input, v_parameterization_input],
|
|
||||||
outputs=[
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Tab('Directories'):
|
|
||||||
with gr.Row():
|
|
||||||
train_data_dir_input = gr.Textbox(
|
|
||||||
label='Image folder',
|
|
||||||
placeholder='Folder where the training folders containing the images are located',
|
|
||||||
)
|
|
||||||
train_data_dir_input_folder = gr.Button(
|
|
||||||
'📂', elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
train_data_dir_input_folder.click(
|
|
||||||
get_folder_path, outputs=train_data_dir_input
|
|
||||||
)
|
|
||||||
reg_data_dir_input = gr.Textbox(
|
|
||||||
label='Regularisation folder',
|
|
||||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
|
||||||
)
|
|
||||||
reg_data_dir_input_folder = gr.Button(
|
|
||||||
'📂', elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
reg_data_dir_input_folder.click(
|
|
||||||
get_folder_path, outputs=reg_data_dir_input
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
output_dir_input = gr.Textbox(
|
|
||||||
label='Output folder',
|
|
||||||
placeholder='Folder to output trained model',
|
|
||||||
)
|
|
||||||
output_dir_input_folder = gr.Button(
|
|
||||||
'📂', elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
output_dir_input_folder.click(
|
|
||||||
get_folder_path, outputs=output_dir_input
|
|
||||||
)
|
|
||||||
logging_dir_input = gr.Textbox(
|
|
||||||
label='Logging folder',
|
|
||||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
|
||||||
)
|
|
||||||
logging_dir_input_folder = gr.Button(
|
|
||||||
'📂', elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
logging_dir_input_folder.click(
|
|
||||||
get_folder_path, outputs=logging_dir_input
|
|
||||||
)
|
|
||||||
train_data_dir_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[train_data_dir_input],
|
|
||||||
outputs=[train_data_dir_input],
|
|
||||||
)
|
|
||||||
reg_data_dir_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[reg_data_dir_input],
|
|
||||||
outputs=[reg_data_dir_input],
|
|
||||||
)
|
|
||||||
output_dir_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[output_dir_input],
|
|
||||||
outputs=[output_dir_input],
|
|
||||||
)
|
|
||||||
logging_dir_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[logging_dir_input],
|
|
||||||
outputs=[logging_dir_input],
|
|
||||||
)
|
|
||||||
with gr.Tab('Training parameters'):
|
|
||||||
with gr.Row():
|
|
||||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
|
||||||
lr_scheduler_input = gr.Dropdown(
|
|
||||||
label='LR Scheduler',
|
|
||||||
choices=[
|
|
||||||
'constant',
|
|
||||||
'constant_with_warmup',
|
|
||||||
'cosine',
|
|
||||||
'cosine_with_restarts',
|
|
||||||
'linear',
|
|
||||||
'polynomial',
|
|
||||||
],
|
|
||||||
value='constant',
|
|
||||||
)
|
|
||||||
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
|
||||||
with gr.Row():
|
|
||||||
train_batch_size_input = gr.Slider(
|
|
||||||
minimum=1,
|
|
||||||
maximum=32,
|
|
||||||
label='Train batch size',
|
|
||||||
value=1,
|
|
||||||
step=1,
|
|
||||||
)
|
|
||||||
epoch_input = gr.Textbox(label='Epoch', value=1)
|
|
||||||
save_every_n_epochs_input = gr.Textbox(
|
|
||||||
label='Save every N epochs', value=1
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
mixed_precision_input = gr.Dropdown(
|
|
||||||
label='Mixed precision',
|
|
||||||
choices=[
|
|
||||||
'no',
|
|
||||||
'fp16',
|
|
||||||
'bf16',
|
|
||||||
],
|
|
||||||
value='fp16',
|
|
||||||
)
|
|
||||||
save_precision_input = gr.Dropdown(
|
|
||||||
label='Save precision',
|
|
||||||
choices=[
|
|
||||||
'float',
|
|
||||||
'fp16',
|
|
||||||
'bf16',
|
|
||||||
],
|
|
||||||
value='fp16',
|
|
||||||
)
|
|
||||||
num_cpu_threads_per_process_input = gr.Slider(
|
|
||||||
minimum=1,
|
|
||||||
maximum=os.cpu_count(),
|
|
||||||
step=1,
|
|
||||||
label='Number of CPU threads per process',
|
|
||||||
value=os.cpu_count(),
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
|
||||||
max_resolution_input = gr.Textbox(
|
|
||||||
label='Max resolution', value='512,512', placeholder='512,512'
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
caption_extention_input = gr.Textbox(
|
|
||||||
label='Caption Extension',
|
|
||||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
|
||||||
)
|
|
||||||
stop_text_encoder_training_input = gr.Slider(
|
|
||||||
minimum=0,
|
|
||||||
maximum=100,
|
|
||||||
value=0,
|
|
||||||
step=1,
|
|
||||||
label='Stop text encoder training',
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
full_fp16_input = gr.Checkbox(
|
|
||||||
label='Full fp16 training (experimental)', value=False
|
|
||||||
)
|
|
||||||
no_token_padding_input = gr.Checkbox(
|
|
||||||
label='No token padding', value=False
|
|
||||||
)
|
|
||||||
|
|
||||||
gradient_checkpointing_input = gr.Checkbox(
|
with gr.Tab('Directories'):
|
||||||
label='Gradient checkpointing', value=False
|
with gr.Row():
|
||||||
|
train_data_dir_input = gr.Textbox(
|
||||||
|
label='Image folder',
|
||||||
|
placeholder='Folder where the training folders containing the images are located',
|
||||||
|
)
|
||||||
|
train_data_dir_input_folder = gr.Button(
|
||||||
|
'📂', elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
train_data_dir_input_folder.click(
|
||||||
|
get_folder_path, outputs=train_data_dir_input
|
||||||
|
)
|
||||||
|
reg_data_dir_input = gr.Textbox(
|
||||||
|
label='Regularisation folder',
|
||||||
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||||
|
)
|
||||||
|
reg_data_dir_input_folder = gr.Button(
|
||||||
|
'📂', elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
reg_data_dir_input_folder.click(
|
||||||
|
get_folder_path, outputs=reg_data_dir_input
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
output_dir_input = gr.Textbox(
|
||||||
|
label='Output folder',
|
||||||
|
placeholder='Folder to output trained model',
|
||||||
|
)
|
||||||
|
output_dir_input_folder = gr.Button(
|
||||||
|
'📂', elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
output_dir_input_folder.click(
|
||||||
|
get_folder_path, outputs=output_dir_input
|
||||||
|
)
|
||||||
|
logging_dir_input = gr.Textbox(
|
||||||
|
label='Logging folder',
|
||||||
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||||
|
)
|
||||||
|
logging_dir_input_folder = gr.Button(
|
||||||
|
'📂', elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
logging_dir_input_folder.click(
|
||||||
|
get_folder_path, outputs=logging_dir_input
|
||||||
|
)
|
||||||
|
train_data_dir_input.change(
|
||||||
|
remove_doublequote,
|
||||||
|
inputs=[train_data_dir_input],
|
||||||
|
outputs=[train_data_dir_input],
|
||||||
)
|
)
|
||||||
with gr.Row():
|
reg_data_dir_input.change(
|
||||||
enable_bucket_input = gr.Checkbox(
|
remove_doublequote,
|
||||||
label='Enable buckets', value=True
|
inputs=[reg_data_dir_input],
|
||||||
|
outputs=[reg_data_dir_input],
|
||||||
)
|
)
|
||||||
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
|
output_dir_input.change(
|
||||||
use_8bit_adam_input = gr.Checkbox(
|
remove_doublequote,
|
||||||
label='Use 8bit adam', value=True
|
inputs=[output_dir_input],
|
||||||
|
outputs=[output_dir_input],
|
||||||
)
|
)
|
||||||
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
logging_dir_input.change(
|
||||||
|
remove_doublequote,
|
||||||
|
inputs=[logging_dir_input],
|
||||||
|
outputs=[logging_dir_input],
|
||||||
|
)
|
||||||
|
with gr.Tab('Training parameters'):
|
||||||
|
with gr.Row():
|
||||||
|
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
||||||
|
lr_scheduler_input = gr.Dropdown(
|
||||||
|
label='LR Scheduler',
|
||||||
|
choices=[
|
||||||
|
'constant',
|
||||||
|
'constant_with_warmup',
|
||||||
|
'cosine',
|
||||||
|
'cosine_with_restarts',
|
||||||
|
'linear',
|
||||||
|
'polynomial',
|
||||||
|
],
|
||||||
|
value='constant',
|
||||||
|
)
|
||||||
|
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
||||||
|
with gr.Row():
|
||||||
|
train_batch_size_input = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=32,
|
||||||
|
label='Train batch size',
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
epoch_input = gr.Textbox(label='Epoch', value=1)
|
||||||
|
save_every_n_epochs_input = gr.Textbox(
|
||||||
|
label='Save every N epochs', value=1
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
mixed_precision_input = gr.Dropdown(
|
||||||
|
label='Mixed precision',
|
||||||
|
choices=[
|
||||||
|
'no',
|
||||||
|
'fp16',
|
||||||
|
'bf16',
|
||||||
|
],
|
||||||
|
value='fp16',
|
||||||
|
)
|
||||||
|
save_precision_input = gr.Dropdown(
|
||||||
|
label='Save precision',
|
||||||
|
choices=[
|
||||||
|
'float',
|
||||||
|
'fp16',
|
||||||
|
'bf16',
|
||||||
|
],
|
||||||
|
value='fp16',
|
||||||
|
)
|
||||||
|
num_cpu_threads_per_process_input = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=os.cpu_count(),
|
||||||
|
step=1,
|
||||||
|
label='Number of CPU threads per process',
|
||||||
|
value=os.cpu_count(),
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||||
|
max_resolution_input = gr.Textbox(
|
||||||
|
label='Max resolution', value='512,512', placeholder='512,512'
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
caption_extention_input = gr.Textbox(
|
||||||
|
label='Caption Extension',
|
||||||
|
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||||
|
)
|
||||||
|
stop_text_encoder_training_input = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=100,
|
||||||
|
value=0,
|
||||||
|
step=1,
|
||||||
|
label='Stop text encoder training',
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
full_fp16_input = gr.Checkbox(
|
||||||
|
label='Full fp16 training (experimental)', value=False
|
||||||
|
)
|
||||||
|
no_token_padding_input = gr.Checkbox(
|
||||||
|
label='No token padding', value=False
|
||||||
|
)
|
||||||
|
|
||||||
|
gradient_checkpointing_input = gr.Checkbox(
|
||||||
|
label='Gradient checkpointing', value=False
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
enable_bucket_input = gr.Checkbox(
|
||||||
|
label='Enable buckets', value=True
|
||||||
|
)
|
||||||
|
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
|
||||||
|
use_8bit_adam_input = gr.Checkbox(
|
||||||
|
label='Use 8bit adam', value=True
|
||||||
|
)
|
||||||
|
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
||||||
|
|
||||||
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
with gr.Tab('Utilities'):
|
with gr.Tab('Utilities'):
|
||||||
# Dreambooth folder creation tab
|
with gr.Tab('Captioning'):
|
||||||
|
gradio_basic_caption_gui_tab()
|
||||||
|
gradio_blip_caption_gui_tab()
|
||||||
|
gradio_wd14_caption_gui_tab()
|
||||||
gradio_dreambooth_folder_creation_tab(
|
gradio_dreambooth_folder_creation_tab(
|
||||||
train_data_dir_input,
|
train_data_dir_input,
|
||||||
reg_data_dir_input,
|
reg_data_dir_input,
|
||||||
output_dir_input,
|
output_dir_input,
|
||||||
logging_dir_input,
|
logging_dir_input,
|
||||||
)
|
)
|
||||||
# Captionning tab
|
|
||||||
gradio_caption_gui_tab()
|
|
||||||
gradio_dataset_balancing_tab()
|
gradio_dataset_balancing_tab()
|
||||||
gradio_convert_model_tab()
|
gradio_convert_model_tab()
|
||||||
# with gr.Tab('Model conversion'):
|
|
||||||
# convert_to_safetensors_input = gr.Checkbox(
|
|
||||||
# label='Convert to SafeTensors', value=True
|
|
||||||
# )
|
|
||||||
# convert_to_ckpt_input = gr.Checkbox(
|
|
||||||
# label='Convert to CKPT', value=False
|
|
||||||
# )
|
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
open_configuration,
|
open_configuration,
|
||||||
|
@ -41,10 +41,10 @@ def caption_images(
|
|||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
def gradio_caption_gui_tab():
|
def gradio_basic_caption_gui_tab():
|
||||||
with gr.Tab('Captioning'):
|
with gr.Tab('Basic Captioning'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This utility will allow the creation of caption files for each images in a folder.'
|
'This utility will allow the creation of simple caption files for each images in a folder.'
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
caption_text_input = gr.Textbox(
|
caption_text_input = gr.Textbox(
|
||||||
@ -64,7 +64,7 @@ def gradio_caption_gui_tab():
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
images_dir_input = gr.Textbox(
|
images_dir_input = gr.Textbox(
|
||||||
label='Image forder to caption',
|
label='Image folder to caption',
|
||||||
placeholder='Directory containing the images to caption',
|
placeholder='Directory containing the images to caption',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
111
library/blip_caption_gui.py
Normal file
111
library/blip_caption_gui.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from easygui import msgbox
|
||||||
|
import subprocess
|
||||||
|
from .common_gui import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
|
def caption_images(
|
||||||
|
train_data_dir,
|
||||||
|
caption_file_ext,
|
||||||
|
batch_size,
|
||||||
|
num_beams,
|
||||||
|
top_p,
|
||||||
|
max_length,
|
||||||
|
min_length,
|
||||||
|
beam_search,
|
||||||
|
):
|
||||||
|
# Check for caption_text_input
|
||||||
|
# if caption_text_input == "":
|
||||||
|
# msgbox("Caption text is missing...")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# Check for images_dir_input
|
||||||
|
if train_data_dir == '':
|
||||||
|
msgbox('Image folder is missing...')
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f'Captioning files in {train_data_dir}...')
|
||||||
|
run_cmd = f'.\\venv\\Scripts\\python.exe "./BLIP_caption/make_captions.py"'
|
||||||
|
run_cmd += f' --batch_size="{int(batch_size)}"'
|
||||||
|
run_cmd += f' --num_beams="{int(num_beams)}"'
|
||||||
|
run_cmd += f' --top_p="{top_p}"'
|
||||||
|
run_cmd += f' --max_length="{int(max_length)}"'
|
||||||
|
run_cmd += f' --min_length="{int(min_length)}"'
|
||||||
|
if beam_search:
|
||||||
|
run_cmd += f' --beam_search'
|
||||||
|
if caption_file_ext != '':
|
||||||
|
run_cmd += f' --caption_extension="{caption_file_ext}"'
|
||||||
|
run_cmd += f' "{train_data_dir}"'
|
||||||
|
run_cmd += f' "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
print('...captioning done')
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
# Gradio UI
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_blip_caption_gui_tab():
|
||||||
|
with gr.Tab('BLIP Captioning'):
|
||||||
|
gr.Markdown(
|
||||||
|
'This utility will use BLIP to caption files for each images in a folder.'
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
train_data_dir = gr.Textbox(
|
||||||
|
label='Image folder to caption',
|
||||||
|
placeholder='Directory containing the images to caption',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
button_train_data_dir_input = gr.Button(
|
||||||
|
'📂', elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
button_train_data_dir_input.click(
|
||||||
|
get_folder_path, outputs=train_data_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
caption_file_ext = gr.Textbox(
|
||||||
|
label='Caption file extension',
|
||||||
|
placeholder='(Optional) Default: .caption',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = gr.Number(
|
||||||
|
value=1, label='Batch size', interactive=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
beam_search = gr.Checkbox(
|
||||||
|
label='Use beam search', interactive=True, value=True
|
||||||
|
)
|
||||||
|
num_beams = gr.Number(
|
||||||
|
value=1, label='Number of beams', interactive=True
|
||||||
|
)
|
||||||
|
top_p = gr.Number(value=0.9, label='Top p', interactive=True)
|
||||||
|
max_length = gr.Number(
|
||||||
|
value=75, label='Max length', interactive=True
|
||||||
|
)
|
||||||
|
min_length = gr.Number(
|
||||||
|
value=5, label='Min length', interactive=True
|
||||||
|
)
|
||||||
|
|
||||||
|
caption_button = gr.Button('Caption images')
|
||||||
|
|
||||||
|
caption_button.click(
|
||||||
|
caption_images,
|
||||||
|
inputs=[
|
||||||
|
train_data_dir,
|
||||||
|
caption_file_ext,
|
||||||
|
batch_size,
|
||||||
|
num_beams,
|
||||||
|
top_p,
|
||||||
|
max_length,
|
||||||
|
min_length,
|
||||||
|
beam_search,
|
||||||
|
],
|
||||||
|
)
|
@ -1,16 +1,17 @@
|
|||||||
from easygui import diropenbox, fileopenbox
|
from tkinter import filedialog, Tk
|
||||||
|
|
||||||
|
def get_file_path(file_path='', defaultextension='.json'):
|
||||||
|
current_file_path = file_path
|
||||||
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
def get_folder_path():
|
root = Tk()
|
||||||
folder_path = diropenbox('Select the directory to use')
|
root.wm_attributes('-topmost', 1)
|
||||||
|
root.withdraw()
|
||||||
|
file_path = filedialog.askopenfilename(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
return folder_path
|
if file_path == '':
|
||||||
|
file_path = current_file_path
|
||||||
|
|
||||||
def get_file_path(file_path):
|
|
||||||
file_path = fileopenbox(
|
|
||||||
'Select the config file to load', default=file_path, filetypes='*.json',
|
|
||||||
)
|
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
@ -20,3 +21,34 @@ def remove_doublequote(file_path):
|
|||||||
file_path = file_path.replace('"', '')
|
file_path = file_path.replace('"', '')
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_folder_path(folder_path=''):
|
||||||
|
current_folder_path = folder_path
|
||||||
|
|
||||||
|
root = Tk()
|
||||||
|
root.wm_attributes('-topmost', 1)
|
||||||
|
root.withdraw()
|
||||||
|
folder_path = filedialog.askdirectory()
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
|
if folder_path == '':
|
||||||
|
folder_path = current_folder_path
|
||||||
|
|
||||||
|
return folder_path
|
||||||
|
|
||||||
|
def get_saveasfile_path(file_path='', defaultextension='.json'):
|
||||||
|
current_file_path = file_path
|
||||||
|
# print(f'current file path: {current_file_path}')
|
||||||
|
|
||||||
|
root = Tk()
|
||||||
|
root.wm_attributes('-topmost', 1)
|
||||||
|
root.withdraw()
|
||||||
|
file_path = filedialog.asksaveasfile(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
|
||||||
|
root.destroy()
|
||||||
|
|
||||||
|
file_path = file_path.name
|
||||||
|
if file_path == '':
|
||||||
|
file_path = current_file_path
|
||||||
|
|
||||||
|
return file_path
|
@ -50,12 +50,15 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
|
|||||||
if not target_save_precision_type == 'unspecified':
|
if not target_save_precision_type == 'unspecified':
|
||||||
run_cmd += f' --{target_save_precision_type}'
|
run_cmd += f' --{target_save_precision_type}'
|
||||||
|
|
||||||
if target_model_type == "diffuser":
|
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
||||||
run_cmd += f' --reference_model="{source_model_type}"'
|
run_cmd += f' --reference_model="{source_model_type}"'
|
||||||
|
|
||||||
|
if target_model_type == 'diffuser_safetensors':
|
||||||
|
run_cmd += ' --use_safetensors'
|
||||||
|
|
||||||
run_cmd += f' "{source_model_input}"'
|
run_cmd += f' "{source_model_input}"'
|
||||||
|
|
||||||
if target_model_type == "diffuser":
|
if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
||||||
target_model_path = os.path.join(target_model_folder_input, target_model_name_input)
|
target_model_path = os.path.join(target_model_folder_input, target_model_name_input)
|
||||||
run_cmd += f' "{target_model_path}"'
|
run_cmd += f' "{target_model_path}"'
|
||||||
else:
|
else:
|
||||||
@ -67,7 +70,7 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
|
|||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
if not target_model_type == "diffuser":
|
if not target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
|
||||||
|
|
||||||
v2_models = ['stabilityai/stable-diffusion-2-1-base',
|
v2_models = ['stabilityai/stable-diffusion-2-1-base',
|
||||||
'stabilityai/stable-diffusion-2-base',]
|
'stabilityai/stable-diffusion-2-base',]
|
||||||
@ -171,6 +174,7 @@ def gradio_convert_model_tab():
|
|||||||
)
|
)
|
||||||
target_model_type = gr.Dropdown(label="Target model type", choices=[
|
target_model_type = gr.Dropdown(label="Target model type", choices=[
|
||||||
'diffuser',
|
'diffuser',
|
||||||
|
'diffuser_safetensors',
|
||||||
'ckpt',
|
'ckpt',
|
||||||
'safetensors',
|
'safetensors',
|
||||||
],)
|
],)
|
||||||
|
73
library/wd14_caption_gui.py
Normal file
73
library/wd14_caption_gui.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import gradio as gr
|
||||||
|
from easygui import msgbox
|
||||||
|
import subprocess
|
||||||
|
from .common_gui import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
|
def caption_images(train_data_dir, caption_extension, batch_size, thresh):
|
||||||
|
# Check for caption_text_input
|
||||||
|
# if caption_text_input == "":
|
||||||
|
# msgbox("Caption text is missing...")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# Check for images_dir_input
|
||||||
|
if train_data_dir == '':
|
||||||
|
msgbox('Image folder is missing...')
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f'Captioning files in {train_data_dir}...')
|
||||||
|
run_cmd = f'accelerate launch "./script/tag_images_by_wd14_tagger.py"'
|
||||||
|
run_cmd += f' --batch_size="{int(batch_size)}"'
|
||||||
|
run_cmd += f' --thresh="{thresh}"'
|
||||||
|
if caption_extension != '':
|
||||||
|
run_cmd += f' --caption_extension="{caption_extension}"'
|
||||||
|
run_cmd += f' "{train_data_dir}"'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
print('...captioning done')
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
# Gradio UI
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_wd14_caption_gui_tab():
|
||||||
|
with gr.Tab('WD14 Captioning'):
|
||||||
|
gr.Markdown(
|
||||||
|
'This utility will use WD14 to caption files for each images in a folder.'
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
train_data_dir = gr.Textbox(
|
||||||
|
label='Image folder to caption',
|
||||||
|
placeholder='Directory containing the images to caption',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
button_train_data_dir_input = gr.Button(
|
||||||
|
'📂', elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
button_train_data_dir_input.click(
|
||||||
|
get_folder_path, outputs=train_data_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
caption_extension = gr.Textbox(
|
||||||
|
label='Caption file extension',
|
||||||
|
placeholder='(Optional) Default: .caption',
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
thresh = gr.Number(value=0.35, label='Threshold')
|
||||||
|
|
||||||
|
batch_size = gr.Number(
|
||||||
|
value=1, label='Batch size', interactive=True
|
||||||
|
)
|
||||||
|
|
||||||
|
caption_button = gr.Button('Caption images')
|
||||||
|
|
||||||
|
caption_button.click(
|
||||||
|
caption_images,
|
||||||
|
inputs=[train_data_dir, caption_extension, batch_size, thresh],
|
||||||
|
)
|
609
mytraining.ps
Normal file
609
mytraining.ps
Normal file
@ -0,0 +1,609 @@
|
|||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_bernard\train_man" `
|
||||||
|
--reg_data_dir="D:\dreambooth\train_bernard\reg_man" `
|
||||||
|
--output_dir="D:\dreambooth\train_bernard" `
|
||||||
|
--prior_loss_weight=1.0 `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=1 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=3000 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--gradient_checkpointing `
|
||||||
|
--save_every_n_epochs=1
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\bernard\asd man-3000-remgb-sd15.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_bernard\train_man" `
|
||||||
|
--reg_data_dir="D:\dreambooth\train_bernard\reg_man" `
|
||||||
|
--output_dir="D:\dreambooth\train_bernard" `
|
||||||
|
--prior_loss_weight=1.0 `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=1 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=1500 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--gradient_checkpointing `
|
||||||
|
--save_every_n_epochs=1
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_bernard\train_man" `
|
||||||
|
--reg_data_dir="D:\dreambooth\train_bernard\reg_man" `
|
||||||
|
--output_dir="D:\dreambooth\train_bernard" `
|
||||||
|
--prior_loss_weight=1.0 `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=1 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=4500 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--gradient_checkpointing `
|
||||||
|
--no_token_padding `
|
||||||
|
--save_every_n_epochs=1
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\source\alex\train" `
|
||||||
|
--output_dir="D:\dreambooth\train_alex" `
|
||||||
|
--prior_loss_weight=1.0 `
|
||||||
|
--resolution="448,640" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=4500 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--shuffle_caption
|
||||||
|
|
||||||
|
# -fine_tuning
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\source\alex\train\50_portrait-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_alex" `
|
||||||
|
--resolution="448,640" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=4500 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--shuffle_caption
|
||||||
|
|
||||||
|
Resume:
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\source\alet_et_bernard\landscape-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_alex_and_bernard" `
|
||||||
|
--resolution="640,448" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=550 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=200 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# Mollie Monger
|
||||||
|
|
||||||
|
e1:
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_mollie_monger\landscape-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_mollie_monger\output" `
|
||||||
|
--resolution="640,448" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=625 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=200 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\mollie_monger-kohya-l-200-sd15.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_mollie_monger\portrait-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_mollie_monger\output" `
|
||||||
|
--resolution="448,640" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=1275 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=200 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\mollie_monger-kohya-l+p-200-sd15.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_mollie_monger\square-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_mollie_monger\output" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=500 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=200 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
e2:
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\mollie_monger\mollie_monger-kohya-l+p+s-r200-e1-sd15.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_mollie_monger\landscape-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_mollie_monger\output" `
|
||||||
|
--resolution="640,448" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=625 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=200 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\mollie_monger\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_mollie_monger\portrait-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_mollie_monger\output" `
|
||||||
|
--resolution="448,640" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=1275 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=200 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\mollie_monger\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_mollie_monger\square-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_mollie_monger\output" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=500 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=200 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
|
||||||
|
Midjourney images download:
|
||||||
|
|
||||||
|
https://storage.googleapis.com/dream-machines-output/2932e6e4-ddef-410e-947b-2a6275e31f35/0_3.png
|
||||||
|
|
||||||
|
# Midjourney
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_midjourney_v4\all data" `
|
||||||
|
--output_dir="D:\dreambooth\train_midjourney_v4\model" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=528 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=12 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\midjourney_v4-khoya-r100-e1-sd15.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_midjourney_v4\data2" `
|
||||||
|
--output_dir="D:\dreambooth\train_midjourney_v4\model" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=850 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=100 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\midjourney_v4_finetune\epoch-000001.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_midjourney_v4\newdata3" `
|
||||||
|
--output_dir="D:\dreambooth\train_midjourney_v4\model" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=159 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=24 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# train n
|
||||||
|
|
||||||
|
# Midjourney
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\train_childrens_drawings\model\last2.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_childrens_drawings\data2-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_childrens_drawings\model" `
|
||||||
|
--resolution="704,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=312 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--fine_tuning_repeat=48 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\train_childrens_drawings\model\last2.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_childrens_drawings\data2-pp" `
|
||||||
|
--output_dir="D:\dreambooth\train_childrens_drawings\model" `
|
||||||
|
--resolution="704,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=312 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=48 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# twq
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\source\bernardv2-ft" `
|
||||||
|
--output_dir="D:\dreambooth\train_bernard\model" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=720 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=48 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# the white queen
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
|
||||||
|
--resolution="704,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=520 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
|
||||||
|
--resolution="512,704" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=260 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=220 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# the white queen slow progress init phase
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
|
||||||
|
--resolution="704,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=260 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=80 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
|
||||||
|
--resolution="512,704" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=130 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=80 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=90 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=80 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# the white queen slow progress extra steps phase
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p+s\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
|
||||||
|
--resolution="704,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=130 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
|
||||||
|
--resolution="512,704" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=65 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=45 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# the queen of heart init phase
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\train_qoh\landscape-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
|
||||||
|
--resolution="704,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=260 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=80 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
|
||||||
|
--resolution="512,704" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=130 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=80 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=90 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=80 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
# the white queen slow progress extra steps phase
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p+s\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
|
||||||
|
--resolution="704,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=130 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
|
||||||
|
--resolution="512,704" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=65 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--save_half
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
|
||||||
|
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
|
||||||
|
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
|
||||||
|
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
|
||||||
|
--resolution="512,512" `
|
||||||
|
--train_batch_size=8 `
|
||||||
|
--learning_rate=1e-6 `
|
||||||
|
--max_train_steps=45 `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision="fp16" `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=1 `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=40 `
|
||||||
|
--seed=23 `
|
||||||
|
--save_half
|
@ -12,4 +12,13 @@ safetensors==0.2.6
|
|||||||
gradio
|
gradio
|
||||||
altair
|
altair
|
||||||
easygui
|
easygui
|
||||||
|
tkinter
|
||||||
|
# for BLIP captioning
|
||||||
|
requests
|
||||||
|
timm
|
||||||
|
fairscale
|
||||||
|
# for WD14 captioning
|
||||||
|
tensorflow<2.11
|
||||||
|
huggingface-hub
|
||||||
|
# for kohya_ss library
|
||||||
.
|
.
|
@ -1,12 +1,17 @@
|
|||||||
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
||||||
# v1: initial version
|
# v1: initial version
|
||||||
|
# v2: support safetensors
|
||||||
|
# v3: fix to support another format
|
||||||
|
# v4: support safetensors in Diffusers
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
from library import model_util as model_util
|
from library import model_util as model_util
|
||||||
|
|
||||||
|
|
||||||
def convert(args):
|
def convert(args):
|
||||||
# 引数を確認する
|
# 引数を確認する
|
||||||
load_dtype = torch.float16 if args.fp16 else None
|
load_dtype = torch.float16 if args.fp16 else None
|
||||||
@ -56,7 +61,7 @@ def convert(args):
|
|||||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||||
else:
|
else:
|
||||||
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
||||||
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae)
|
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
|
||||||
print(f"model saved.")
|
print(f"model saved.")
|
||||||
|
|
||||||
|
|
||||||
@ -76,6 +81,8 @@ if __name__ == '__main__':
|
|||||||
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
|
||||||
parser.add_argument("--reference_model", type=str, default=None,
|
parser.add_argument("--reference_model", type=str, default=None,
|
||||||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
|
||||||
|
parser.add_argument("--use_safetensors", action='store_true',
|
||||||
|
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)")
|
||||||
|
|
||||||
parser.add_argument("model_to_load", type=str, default=None,
|
parser.add_argument("model_to_load", type=str, default=None,
|
||||||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
||||||
|
Loading…
Reference in New Issue
Block a user