- Fix file/folder opening behind the browser window

- Add WD14 and BLIP captioning to utilities
- Improve overall GUI layout
This commit is contained in:
bmaltais 2022-12-19 09:22:52 -05:00
parent 0ca93a7aa7
commit c90aa2cc61
30 changed files with 2757 additions and 271 deletions

7
.gitignore vendored
View File

@ -1,7 +1,8 @@
venv venv
venv1
mytraining.ps
__pycache__ __pycache__
*.txt
cudnn_windows
.vscode .vscode
*.egg-info *.egg-info
build build
wd14_tagger_model

View File

@ -0,0 +1,21 @@
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30524,
"encoder_width": 768,
"add_cross_attention": true
}

View File

@ -0,0 +1,115 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from models.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args):
cwd = os.getcwd()
print('Current Working Directory is: ', cwd)
os.chdir('.\BLIP_caption')
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}")
# image_size = 384
# model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config='configs/med_config.json')
# model.eval()
# model = model.to(device)
image_size = 384
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
model_url = args.caption_weights # 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
model.eval()
model = model.to(device)
print("BLIP loaded")
# 正方形でいいのか? という気がするがソースがそうなので
# transform = transforms.Compose([
# transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
# transforms.ToTensor(),
# transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# ])
# captioningする
def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(device)
with torch.no_grad():
if args.beam_search:
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
max_length=args.max_length, min_length=args.min_length)
else:
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
raw_image = Image.open(image_path)
if raw_image.mode != "RGB":
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
raw_image = raw_image.convert("RGB")
image = transform(raw_image)
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("caption_weights", type=str,
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)

238
BLIP_caption/models/blip.py Normal file
View File

@ -0,0 +1,238 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import warnings
warnings.filterwarnings("ignore")
from models.vit import VisionTransformer, interpolate_pos_embed
from models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
class BLIP_Base(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
def forward(self, image, caption, mode):
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
if mode=='image':
# return image features
image_embeds = self.visual_encoder(image)
return image_embeds
elif mode=='text':
# return text features
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
return text_output.last_hidden_state
elif mode=='multimodal':
# return multimodel features
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text.input_ids[:,0] = self.tokenizer.enc_token_id
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
return output.last_hidden_state
class BLIP_Decoder(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
prompt = 'a picture of ',
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel(config=med_config)
self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
def forward(self, image, caption):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
text.input_ids[:,0] = self.tokenizer.bos_token_id
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:,:self.prompt_length] = -100
decoder_output = self.text_decoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
loss_lm = decoder_output.loss
return loss_lm
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
prompt = [self.prompt] * image.size(0)
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
input_ids[:,0] = self.tokenizer.bos_token_id
input_ids = input_ids[:, :-1]
if sample:
#nucleus sampling
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
**model_kwargs)
else:
#beam search
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs)
captions = []
for output in outputs:
caption = self.tokenizer.decode(output, skip_special_tokens=True)
captions.append(caption[len(self.prompt):])
return captions
def blip_decoder(pretrained='',**kwargs):
model = BLIP_Decoder(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
def blip_feature_extractor(pretrained='',**kwargs):
model = BLIP_Base(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
return visual_encoder, vision_width
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

955
BLIP_caption/models/med.py Normal file
View File

@ -0,0 +1,955 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
* Based on huggingface code base
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
'''
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, device, dtype, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.file_utils import (
ModelOutput,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__)
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if self.config.add_cross_attention:
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
mode=None,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if mode=='multimodal':
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
mode='multimodal',
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
mode=mode,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
mode=mode,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode='multimodal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = inputs_embeds.device
elif encoder_embeds is not None:
input_shape = encoder_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = encoder_embeds.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
device, is_decoder)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if encoder_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
else:
embedding_output = encoder_embeds
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mode=mode,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
reduction='mean',
mode='multimodal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
mode=mode,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if reduction=='none':
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

305
BLIP_caption/models/vit.py Normal file
View File

@ -0,0 +1,305 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
* Based on timm code base
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.vision_transformer import _cfg, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath
from timm.models.helpers import named_apply, adapt_input_conv
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_gradients = None
self.attention_map = None
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def forward(self, x, register_hook=False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if use_grad_checkpointing:
self.attn = checkpoint_wrapper(self.attn)
self.mlp = checkpoint_wrapper(self.mlp)
def forward(self, x, register_hook=False):
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
use_grad_checkpointing=False, ckpt_layer=0):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x, register_blk=-1):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[:,:x.size(1),:]
x = self.pos_drop(x)
for i,blk in enumerate(self.blocks):
x = blk(x, register_blk==i)
x = self.norm(x)
return x
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
# interpolate position embedding
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = visual_encoder.patch_embed.num_patches
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
if orig_size!=new_size:
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
return new_pos_embed
else:
return pos_embed_checkpoint

View File

@ -130,6 +130,10 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
## Change history ## Change history
* 12/19 (v18.2) update:
- Fix file/folder opening behind the browser window
- Add WD14 and BLIP captioning to utilities
- Improve overall GUI layout
* 12/18 (v18.1) update: * 12/18 (v18.1) update:
- Add Stable Diffusion model conversion utility. Make sure to run `pip upgrade -U -r requirements.txt` after updating to this release as this introduce new pip requirements. - Add Stable Diffusion model conversion utility. Make sure to run `pip upgrade -U -r requirements.txt` after updating to this release as this introduce new pip requirements.
* 12/17 (v18) update: * 12/17 (v18) update:

View File

@ -11,15 +11,18 @@ import subprocess
import pathlib import pathlib
import shutil import shutil
from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
from library.caption_gui import gradio_caption_gui_tab from library.basic_caption_gui import gradio_basic_caption_gui_tab
from library.convert_model_gui import gradio_convert_model_tab
from library.blip_caption_gui import gradio_blip_caption_gui_tab
from library.wd14_caption_gui import gradio_wd14_caption_gui_tab
from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.common_gui import ( from library.common_gui import (
get_folder_path, get_folder_path,
remove_doublequote, remove_doublequote,
get_file_path, get_file_path,
get_saveasfile_path
) )
from library.convert_model_gui import gradio_convert_model_tab from easygui import msgbox
from easygui import filesavebox, msgbox
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -65,19 +68,21 @@ def save_configuration(
if save_as_bool: if save_as_bool:
print('Save as...') print('Save as...')
file_path = filesavebox( # file_path = filesavebox(
'Select the config file to save', # 'Select the config file to save',
default='finetune.json', # default='finetune.json',
filetypes='*.json', # filetypes='*.json',
) # )
file_path = get_saveasfile_path(file_path)
else: else:
print('Save...') print('Save...')
if file_path == None or file_path == '': if file_path == None or file_path == '':
file_path = filesavebox( # file_path = filesavebox(
'Select the config file to save', # 'Select the config file to save',
default='finetune.json', # default='finetune.json',
filetypes='*.json', # filetypes='*.json',
) # )
file_path = get_saveasfile_path(file_path)
if file_path == None: if file_path == None:
return original_file_path # In case a file_path was provided and the user decide to cancel the open action return original_file_path # In case a file_path was provided and the user decide to cancel the open action
@ -455,261 +460,258 @@ interface = gr.Blocks(css=css)
with interface: with interface:
dummy_true = gr.Label(value=True, visible=False) dummy_true = gr.Label(value=True, visible=False)
dummy_false = gr.Label(value=False, visible=False) dummy_false = gr.Label(value=False, visible=False)
gr.Markdown('Enter kohya finetuner parameter using this interface.') with gr.Tab('Dreambooth'):
with gr.Accordion('Configuration File Load/Save', open=False): gr.Markdown('Enter kohya finetuner parameter using this interface.')
with gr.Row(): with gr.Accordion('Configuration File Load/Save', open=False):
button_open_config = gr.Button('Open 📂', elem_id='open_folder') with gr.Row():
button_save_config = gr.Button('Save 💾', elem_id='open_folder') button_open_config = gr.Button('Open 📂', elem_id='open_folder')
button_save_as_config = gr.Button( button_save_config = gr.Button('Save 💾', elem_id='open_folder')
'Save as... 💾', elem_id='open_folder' button_save_as_config = gr.Button(
'Save as... 💾', elem_id='open_folder'
)
config_file_name = gr.Textbox(
label='',
placeholder="type the configuration file path or use the 'Open' button above to select it...",
) )
config_file_name = gr.Textbox( config_file_name.change(
label='', remove_doublequote,
placeholder="type the configuration file path or use the 'Open' button above to select it...", inputs=[config_file_name],
) outputs=[config_file_name],
config_file_name.change(
remove_doublequote,
inputs=[config_file_name],
outputs=[config_file_name],
)
with gr.Tab('Source model'):
# Define the input elements
with gr.Row():
pretrained_model_name_or_path_input = gr.Textbox(
label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model',
) )
pretrained_model_name_or_path_fille = gr.Button( with gr.Tab('Source model'):
document_symbol, elem_id='open_folder_small' # Define the input elements
with gr.Row():
pretrained_model_name_or_path_input = gr.Textbox(
label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model',
)
pretrained_model_name_or_path_fille = gr.Button(
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_fille.click(
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
)
pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_folder.click(
get_folder_path, outputs=pretrained_model_name_or_path_input
)
model_list = gr.Dropdown(
label='(Optional) Model Quick Pick',
choices=[
'custom',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
],
)
save_model_as_dropdown = gr.Dropdown(
label='Save trained model as',
choices=[
'same as source model',
'ckpt',
'diffusers',
"diffusers_safetensors",
'safetensors',
],
value='same as source model'
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox(
label='v_parameterization', value=False
)
pretrained_model_name_or_path_input.change(
remove_doublequote,
inputs=[pretrained_model_name_or_path_input],
outputs=[pretrained_model_name_or_path_input],
) )
pretrained_model_name_or_path_fille.click( model_list.change(
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input set_pretrained_model_name_or_path_input,
) inputs=[model_list, v2_input, v_parameterization_input],
pretrained_model_name_or_path_folder = gr.Button( outputs=[
folder_symbol, elem_id='open_folder_small' pretrained_model_name_or_path_input,
) v2_input,
pretrained_model_name_or_path_folder.click( v_parameterization_input,
get_folder_path, outputs=pretrained_model_name_or_path_input
)
model_list = gr.Dropdown(
label='(Optional) Model Quick Pick',
choices=[
'custom',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
], ],
) )
save_model_as_dropdown = gr.Dropdown(
label='Save trained model as',
choices=[
'same as source model',
'ckpt',
'diffusers',
"diffusers_safetensors",
'safetensors',
],
value='same as source model'
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox(
label='v_parameterization', value=False
)
pretrained_model_name_or_path_input.change(
remove_doublequote,
inputs=[pretrained_model_name_or_path_input],
outputs=[pretrained_model_name_or_path_input],
)
model_list.change(
set_pretrained_model_name_or_path_input,
inputs=[model_list, v2_input, v_parameterization_input],
outputs=[
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
],
)
with gr.Tab('Directories'):
with gr.Row():
train_data_dir_input = gr.Textbox(
label='Image folder',
placeholder='Folder where the training folders containing the images are located',
)
train_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
train_data_dir_input_folder.click(
get_folder_path, outputs=train_data_dir_input
)
reg_data_dir_input = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
)
reg_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
reg_data_dir_input_folder.click(
get_folder_path, outputs=reg_data_dir_input
)
with gr.Row():
output_dir_input = gr.Textbox(
label='Output folder',
placeholder='Folder to output trained model',
)
output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
output_dir_input_folder.click(
get_folder_path, outputs=output_dir_input
)
logging_dir_input = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
logging_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
logging_dir_input_folder.click(
get_folder_path, outputs=logging_dir_input
)
train_data_dir_input.change(
remove_doublequote,
inputs=[train_data_dir_input],
outputs=[train_data_dir_input],
)
reg_data_dir_input.change(
remove_doublequote,
inputs=[reg_data_dir_input],
outputs=[reg_data_dir_input],
)
output_dir_input.change(
remove_doublequote,
inputs=[output_dir_input],
outputs=[output_dir_input],
)
logging_dir_input.change(
remove_doublequote,
inputs=[logging_dir_input],
outputs=[logging_dir_input],
)
with gr.Tab('Training parameters'):
with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
lr_scheduler_input = gr.Dropdown(
label='LR Scheduler',
choices=[
'constant',
'constant_with_warmup',
'cosine',
'cosine_with_restarts',
'linear',
'polynomial',
],
value='constant',
)
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
with gr.Row():
train_batch_size_input = gr.Slider(
minimum=1,
maximum=32,
label='Train batch size',
value=1,
step=1,
)
epoch_input = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs_input = gr.Textbox(
label='Save every N epochs', value=1
)
with gr.Row():
mixed_precision_input = gr.Dropdown(
label='Mixed precision',
choices=[
'no',
'fp16',
'bf16',
],
value='fp16',
)
save_precision_input = gr.Dropdown(
label='Save precision',
choices=[
'float',
'fp16',
'bf16',
],
value='fp16',
)
num_cpu_threads_per_process_input = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label='Number of CPU threads per process',
value=os.cpu_count(),
)
with gr.Row():
seed_input = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox(
label='Max resolution', value='512,512', placeholder='512,512'
)
with gr.Row():
caption_extention_input = gr.Textbox(
label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption',
)
stop_text_encoder_training_input = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=1,
label='Stop text encoder training',
)
with gr.Row():
full_fp16_input = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
)
no_token_padding_input = gr.Checkbox(
label='No token padding', value=False
)
gradient_checkpointing_input = gr.Checkbox( with gr.Tab('Directories'):
label='Gradient checkpointing', value=False with gr.Row():
train_data_dir_input = gr.Textbox(
label='Image folder',
placeholder='Folder where the training folders containing the images are located',
)
train_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
train_data_dir_input_folder.click(
get_folder_path, outputs=train_data_dir_input
)
reg_data_dir_input = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
)
reg_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
reg_data_dir_input_folder.click(
get_folder_path, outputs=reg_data_dir_input
)
with gr.Row():
output_dir_input = gr.Textbox(
label='Output folder',
placeholder='Folder to output trained model',
)
output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
output_dir_input_folder.click(
get_folder_path, outputs=output_dir_input
)
logging_dir_input = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
logging_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
)
logging_dir_input_folder.click(
get_folder_path, outputs=logging_dir_input
)
train_data_dir_input.change(
remove_doublequote,
inputs=[train_data_dir_input],
outputs=[train_data_dir_input],
) )
with gr.Row(): reg_data_dir_input.change(
enable_bucket_input = gr.Checkbox( remove_doublequote,
label='Enable buckets', value=True inputs=[reg_data_dir_input],
outputs=[reg_data_dir_input],
) )
cache_latent_input = gr.Checkbox(label='Cache latent', value=True) output_dir_input.change(
use_8bit_adam_input = gr.Checkbox( remove_doublequote,
label='Use 8bit adam', value=True inputs=[output_dir_input],
outputs=[output_dir_input],
) )
xformers_input = gr.Checkbox(label='Use xformers', value=True) logging_dir_input.change(
remove_doublequote,
inputs=[logging_dir_input],
outputs=[logging_dir_input],
)
with gr.Tab('Training parameters'):
with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
lr_scheduler_input = gr.Dropdown(
label='LR Scheduler',
choices=[
'constant',
'constant_with_warmup',
'cosine',
'cosine_with_restarts',
'linear',
'polynomial',
],
value='constant',
)
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
with gr.Row():
train_batch_size_input = gr.Slider(
minimum=1,
maximum=32,
label='Train batch size',
value=1,
step=1,
)
epoch_input = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs_input = gr.Textbox(
label='Save every N epochs', value=1
)
with gr.Row():
mixed_precision_input = gr.Dropdown(
label='Mixed precision',
choices=[
'no',
'fp16',
'bf16',
],
value='fp16',
)
save_precision_input = gr.Dropdown(
label='Save precision',
choices=[
'float',
'fp16',
'bf16',
],
value='fp16',
)
num_cpu_threads_per_process_input = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label='Number of CPU threads per process',
value=os.cpu_count(),
)
with gr.Row():
seed_input = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox(
label='Max resolution', value='512,512', placeholder='512,512'
)
with gr.Row():
caption_extention_input = gr.Textbox(
label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption',
)
stop_text_encoder_training_input = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=1,
label='Stop text encoder training',
)
with gr.Row():
full_fp16_input = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
)
no_token_padding_input = gr.Checkbox(
label='No token padding', value=False
)
gradient_checkpointing_input = gr.Checkbox(
label='Gradient checkpointing', value=False
)
with gr.Row():
enable_bucket_input = gr.Checkbox(
label='Enable buckets', value=True
)
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
use_8bit_adam_input = gr.Checkbox(
label='Use 8bit adam', value=True
)
xformers_input = gr.Checkbox(label='Use xformers', value=True)
button_run = gr.Button('Train model')
with gr.Tab('Utilities'): with gr.Tab('Utilities'):
# Dreambooth folder creation tab with gr.Tab('Captioning'):
gradio_basic_caption_gui_tab()
gradio_blip_caption_gui_tab()
gradio_wd14_caption_gui_tab()
gradio_dreambooth_folder_creation_tab( gradio_dreambooth_folder_creation_tab(
train_data_dir_input, train_data_dir_input,
reg_data_dir_input, reg_data_dir_input,
output_dir_input, output_dir_input,
logging_dir_input, logging_dir_input,
) )
# Captionning tab
gradio_caption_gui_tab()
gradio_dataset_balancing_tab() gradio_dataset_balancing_tab()
gradio_convert_model_tab() gradio_convert_model_tab()
# with gr.Tab('Model conversion'):
# convert_to_safetensors_input = gr.Checkbox(
# label='Convert to SafeTensors', value=True
# )
# convert_to_ckpt_input = gr.Checkbox(
# label='Convert to CKPT', value=False
# )
button_run = gr.Button('Train model')
button_open_config.click( button_open_config.click(
open_configuration, open_configuration,

View File

@ -41,10 +41,10 @@ def caption_images(
### ###
def gradio_caption_gui_tab(): def gradio_basic_caption_gui_tab():
with gr.Tab('Captioning'): with gr.Tab('Basic Captioning'):
gr.Markdown( gr.Markdown(
'This utility will allow the creation of caption files for each images in a folder.' 'This utility will allow the creation of simple caption files for each images in a folder.'
) )
with gr.Row(): with gr.Row():
caption_text_input = gr.Textbox( caption_text_input = gr.Textbox(
@ -64,7 +64,7 @@ def gradio_caption_gui_tab():
) )
with gr.Row(): with gr.Row():
images_dir_input = gr.Textbox( images_dir_input = gr.Textbox(
label='Image forder to caption', label='Image folder to caption',
placeholder='Directory containing the images to caption', placeholder='Directory containing the images to caption',
interactive=True, interactive=True,
) )

111
library/blip_caption_gui.py Normal file
View File

@ -0,0 +1,111 @@
import gradio as gr
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path
def caption_images(
train_data_dir,
caption_file_ext,
batch_size,
num_beams,
top_p,
max_length,
min_length,
beam_search,
):
# Check for caption_text_input
# if caption_text_input == "":
# msgbox("Caption text is missing...")
# return
# Check for images_dir_input
if train_data_dir == '':
msgbox('Image folder is missing...')
return
print(f'Captioning files in {train_data_dir}...')
run_cmd = f'.\\venv\\Scripts\\python.exe "./BLIP_caption/make_captions.py"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --num_beams="{int(num_beams)}"'
run_cmd += f' --top_p="{top_p}"'
run_cmd += f' --max_length="{int(max_length)}"'
run_cmd += f' --min_length="{int(min_length)}"'
if beam_search:
run_cmd += f' --beam_search'
if caption_file_ext != '':
run_cmd += f' --caption_extension="{caption_file_ext}"'
run_cmd += f' "{train_data_dir}"'
run_cmd += f' "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
print('...captioning done')
###
# Gradio UI
###
def gradio_blip_caption_gui_tab():
with gr.Tab('BLIP Captioning'):
gr.Markdown(
'This utility will use BLIP to caption files for each images in a folder.'
)
with gr.Row():
train_data_dir = gr.Textbox(
label='Image folder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_train_data_dir_input = gr.Button(
'📂', elem_id='open_folder_small'
)
button_train_data_dir_input.click(
get_folder_path, outputs=train_data_dir
)
caption_file_ext = gr.Textbox(
label='Caption file extension',
placeholder='(Optional) Default: .caption',
interactive=True,
)
batch_size = gr.Number(
value=1, label='Batch size', interactive=True
)
with gr.Row():
beam_search = gr.Checkbox(
label='Use beam search', interactive=True, value=True
)
num_beams = gr.Number(
value=1, label='Number of beams', interactive=True
)
top_p = gr.Number(value=0.9, label='Top p', interactive=True)
max_length = gr.Number(
value=75, label='Max length', interactive=True
)
min_length = gr.Number(
value=5, label='Min length', interactive=True
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
train_data_dir,
caption_file_ext,
batch_size,
num_beams,
top_p,
max_length,
min_length,
beam_search,
],
)

View File

@ -1,16 +1,17 @@
from easygui import diropenbox, fileopenbox from tkinter import filedialog, Tk
def get_file_path(file_path='', defaultextension='.json'):
def get_folder_path(): current_file_path = file_path
folder_path = diropenbox('Select the directory to use') # print(f'current file path: {current_file_path}')
return folder_path root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
def get_file_path(file_path): file_path = filedialog.askopenfilename(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
file_path = fileopenbox( root.destroy()
'Select the config file to load', default=file_path, filetypes='*.json',
) if file_path == '':
file_path = current_file_path
return file_path return file_path
@ -20,3 +21,34 @@ def remove_doublequote(file_path):
file_path = file_path.replace('"', '') file_path = file_path.replace('"', '')
return file_path return file_path
def get_folder_path(folder_path=''):
current_folder_path = folder_path
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
folder_path = filedialog.askdirectory()
root.destroy()
if folder_path == '':
folder_path = current_folder_path
return folder_path
def get_saveasfile_path(file_path='', defaultextension='.json'):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
file_path = filedialog.asksaveasfile(filetypes = (("Config files", "*.json"), ("All files", "*")), defaultextension=defaultextension)
root.destroy()
file_path = file_path.name
if file_path == '':
file_path = current_file_path
return file_path

View File

@ -50,12 +50,15 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
if not target_save_precision_type == 'unspecified': if not target_save_precision_type == 'unspecified':
run_cmd += f' --{target_save_precision_type}' run_cmd += f' --{target_save_precision_type}'
if target_model_type == "diffuser": if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
run_cmd += f' --reference_model="{source_model_type}"' run_cmd += f' --reference_model="{source_model_type}"'
if target_model_type == 'diffuser_safetensors':
run_cmd += ' --use_safetensors'
run_cmd += f' "{source_model_input}"' run_cmd += f' "{source_model_input}"'
if target_model_type == "diffuser": if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
target_model_path = os.path.join(target_model_folder_input, target_model_name_input) target_model_path = os.path.join(target_model_folder_input, target_model_name_input)
run_cmd += f' "{target_model_path}"' run_cmd += f' "{target_model_path}"'
else: else:
@ -67,7 +70,7 @@ def convert_model(source_model_input, source_model_type, target_model_folder_inp
# Run the command # Run the command
subprocess.run(run_cmd) subprocess.run(run_cmd)
if not target_model_type == "diffuser": if not target_model_type == "diffuser" or target_model_type == "diffuser_safetensors":
v2_models = ['stabilityai/stable-diffusion-2-1-base', v2_models = ['stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',] 'stabilityai/stable-diffusion-2-base',]
@ -171,6 +174,7 @@ def gradio_convert_model_tab():
) )
target_model_type = gr.Dropdown(label="Target model type", choices=[ target_model_type = gr.Dropdown(label="Target model type", choices=[
'diffuser', 'diffuser',
'diffuser_safetensors',
'ckpt', 'ckpt',
'safetensors', 'safetensors',
],) ],)

View File

@ -0,0 +1,73 @@
import gradio as gr
from easygui import msgbox
import subprocess
from .common_gui import get_folder_path
def caption_images(train_data_dir, caption_extension, batch_size, thresh):
# Check for caption_text_input
# if caption_text_input == "":
# msgbox("Caption text is missing...")
# return
# Check for images_dir_input
if train_data_dir == '':
msgbox('Image folder is missing...')
return
print(f'Captioning files in {train_data_dir}...')
run_cmd = f'accelerate launch "./script/tag_images_by_wd14_tagger.py"'
run_cmd += f' --batch_size="{int(batch_size)}"'
run_cmd += f' --thresh="{thresh}"'
if caption_extension != '':
run_cmd += f' --caption_extension="{caption_extension}"'
run_cmd += f' "{train_data_dir}"'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
print('...captioning done')
###
# Gradio UI
###
def gradio_wd14_caption_gui_tab():
with gr.Tab('WD14 Captioning'):
gr.Markdown(
'This utility will use WD14 to caption files for each images in a folder.'
)
with gr.Row():
train_data_dir = gr.Textbox(
label='Image folder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_train_data_dir_input = gr.Button(
'📂', elem_id='open_folder_small'
)
button_train_data_dir_input.click(
get_folder_path, outputs=train_data_dir
)
caption_extension = gr.Textbox(
label='Caption file extension',
placeholder='(Optional) Default: .caption',
interactive=True,
)
thresh = gr.Number(value=0.35, label='Threshold')
batch_size = gr.Number(
value=1, label='Batch size', interactive=True
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[train_data_dir, caption_extension, batch_size, thresh],
)

609
mytraining.ps Normal file
View File

@ -0,0 +1,609 @@
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned.ckpt" `
--train_data_dir="D:\dreambooth\train_bernard\train_man" `
--reg_data_dir="D:\dreambooth\train_bernard\reg_man" `
--output_dir="D:\dreambooth\train_bernard" `
--prior_loss_weight=1.0 `
--resolution="512,512" `
--train_batch_size=1 `
--learning_rate=1e-6 `
--max_train_steps=3000 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--gradient_checkpointing `
--save_every_n_epochs=1
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
--pretrained_model_name_or_path="D:\models\bernard\asd man-3000-remgb-sd15.ckpt" `
--train_data_dir="D:\dreambooth\train_bernard\train_man" `
--reg_data_dir="D:\dreambooth\train_bernard\reg_man" `
--output_dir="D:\dreambooth\train_bernard" `
--prior_loss_weight=1.0 `
--resolution="512,512" `
--train_batch_size=1 `
--learning_rate=1e-6 `
--max_train_steps=1500 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--gradient_checkpointing `
--save_every_n_epochs=1
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\train_bernard\train_man" `
--reg_data_dir="D:\dreambooth\train_bernard\reg_man" `
--output_dir="D:\dreambooth\train_bernard" `
--prior_loss_weight=1.0 `
--resolution="512,512" `
--train_batch_size=1 `
--learning_rate=1e-6 `
--max_train_steps=4500 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--gradient_checkpointing `
--no_token_padding `
--save_every_n_epochs=1
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\source\alex\train" `
--output_dir="D:\dreambooth\train_alex" `
--prior_loss_weight=1.0 `
--resolution="448,640" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=4500 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--shuffle_caption
# -fine_tuning
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\source\alex\train\50_portrait-pp" `
--output_dir="D:\dreambooth\train_alex" `
--resolution="448,640" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=4500 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--shuffle_caption
Resume:
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\source\alet_et_bernard\landscape-pp" `
--output_dir="D:\dreambooth\train_alex_and_bernard" `
--resolution="640,448" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=550 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=200 `
--seed=23 `
--save_half
# Mollie Monger
e1:
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\train_mollie_monger\landscape-pp" `
--output_dir="D:\dreambooth\train_mollie_monger\output" `
--resolution="640,448" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=625 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=200 `
--seed=23 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\mollie_monger-kohya-l-200-sd15.ckpt" `
--train_data_dir="D:\dreambooth\train_mollie_monger\portrait-pp" `
--output_dir="D:\dreambooth\train_mollie_monger\output" `
--resolution="448,640" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=1275 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=200 `
--seed=23 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\mollie_monger-kohya-l+p-200-sd15.ckpt" `
--train_data_dir="D:\dreambooth\train_mollie_monger\square-pp" `
--output_dir="D:\dreambooth\train_mollie_monger\output" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=500 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=200 `
--seed=23 `
--save_half
e2:
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\mollie_monger\mollie_monger-kohya-l+p+s-r200-e1-sd15.ckpt" `
--train_data_dir="D:\dreambooth\train_mollie_monger\landscape-pp" `
--output_dir="D:\dreambooth\train_mollie_monger\output" `
--resolution="640,448" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=625 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=200 `
--seed=23 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\mollie_monger\last.ckpt" `
--train_data_dir="D:\dreambooth\train_mollie_monger\portrait-pp" `
--output_dir="D:\dreambooth\train_mollie_monger\output" `
--resolution="448,640" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=1275 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=200 `
--seed=23 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\mollie_monger\last.ckpt" `
--train_data_dir="D:\dreambooth\train_mollie_monger\square-pp" `
--output_dir="D:\dreambooth\train_mollie_monger\output" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=500 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=200 `
--seed=23 `
--save_half
Midjourney images download:
https://storage.googleapis.com/dream-machines-output/2932e6e4-ddef-410e-947b-2a6275e31f35/0_3.png
# Midjourney
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\train_midjourney_v4\all data" `
--output_dir="D:\dreambooth\train_midjourney_v4\model" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=528 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=12 `
--seed=23 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\midjourney_v4-khoya-r100-e1-sd15.ckpt" `
--train_data_dir="D:\dreambooth\train_midjourney_v4\data2" `
--output_dir="D:\dreambooth\train_midjourney_v4\model" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=850 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=100 `
--seed=23 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\models\midjourney_v4_finetune\epoch-000001.ckpt" `
--train_data_dir="D:\dreambooth\train_midjourney_v4\newdata3" `
--output_dir="D:\dreambooth\train_midjourney_v4\model" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=159 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=24 `
--seed=23 `
--save_half
# train n
# Midjourney
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v6-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\train_childrens_drawings\model\last2.ckpt" `
--train_data_dir="D:\dreambooth\train_childrens_drawings\data2-pp" `
--output_dir="D:\dreambooth\train_childrens_drawings\model" `
--resolution="704,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=312 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--fine_tuning_repeat=48 `
--seed=23 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\train_childrens_drawings\model\last2.ckpt" `
--train_data_dir="D:\dreambooth\train_childrens_drawings\data2-pp" `
--output_dir="D:\dreambooth\train_childrens_drawings\model" `
--resolution="704,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=312 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=48 `
--seed=23 `
--save_half
# twq
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\source\bernardv2-ft" `
--output_dir="D:\dreambooth\train_bernard\model" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=720 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=48 `
--save_half
# the white queen
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
--resolution="704,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=520 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
--resolution="512,704" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=260 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=220 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--seed=23 `
--save_half
# the white queen slow progress init phase
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
--resolution="704,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=260 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=80 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
--resolution="512,704" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=130 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=80 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=90 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=80 `
--seed=23 `
--save_half
# the white queen slow progress extra steps phase
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p+s\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
--resolution="704,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=130 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
--resolution="512,704" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=65 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=45 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--seed=23 `
--save_half
# the queen of heart init phase
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\models\v1-5-pruned-mse-vae.ckpt" `
--train_data_dir="D:\dreambooth\train_qoh\landscape-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
--resolution="704,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=260 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=80 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
--resolution="512,704" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=130 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=80 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=90 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=80 `
--seed=23 `
--save_half
# the white queen slow progress extra steps phase
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p+s\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\landscape-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l" `
--resolution="704,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=130 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\portrait-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p" `
--resolution="512,704" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=65 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--save_half
accelerate launch --num_cpu_threads_per_process 6 train_db_fixed_v7-ber.py `
--pretrained_model_name_or_path="D:\dreambooth\training_twq\the_white_queen\model+l+p\last.ckpt" `
--train_data_dir="D:\dreambooth\training_twq\the_white_queen\square-ft" `
--output_dir="D:\dreambooth\training_twq\the_white_queen\model+l+p+s" `
--resolution="512,512" `
--train_batch_size=8 `
--learning_rate=1e-6 `
--max_train_steps=45 `
--use_8bit_adam `
--xformers `
--mixed_precision="fp16" `
--cache_latents `
--save_every_n_epochs=1 `
--fine_tuning `
--dataset_repeats=40 `
--seed=23 `
--save_half

View File

@ -12,4 +12,13 @@ safetensors==0.2.6
gradio gradio
altair altair
easygui easygui
tkinter
# for BLIP captioning
requests
timm
fairscale
# for WD14 captioning
tensorflow<2.11
huggingface-hub
# for kohya_ss library
. .

View File

@ -1,12 +1,17 @@
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion # convert Diffusers v1.x/v2.0 model to original Stable Diffusion
# v1: initial version # v1: initial version
# v2: support safetensors
# v3: fix to support another format
# v4: support safetensors in Diffusers
import argparse import argparse
import os import os
import torch import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from library import model_util as model_util from library import model_util as model_util
def convert(args): def convert(args):
# 引数を確認する # 引数を確認する
load_dtype = torch.float16 if args.fp16 else None load_dtype = torch.float16 if args.fp16 else None
@ -56,7 +61,7 @@ def convert(args):
print(f"model saved. total converted state_dict keys: {key_count}") print(f"model saved. total converted state_dict keys: {key_count}")
else: else:
print(f"copy scheduler/tokenizer config from: {args.reference_model}") print(f"copy scheduler/tokenizer config from: {args.reference_model}")
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae) model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
print(f"model saved.") print(f"model saved.")
@ -76,6 +81,8 @@ if __name__ == '__main__':
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
parser.add_argument("--reference_model", type=str, default=None, parser.add_argument("--reference_model", type=str, default=None,
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
parser.add_argument("--use_safetensors", action='store_true',
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存するcheckpointは拡張子で自動判定")
parser.add_argument("model_to_load", type=str, default=None, parser.add_argument("model_to_load", type=str, default=None,
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")