Merge branch 'AUTOMATIC1111:master' into sub-quad_attn_opt

This commit is contained in:
brkirch 2023-01-04 00:40:16 -05:00
commit f6ab5a39d7
40 changed files with 1567 additions and 663 deletions

View File

@ -127,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git

View File

@ -0,0 +1,72 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"

View File

@ -0,0 +1,50 @@
import random
from modules import script_callbacks, shared
import gradio as gr
art_symbol = '\U0001f3a8' # 🎨
global_prompt = None
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def add_roll_button(prompt):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
prompt,
],
outputs=[
prompt,
]
)
def after_component(component, **kwargs):
global global_prompt
elem_id = kwargs.get('elem_id', None)
if elem_id not in related_ids:
return
if elem_id == "txt2img_prompt":
global_prompt = component
elif elem_id == "txt2img_clear_prompt":
add_roll_button(global_prompt)
elif elem_id == "img2img_prompt":
global_prompt = component
elif elem_id == "img2img_clear_prompt":
add_roll_button(global_prompt)
script_callbacks.on_after_component(after_component)

9
html/footer.html Normal file
View File

@ -0,0 +1,9 @@
<div>
<a href="/docs">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>

392
html/licenses.html Normal file
View File

@ -0,0 +1,392 @@
<style>
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
#licenses small {font-size: 0.95em; opacity: 0.85;}
#licenses pre { margin: 1em 0 2em 0;}
</style>
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
<pre>
S-Lab License 1.0
Copyright 2022 S-Lab
Redistribution and use for non-commercial purpose in source and
binary forms, with or without modification, are permitted provided
that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
In the event that redistribution and/or use for commercial purpose in
source or binary forms, with or without modification is required,
please contact the contributor(s) of the work.
</pre>
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
<small>Code for architecture and reading models copied.</small>
<pre>
MIT License
Copyright (c) 2021 victorca25
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
<small>Some code is copied to support ESRGAN models.</small>
<pre>
BSD 3-Clause License
Copyright (c) 2021, Xintao Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
</pre>
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
<pre>
MIT License
Copyright (c) 2022 InvokeAI Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
<small>Code added by contirubtors, most likely copied from this repository.</small>
<pre>
MIT License
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
<small>Some small amounts of code borrowed and reworked.</small>
<pre>
MIT License
Copyright (c) 2022 pharmapsychotic
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
<small>Code added by contirubtors, most likely copied from this repository.</small>
<pre>
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2021] [SwinIR Authors]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
</pre>

View File

@ -19,7 +19,7 @@ function selected_gallery_index(){
function extract_image_from_gallery(gallery){ function extract_image_from_gallery(gallery){
if(gallery.length == 1){ if(gallery.length == 1){
return gallery[0] return [gallery[0]]
} }
index = selected_gallery_index() index = selected_gallery_index()
@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
return [null] return [null]
} }
return gallery[index]; return [gallery[index]];
} }
function args_to_array(args){ function args_to_array(args){
@ -188,6 +188,17 @@ onUiUpdate(function(){
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
} }
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
elem.style.display = "block";
})
}
}
}) })
let txt2img_textarea, img2img_textarea = undefined; let txt2img_textarea, img2img_textarea = undefined;

View File

@ -100,6 +100,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
@ -121,7 +122,6 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
@ -129,15 +129,14 @@ class Api:
) )
if populate.sampler_name: if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on populate.sampler_index = None # prevent a warning later on
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
shared.state.begin()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
shared.state.begin()
processed = process_images(p)
shared.state.end()
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
@ -153,7 +152,6 @@ class Api:
mask = decode_base64_to_image(mask) mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
@ -165,16 +163,14 @@ class Api:
args = vars(populate) args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
p = StableDiffusionProcessingImg2Img(**args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.end() shared.state.begin()
processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
@ -332,6 +328,26 @@ class Api:
def get_artists(self): def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {
"step": embedding.step,
"sd_checkpoint": embedding.sd_checkpoint,
"sd_checkpoint_name": embedding.sd_checkpoint_name,
"shape": embedding.shape,
"vectors": embedding.vectors,
}
def convert_embeddings(embeddings):
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {
"loaded": convert_embeddings(db.word_embeddings),
"skipped": convert_embeddings(db.skipped_embeddings),
}
def refresh_checkpoints(self): def refresh_checkpoints(self):
shared.refresh_checkpoints() shared.refresh_checkpoints()

View File

@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score") score: float = Field(title="Score")
category: str = Field(title="Category") category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")

View File

@ -303,6 +303,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True result_is_inpainting_model = True
else: else:
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
theta_0[key] = theta_func2(a, b, multiplier) theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half: if save_as_half:

View File

@ -1,12 +1,13 @@
import base64 import base64
import io import io
import math
import os import os
import re import re
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared, ui_tempdir
import tempfile import tempfile
from PIL import Image from PIL import Image
@ -36,9 +37,12 @@ def quote(text):
def image_from_url_text(filedata): def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]: if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"] filename = filedata["name"]
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs) is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories' assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename) return Image.open(filename)
@ -93,7 +97,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list): def create_buttons(tabs_list):
buttons = {} buttons = {}
for tab in tabs_list: for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}") buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons return buttons
@ -102,35 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info]) bind_list.append([buttons, send_image, send_generate_info])
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
else:
img = image_from_url_text(x)
if shared.opts.send_size and isinstance(img, Image.Image):
w = img.width
h = img.height
else:
w = gr.update()
h = gr.update()
return img, w, h
def run_bind(): def run_bind():
for buttons, send_image, send_generate_info in bind_list: for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons: for tab in buttons:
button = buttons[tab] button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]: destination_image_component = paste_fields[tab]["init_img"]
if type(send_image) == gr.Gallery: fields = paste_fields[tab]["fields"]
button.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
else:
button.click(
fn=lambda x: x,
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
if send_generate_info and paste_fields[tab]["fields"] is not None: destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
if source_image_component and destination_image_component:
if isinstance(source_image_component, gr.Gallery):
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = "extract_image_from_gallery"
else:
func = send_image_and_dimensions if destination_width_component else lambda x: x
jsfunc = None
button.click(
fn=func,
_js=jsfunc,
inputs=[source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
)
if send_generate_info and fields is not None:
if send_generate_info in paste_fields: if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else []) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click( button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names],
) )
else: else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info) connect_paste(button, fields, send_generate_info)
button.click( button.click(
fn=None, fn=None,
@ -164,6 +190,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
return None return None
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
if firstpass_width is None or firstpass_height is None:
return
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
width = int(res.get("Size-1", 512))
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
# old algorithm for auto-calculating first pass size
desired_pixel_count = 512 * 512
actual_pixel_count = width * height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
firstpass_width = math.ceil(scale * width / 64) * 64
firstpass_height = math.ceil(scale * height / 64) * 64
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
res['Hires upscale'] = hr_scale
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
@ -221,6 +276,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None) hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
restore_old_hires_fix_params(res)
return res return res

View File

@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows) cols = math.ceil(len(imgs) / rows)
w, h = imgs[0].size params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
grid = Image.new('RGB', size=(cols * w, rows * h), color='black') script_callbacks.image_grid_callback(params)
for i, img in enumerate(imgs): w, h = imgs[0].size
grid.paste(img, box=(i % cols * w, i // cols * h)) grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
for i, img in enumerate(params.imgs):
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid return grid
@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
def resize_image(resize_mode, im, width, height): def resize_image(resize_mode, im, width, height, upscaler_name=None):
"""
Resizes an image with the specified resize_mode, width, and height.
Args:
resize_mode: The mode to use when resizing the image.
0: Resize the image to the specified width and height.
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
im: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
"""
upscaler_name = upscaler_name or opts.upscaler_for_img2img
def resize(im, w, h): def resize(im, w, h):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height) scale = max(w / im.width, h / im.height)
if scale > 1.0: if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0] upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"): elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None: if opts.enable_pnginfo and info is not None:

View File

@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images: if opts.do_not_show_images:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)

View File

@ -135,7 +135,7 @@ class InterrogateModels:
return caption[0] return caption[0]
def interrogate(self, pil_image): def interrogate(self, pil_image):
res = None res = ""
try: try:

View File

@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self): def read(self):
if not self.disabled: if not self.disabled:
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
self.data["free"] = free
self.data["total"] = total self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device) torch_stats = torch.cuda.memory_stats(self.device)
self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"] self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"] self.data["system_peak"] = total - self.data["min_free"]

View File

@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass pass
builtin_upscaler_classes = []
forbidden_upscaler_classes = set()
def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes:
forbidden_upscaler_classes.add(cls)
def load_upscalers(): def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__ # so we'll try to import any _model.py files before looking in __subclasses__
@ -139,6 +156,9 @@ def load_upscalers():
datas = [] datas = []
commandline_options = vars(shared.cmd_opts) commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
if cls in forbidden_upscaler_classes:
continue
name = cls.__name__ name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None)) scaler = cls(commandline_options.get(cmd_name, None))

View File

@ -239,7 +239,7 @@ class StableDiffusionProcessing():
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
@ -247,6 +247,7 @@ class Processed:
self.subseed = subseed self.subseed = subseed
self.subseed_strength = p.subseed_strength self.subseed_strength = p.subseed_strength
self.info = info self.info = info
self.comments = comments
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_name = p.sampler_name self.sampler_name = p.sampler_name
@ -646,7 +647,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)
@ -657,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength
self.firstphase_width = firstphase_width self.hr_scale = hr_scale
self.firstphase_height = firstphase_height self.hr_upscaler = hr_upscaler
self.truncate_x = 0
self.truncate_y = 0 if firstphase_width != 0 or firstphase_height != 0:
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
self.hr_scale = self.width / firstphase_width
self.width = firstphase_width
self.height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr: if self.enable_hr:
@ -673,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else: else:
state.job_count = state.job_count * 2 state.job_count = state.job_count * 2
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" self.extra_generation_params["Hires upscale"] = self.hr_scale
if self.hr_upscaler is not None:
if self.firstphase_width == 0 or self.firstphase_height == 0: self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
firstphase_width_truncated = int(scale * self.width)
firstphase_height_truncated = int(scale * self.height)
else:
width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height
if width_ratio > height_ratio:
firstphase_width_truncated = self.firstphase_width
firstphase_height_truncated = self.firstphase_width * self.height / self.width
else:
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) target_width = int(self.width * self.hr_scale)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) target_height = int(self.height * self.hr_scale)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
def save_intermediate(image, index): def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return return
@ -722,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix: if latent_scale_mode is not None:
for i in range(samples.shape[0]): for i in range(samples.shape[0]):
save_intermediate(samples, i) save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
# Avoid making the inpainting conditioning unless necessary as # Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again. # this does need some extra compute to decode / encode the image again.
@ -746,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i) save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height) image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
batch_images.append(image) batch_images.append(image)
@ -763,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None

View File

@ -51,6 +51,13 @@ class UiTrainTabParams:
self.txt2img_preview_params = txt2img_preview_params self.txt2img_preview_params = txt2img_preview_params
class ImageGridLoopParams:
def __init__(self, imgs, cols, rows):
self.imgs = imgs
self.cols = cols
self.rows = rows
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
@ -63,6 +70,7 @@ callback_map = dict(
callbacks_cfg_denoiser=[], callbacks_cfg_denoiser=[],
callbacks_before_component=[], callbacks_before_component=[],
callbacks_after_component=[], callbacks_after_component=[],
callbacks_image_grid=[],
) )
@ -155,6 +163,14 @@ def after_component_callback(component, **kwargs):
report_exception(c, 'after_component_callback') report_exception(c, 'after_component_callback')
def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']:
try:
c.callback(params)
except Exception:
report_exception(c, 'image_grid')
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@ -255,3 +271,11 @@ def on_before_component(callback):
def on_after_component(callback): def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more.""" """register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback) add_callback(callback_map['callbacks_after_component'], callback)
def on_image_grid(callback):
"""register a function to be called before making an image grid.
The callback is called with one argument:
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
add_callback(callback_map['callbacks_image_grid'], callback)

View File

@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
@ -65,6 +65,7 @@ def fix_checkpoint():
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
comments = [] comments = []
@ -75,17 +76,25 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m): def hijack(self, m):
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
apply_optimizations()
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
apply_optimizations()
fix_checkpoint() fix_checkpoint()
def flatten(el): def flatten(el):
@ -98,7 +107,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@ -126,8 +139,8 @@ class StableDiffusionModelHijack:
def tokenize(self, text): def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module): class EmbeddingsWithFixes(torch.nn.Module):

View File

@ -5,7 +5,6 @@ import torch
from modules import prompt_parser, devices from modules import prompt_parser, devices
from modules.shared import opts from modules.shared import opts
def get_target_prompt_token_count(token_count): def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75 return math.ceil(max(token_count, 1) / 75) * 75
@ -254,10 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack) super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
vocab = self.tokenizer.get_vocab()
self.comma_token = vocab.get(',</w>', None)
self.token_mults = {} self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens: for text, ident in tokens_with_parens:
mult = 1.0 mult = 1.0
for c in text: for c in text:
@ -296,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded return embedded

View File

@ -12,191 +12,6 @@ from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# =================================================================================================
# Monkey patch DDIMSampler methods from RunwayML repo directly.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
def sample_ddim(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
# =================================================================================================
# Monkey patch PLMSSampler methods.
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
# Adapted from:
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
# =================================================================================================
@torch.no_grad()
def sample_plms(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@ -280,44 +95,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t return x_prev, pred_x0, e_t
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
# =================================================================================================
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
def should_hijack_inpainting(checkpoint_info): def should_hijack_inpainting(checkpoint_info):
ckpt_basename = os.path.basename(checkpoint_info.filename).lower() ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
@ -326,15 +103,6 @@ def should_hijack_inpainting(checkpoint_info):
def do_inpainting_hijack(): def do_inpainting_hijack():
# most of this stuff seems to no longer be needed because it is already included into SD2.0
# p_sample_plms is needed because PLMS can't work with dicts as conditionings # p_sample_plms is needed because PLMS can't work with dicts as conditionings
# this file should be cleaned up later if everything turns out to work fine
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
# ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
# ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms

34
modules/sd_hijack_xlmr.py Normal file
View File

@ -0,0 +1,34 @@
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.id_start = wrapped.config.bos_token_id
self.id_end = wrapped.config.eos_token_id
self.id_pad = wrapped.config.pad_token_id
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
def encode_with_transformers(self, tokens):
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
# layer to work with - you have to use the last
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
z = features['projection_state']
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.roberta.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
return embedded

View File

@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae() sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae() sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
@ -322,9 +324,12 @@ def load_model(checkpoint_info=None):
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model shared.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
script_callbacks.model_loaded_callback(sd_model) script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.") print("Model loaded.")
return sd_model return sd_model

View File

@ -465,7 +465,9 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps) sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else: else:
sigmas = self.model_wrap.get_sigmas(steps) sigmas = self.model_wrap.get_sigmas(steps)

View File

@ -1,5 +1,6 @@
import torch import torch
import os import os
import collections
from collections import namedtuple from collections import namedtuple
from modules import shared, devices, script_callbacks from modules import shared, devices, script_callbacks
from modules.paths import models_path from modules.paths import models_path
@ -30,6 +31,7 @@ base_vae = None
loaded_vae_file = None loaded_vae_file = None
checkpoint_info = None checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model): def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False # save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
if vae_file: if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" if cache_enabled and vae_file in checkpoints_loaded:
print(f"Loading VAE weights from: {vae_file}") # use vae checkpoint cache
store_base_vae(model) print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) store_base_vae(model)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, checkpoints_loaded[vae_file])
_load_vae_dict(model, vae_dict_1) else:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it # If vae used is not in dict, update it
# It will be removed on refresh though # It will be removed on refresh though

View File

@ -23,7 +23,7 @@ demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@ -113,6 +113,17 @@ restricted_opts = {
"outdir_save", "outdir_save",
} }
ui_reorder_categories = [
"sampler",
"dimensions",
"cfg",
"seed",
"checkboxes",
"hires_fix",
"batch",
"scripts",
]
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
@ -172,7 +183,7 @@ class State:
def dict(self): def dict(self):
obj = { obj = {
"skipped": self.skipped, "skipped": self.skipped,
"interrupted": self.skipped, "interrupted": self.interrupted,
"job": self.job, "job": self.job,
"job_count": self.job_count, "job_count": self.job_count,
"job_no": self.job_no, "job_no": self.job_no,
@ -331,7 +342,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
@ -360,6 +370,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
@ -371,13 +382,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}), "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
@ -409,7 +424,10 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
})) }))
@ -543,6 +561,12 @@ opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": "bilinear",
"Latent (nearest)": "nearest",
}
sd_upscalers = [] sd_upscalers = []
sd_model = None sd_model = None

View File

@ -23,6 +23,8 @@ class Embedding:
self.vec = vec self.vec = vec
self.name = name self.name = name
self.step = step self.step = step
self.shape = None
self.vectors = 0
self.cached_checksum = None self.cached_checksum = None
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
@ -57,8 +59,10 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir): def __init__(self, embeddings_dir):
self.ids_lookup = {} self.ids_lookup = {}
self.word_embeddings = {} self.word_embeddings = {}
self.skipped_embeddings = {}
self.dir_mtime = None self.dir_mtime = None
self.embeddings_dir = embeddings_dir self.embeddings_dir = embeddings_dir
self.expected_shape = -1
def register_embedding(self, embedding, model): def register_embedding(self, embedding, model):
@ -75,20 +79,24 @@ class EmbeddingDatabase:
return embedding return embedding
def load_textual_inversion_embeddings(self): def get_expected_shape(self):
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir) mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime: if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return return
self.dir_mtime = mt self.dir_mtime = mt
self.ids_lookup.clear() self.ids_lookup.clear()
self.word_embeddings.clear() self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name = os.path.splitext(filename)[0]
data = []
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
@ -122,7 +130,13 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None) embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model) embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir): for fn in os.listdir(self.embeddings_dir):
try: try:
@ -137,8 +151,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
print("Embeddings:", ', '.join(self.word_embeddings.keys())) if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset): def find_embedding_at_position(self, tokens, offset):
token = tokens[offset] token = tokens[offset]
@ -267,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
return embedding, filename return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed old_parallel_processing_allowed = shared.parallel_processing_allowed
@ -295,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0 loss_step = 0
_loss_step = 0 #internal _loss_step = 0 #internal
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"

View File

@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling, tiling=tiling,
enable_hr=enable_hr, enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
firstphase_width=firstphase_width if enable_hr else None, hr_scale=hr_scale,
firstphase_height=firstphase_height if enable_hr else None, hr_upscaler=hr_upscaler,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
@ -59,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.do_not_show_images: if opts.do_not_show_images:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)

View File

@ -20,6 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules.ui_components import FormRow, FormGroup, ToolButton
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
@ -80,7 +81,6 @@ css_hide_progressbar = """
# Important that they exactly match script.js for tooltip to work. # Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -159,7 +159,7 @@ def save_files(js_data, images, do_make_zip, index):
zip_file.writestr(filenames[i], f.read()) zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath) fullfns.insert(0, zip_filepath)
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
@ -234,13 +234,6 @@ def check_progress_call_initial(id_part):
return check_progress_call(id_part) return check_progress_call(id_part)
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def visit(x, func, path=""): def visit(x, func, path=""):
if hasattr(x, 'children'): if hasattr(x, 'children'):
for c in x.children: for c in x.children:
@ -280,35 +273,31 @@ def interrogate_deepbooru(image):
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
def create_seed_inputs(): def create_seed_inputs(target_interface):
with gr.Row(): with FormRow(elem_id=target_interface + '_seed_row'):
with gr.Box(): seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
with gr.Row(elem_id='seed_row'): seed.style(container=False)
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1) random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
seed.style(container=False) reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
random_seed = gr.Button(random_symbol, elem_id='random_seed')
reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
with gr.Box(elem_id='subseed_show_box'): with gr.Group(elem_id=target_interface + '_subseed_show_box'):
seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox # Components to show/hide based on the 'Extra' checkbox
seed_extras = [] seed_extras = []
with gr.Row(visible=False) as seed_extra_row_1: with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
seed_extras.append(seed_extra_row_1) seed_extras.append(seed_extra_row_1)
with gr.Box(): subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
with gr.Row(elem_id='subseed_row'): subseed.style(container=False)
subseed = gr.Number(label='Variation seed', value=-1) random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
subseed.style(container=False) reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
random_subseed = gr.Button(random_symbol, elem_id='random_subseed') subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)
with gr.Row(visible=False) as seed_extra_row_2: with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2) seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0) seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0) seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@ -403,7 +392,6 @@ def create_toprow(is_img2img):
) )
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create") save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
@ -452,7 +440,7 @@ def create_toprow(is_img2img):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True prompt_style2.save_to_config = True
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None): def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@ -500,7 +488,7 @@ def apply_setting(key, value):
return return
valtype = type(opts.data_labels[key].default) valtype = type(opts.data_labels[key].default)
oldval = opts.data[key] oldval = opts.data.get(key, None)
opts.data[key] = valtype(value) if valtype != type(None) else value opts.data[key] = valtype(value) if valtype != type(None) else value
if oldval != value and opts.data_labels[key].onchange is not None: if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange() opts.data_labels[key].onchange()
@ -532,7 +520,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
return gr.update(**(args or {})) return gr.update(**(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click( refresh_button.click(
fn=refresh, fn=refresh,
inputs=[], inputs=[],
@ -570,13 +558,14 @@ Requested path was: {f}
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row(elem_id=f"image_buttons_{tabname}"):
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
if tabname != "extras": if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}') save = gr.Button('Save', elem_id=f'save_{tabname}')
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
open_folder_button.click( open_folder_button.click(
fn=lambda: open_folder(opts.outdir_samples or outdir), fn=lambda: open_folder(opts.outdir_samples or outdir),
@ -585,14 +574,13 @@ Requested path was: {f}
) )
if tabname != "extras": if tabname != "extras":
with gr.Row():
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
with gr.Row(): with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
html_log = gr.HTML()
generation_info = gr.Textbox(visible=False) generation_info = gr.Textbox(visible=False)
if tabname == 'txt2img' or tabname == 'img2img': if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
@ -606,25 +594,61 @@ Requested path was: {f}
save.click( save.click(
fn=wrap_gradio_call(save_files), fn=wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[ inputs=[
generation_info, generation_info,
result_gallery, result_gallery,
do_make_zip, html_info,
html_info, html_info,
], ],
outputs=[ outputs=[
download_files, download_files,
html_info, html_log,
html_info,
html_info,
] ]
) )
save_zip.click(
fn=wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
inputs=[
generation_info,
result_gallery,
html_info,
html_info,
],
outputs=[
download_files,
html_log,
]
)
else: else:
html_info_x = gr.HTML() html_info_x = gr.HTML()
html_info = gr.HTML() html_info = gr.HTML()
html_log = gr.HTML()
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
return steps, sampler_index
def ordered_ui_categories():
user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
yield category
def create_ui(): def create_ui():
@ -639,14 +663,11 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
with gr.Row(elem_id='txt2img_progress_row'): with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1): with gr.Column(scale=1):
pass pass
@ -658,42 +679,57 @@ def create_ui():
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel', elem_id="txt2img_settings"): with gr.Column(variant='panel', elem_id="txt2img_settings"):
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) for category in ordered_ui_categories():
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
with gr.Group(): elif category == "dimensions":
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) with FormRow():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) with gr.Column(elem_id="txt2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Row(): if opts.dimensions_and_batch_together:
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) with gr.Column(elem_id="txt2img_column_batch"):
tiling = gr.Checkbox(label='Tiling', value=False) batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
enable_hr = gr.Checkbox(label='Highres. fix', value=False) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
with gr.Row(visible=False) as hr_options: elif category == "cfg":
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(equal_height=True): elif category == "seed":
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) elif category == "checkboxes":
with FormRow(elem_id="txt2img_checkboxes"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() elif category == "hires_fix":
with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with gr.Group(): elif category == "batch":
custom_inputs = modules.scripts.scripts_txt2img.setup_ui() if not opts.dimensions_and_batch_together:
with FormRow(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict( txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit", _js="submit",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
@ -713,14 +749,15 @@ def create_ui():
width, width,
enable_hr, enable_hr,
denoising_strength, denoising_strength,
firstphase_width, hr_scale,
firstphase_height, hr_upscaler,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
txt2img_gallery, txt2img_gallery,
generation_info, generation_info,
html_info html_info,
html_log,
], ],
show_progress=False, show_progress=False,
) )
@ -745,17 +782,6 @@ def create_ui():
outputs=[hr_options], outputs=[hr_options],
) )
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
txt2img_prompt,
],
outputs=[
txt2img_prompt,
]
)
txt2img_paste_fields = [ txt2img_paste_fields = [
(txt2img_prompt, "Prompt"), (txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"), (txt2img_negative_prompt, "Negative prompt"),
@ -774,8 +800,8 @@ def create_ui():
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d), (enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"), (hr_scale, "Hires upscale"),
(firstphase_height, "First pass size-2"), (hr_upscaler, "Hires upscaler"),
*modules.scripts.scripts_txt2img.infotext_fields *modules.scripts.scripts_txt2img.infotext_fields
] ]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
@ -797,8 +823,7 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'): with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@ -811,14 +836,14 @@ def create_ui():
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
setup_progressbar(progressbar, img2img_preview, 'img2img') setup_progressbar(progressbar, img2img_preview, 'img2img')
with gr.Row().style(equal_height=False): with FormRow().style(equal_height=False):
with gr.Column(variant='panel', elem_id="img2img_settings"): with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'): with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
with gr.TabItem('Inpaint', id='inpaint'): with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
init_img_with_mask_orig = gr.State(None) init_img_with_mask_orig = gr.State(None)
@ -836,54 +861,72 @@ def create_ui():
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
with gr.Row(): with FormRow():
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch) mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
with gr.Row(): with FormRow():
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") with FormRow():
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
with gr.Row(): with FormRow():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) with gr.Column():
inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
with gr.TabItem('Batch img2img', id='batch'): with gr.Column(scale=4):
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>") gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
with gr.Row(): with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) for category in ordered_ui_categories():
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
with gr.Group(): elif category == "dimensions":
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") with FormRow():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Row(): if opts.dimensions_and_batch_together:
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) with gr.Column(elem_id="img2img_column_batch"):
tiling = gr.Checkbox(label='Tiling', value=False) batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
with gr.Row(): elif category == "cfg":
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) with FormGroup():
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
with gr.Group(): elif category == "seed":
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() elif category == "checkboxes":
with FormRow(elem_id="img2img_checkboxes"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
with gr.Group(): elif category == "batch":
custom_inputs = modules.scripts.scripts_img2img.setup_ui() if not opts.dimensions_and_batch_together:
with FormRow(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) elif category == "scripts":
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
@ -915,7 +958,7 @@ def create_ui():
) )
img2img_args = dict( img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img), fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img", _js="submit_img2img",
inputs=[ inputs=[
dummy_component, dummy_component,
@ -954,7 +997,8 @@ def create_ui():
outputs=[ outputs=[
img2img_gallery, img2img_gallery,
generation_info, generation_info,
html_info html_info,
html_log,
], ],
show_progress=False, show_progress=False,
) )
@ -974,18 +1018,6 @@ def create_ui():
outputs=[img2img_prompt], outputs=[img2img_prompt],
) )
roll.click(
fn=roll_artist,
_js="update_img2img_tokens",
inputs=[
img2img_prompt,
],
outputs=[
img2img_prompt,
]
)
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
@ -1038,50 +1070,50 @@ def create_ui():
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
with gr.Tabs(elem_id="mode_extras"): with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image'): with gr.TabItem('Single Image', elem_id="extras_single_tab"):
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Batch Process'): with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
with gr.TabItem('Batch from Directory'): with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True) show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
with gr.TabItem('Scale to'): with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with gr.Group(): with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group(): with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
with gr.Group(): with gr.Group():
gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
with gr.Group(): with gr.Group():
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
with gr.Group(): with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
submit.click( submit.click(
fn=wrap_gradio_gpu_call(modules.extras.run_extras), fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
_js="get_extras_tab_index", _js="get_extras_tab_index",
inputs=[ inputs=[
dummy_component, dummy_component,
@ -1123,7 +1155,7 @@ def create_ui():
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
html = gr.HTML() html = gr.HTML()
generation_info = gr.Textbox(visible=False) generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
html2 = gr.HTML() html2 = gr.HTML()
with gr.Row(): with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
@ -1142,23 +1174,27 @@ def create_ui():
with gr.Row(): with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)") create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with gr.Row(): with gr.Row():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format") checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
@ -1167,58 +1203,58 @@ def create_ui():
with gr.Tabs(elem_id="train_tabs"): with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"): with gr.Tab(label="Create embedding"):
new_embedding_name = gr.Textbox(label="Name") new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
initialization_text = gr.Textbox(label="Initialization text", value="*") initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
with gr.Tab(label="Create hypernetwork"): with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
with gr.Tab(label="Preprocess images"): with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory') process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory') process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
process_split = gr.Checkbox(label='Split oversized images') process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
process_focal_crop = gr.Checkbox(label='Auto focal point crop') process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
with gr.Row(visible=False) as process_split_extra_row: with gr.Row(visible=False) as process_split_extra_row:
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
with gr.Row(visible=False) as process_focal_crop_row: with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image') process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
@ -1226,8 +1262,8 @@ def create_ui():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt") interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
process_split.change( process_split.change(
fn=lambda show: gr_show(show), fn=lambda show: gr_show(show),
@ -1250,31 +1286,31 @@ def create_ui():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
with gr.Row(): with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
steps = gr.Number(label='Max steps', value=100000, precision=0) steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
with gr.Row(): with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
with gr.Row(): with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
with gr.Row(): with gr.Row():
interrupt_training = gr.Button(value="Interrupt") interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
train_embedding = gr.Button(value="Train Embedding", variant='primary') train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
params = script_callbacks.UiTrainTabParams(txt2img_preview_params) params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
@ -1447,7 +1483,7 @@ def create_ui():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
with gr.Row(variant="compact"): with FormRow():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
@ -1492,41 +1528,36 @@ def create_ui():
return gr.update(value=value), opts.dumpjson() return gr.update(value=value), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Blocks(analytics_enabled=False) as settings_interface:
settings_submit = gr.Button(value="Apply settings", variant='primary') with gr.Row():
result = gr.HTML() with gr.Column(scale=6):
settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
with gr.Column():
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
settings_cols = 3 result = gr.HTML(elem_id="settings_result")
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
quicksettings_list = [] quicksettings_list = []
cols_displayed = 0
items_displayed = 0
previous_section = None previous_section = None
column = None current_tab = None
with gr.Row(elem_id="settings").style(equal_height=False): with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()): for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None section_must_be_skipped = item.section[0] is None
if previous_section != item.section and not section_must_be_skipped: if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): elem_id, text = item.section
if column is not None:
column.__exit__()
column = gr.Column(variant='panel') if current_tab is not None:
column.__enter__() current_tab.__exit__()
items_displayed = 0 current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
cols_displayed += 1 current_tab.__enter__()
previous_section = item.section previous_section = item.section
elem_id, text = item.section
gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='<h1 class="gr-button-lg">{}</h1>'.format(text))
if k in quicksettings_names and not shared.cmd_opts.freeze_settings: if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item)) quicksettings_list.append((i, k, item))
components.append(dummy_component) components.append(dummy_component)
@ -1536,15 +1567,21 @@ def create_ui():
component = create_setting_component(k) component = create_setting_component(k)
component_dict[k] = component component_dict[k] = component
components.append(component) components.append(component)
items_displayed += 1
with gr.Row(): if current_tab is not None:
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") current_tab.__exit__()
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
with gr.Row(): with gr.TabItem("Actions"):
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
if os.path.exists("html/licenses.html"):
with open("html/licenses.html", encoding="utf8") as file:
with gr.TabItem("Licenses"):
gr.HTML(file.read(), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
request_notifications.click( request_notifications.click(
fn=lambda: None, fn=lambda: None,
@ -1581,9 +1618,6 @@ def create_ui():
outputs=[], outputs=[],
) )
if column is not None:
column.__exit__()
interfaces = [ interfaces = [
(txt2img_interface, "txt2img", "txt2img"), (txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
@ -1617,7 +1651,7 @@ def create_ui():
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"): with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list: for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True) component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component component_dict[k] = component
@ -1632,6 +1666,10 @@ def create_ui():
if os.path.exists(os.path.join(script_path, "notification.mp3")): if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
if os.path.exists("html/footer.html"):
with open("html/footer.html", encoding="utf8") as file:
gr.HTML(file.read(), elem_id="footer")
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click( settings_submit.click(
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
@ -1666,7 +1704,7 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr) print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
return results return results
modelmerger_merge.click( modelmerger_merge.click(

25
modules/ui_components.py Normal file
View File

@ -0,0 +1,25 @@
import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self):
return "button"
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
class FormGroup(gr.Group, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "group"

View File

@ -1,6 +1,7 @@
import os import os
import tempfile import tempfile
from collections import namedtuple from collections import namedtuple
from pathlib import Path
import gradio as gr import gradio as gr
@ -12,10 +13,29 @@ from modules import shared
Savedfile = namedtuple("Savedfile", ["name"]) Savedfile = namedtuple("Savedfile", ["name"])
def register_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
return any([filename in fileset for fileset in gradio.temp_file_sets])
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
return False
def save_pil_to_file(pil_image, dir=None): def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None) already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as): if already_saved_as and os.path.isfile(already_saved_as):
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))} register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as) file_obj = Savedfile(already_saved_as)
return file_obj return file_obj
@ -44,7 +64,7 @@ def on_tmpdir_changed():
os.makedirs(shared.opts.temp_dir, exist_ok=True) os.makedirs(shared.opts.temp_dir, exist_ok=True)
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)} register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr(): def cleanup_tmpdr():

View File

@ -53,10 +53,10 @@ class Upscaler:
def do_upscale(self, img: PIL.Image, selected_model: str): def do_upscale(self, img: PIL.Image, selected_model: str):
return img return img
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale self.scale = scale
dest_w = img.width * scale dest_w = int(img.width * scale)
dest_h = img.height * scale dest_h = int(img.height * scale)
for i in range(3): for i in range(3):
shape = (img.width, img.height) shape = (img.width, img.height)

137
modules/xlmr.py Normal file
View File

@ -0,0 +1,137 @@
from transformers import BertPreTrainedModel,BertModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class BertSeriesModelWithTransformation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
# modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0
config.eos_token_id=2
config.hidden_act='gelu'
config.hidden_dropout_prob=0.1
config.hidden_size=1024
config.initializer_range=0.02
config.intermediate_size=4096
config.layer_norm_eps=1e-05
config.max_position_embeddings=514
config.num_attention_heads=16
config.num_hidden_layers=24
config.output_past=True
config.pad_token_id=1
config.position_embedding_type= "absolute"
config.type_vocab_size= 1
config.use_cache=True
config.vocab_size= 250002
config.project_dim = 768
config.learn_encoder = False
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
self.pooler = lambda x: x[:,0]
self.post_init()
def encode(self,c):
device = next(self.parameters()).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
return features['projection_state']
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) :
r"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
# last module outputs
sequence_output = outputs[0]
# project every module
sequence_output_ln = self.pre_LN(sequence_output)
# pooler
pooler_output = self.pooler(sequence_output_ln)
pooler_output = self.transformation(pooler_output)
projection_state = self.transformation(outputs.last_hidden_state)
return {
'pooler_output':pooler_output,
'last_hidden_state':outputs.last_hidden_state,
'hidden_states':outputs.hidden_states,
'attentions':outputs.attentions,
'projection_state':projection_state,
'sequence_out': sequence_output
}
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig

View File

@ -5,7 +5,7 @@ fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.9 gradio==3.15.0
invisible-watermark invisible-watermark
numpy numpy
omegaconf omegaconf

View File

@ -3,7 +3,7 @@ transformers==4.19.2
accelerate==0.12.0 accelerate==0.12.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.9 gradio==3.15.0
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.2.0
realesrgan==0.3.0 realesrgan==0.3.0

View File

@ -19,7 +19,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>") info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
scale_factor = gr.Slider(minimum=1, maximum=4, step=1, label='Scale Factor', value=2) scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
return [info, overlap, upscaler_index, scale_factor] return [info, overlap, upscaler_index, scale_factor]

View File

@ -202,7 +202,7 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None),
AxisOption("Styles", str, apply_styles, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None),
@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model self.model = shared.sd_model
self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix
self.vae = opts.sd_vae self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object):
hypernetwork.apply_strength() hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")

126
style.css
View File

@ -73,8 +73,9 @@
margin-right: auto; margin-right: auto;
} }
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{ [id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{
min-width: auto; min-width: 2.3em;
height: 2.5em;
flex-grow: 0; flex-grow: 0;
padding-left: 0.25em; padding-left: 0.25em;
padding-right: 0.25em; padding-right: 0.25em;
@ -84,27 +85,28 @@
display: none; display: none;
} }
#seed_row, #subseed_row{ [id$=_seed_row], [id$=_subseed_row]{
gap: 0.5rem; gap: 0.5rem;
padding: 0.6em;
} }
#subseed_show_box{ [id$=_subseed_show_box]{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
} }
#subseed_show_box > div{ [id$=_subseed_show_box] > div{
border: 0; border: 0;
height: 100%; height: 100%;
} }
#subseed_show{ [id$=_subseed_show]{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
padding: 0; padding: 0;
} }
#subseed_show label{ [id$=_subseed_show] label{
height: 100%; height: 100%;
} }
@ -206,24 +208,24 @@ button{
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute; position: absolute;
top: -0.6em; top: -0.7em;
line-height: 1.2em; line-height: 1.2em;
padding: 0 0.5em; padding: 0;
margin: 0; margin: 0 0.5em;
background-color: white; background-color: white;
border-top: 1px solid #eee; box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white;
border-left: 1px solid #eee;
border-right: 1px solid #eee;
z-index: 300; z-index: 300;
} }
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55); background-color: rgb(31, 41, 55);
border-top: 1px solid rgb(55 65 81); box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
border-left: 1px solid rgb(55 65 81); }
border-right: 1px solid rgb(55 65 81);
#txt2img_column_batch, #img2img_column_batch{
min-width: min(13.5em, 100%) !important;
} }
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
@ -232,22 +234,40 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
margin-right: 8em; margin-right: 8em;
} }
.gr-panel div.flex-col div.justify-between label span{
margin: 0;
}
#settings .gr-panel div.flex-col div.justify-between div{ #settings .gr-panel div.flex-col div.justify-between div{
position: relative; position: relative;
z-index: 200; z-index: 200;
} }
input[type="range"]{ #settings{
margin: 0.5em 0 -0.3em 0; display: block;
} }
#txt2img_sampling label{ #settings > div{
padding-left: 0.6em; border: none;
padding-right: 0.6em; margin-left: 10em;
}
#settings > div.flex-wrap{
float: left;
display: block;
margin-left: 0;
width: 10em;
}
#settings > div.flex-wrap button{
display: block;
border: none;
text-align: left;
}
#settings_result{
height: 1.4em;
margin: 0 1.2em;
}
input[type="range"]{
margin: 0.5em 0 -0.3em 0;
} }
#mask_bug_info { #mask_bug_info {
@ -501,13 +521,6 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
}
canvas[key="mask"] { canvas[key="mask"] {
z-index: 12 !important; z-index: 12 !important;
filter: invert(); filter: invert();
@ -521,7 +534,7 @@ canvas[key="mask"] {
position: absolute; position: absolute;
right: 0.5em; right: 0.5em;
top: -0.6em; top: -0.6em;
z-index: 200; z-index: 400;
width: 8em; width: 8em;
} }
#quicksettings .gr-box > div > div > input.gr-text-input { #quicksettings .gr-box > div > div > input.gr-text-input {
@ -568,6 +581,53 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
font-size: 95%; font-size: 95%;
} }
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
min-width: auto;
padding-left: 0.5em;
padding-right: 0.5em;
}
.gr-form{
background-color: white;
}
.dark .gr-form{
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
}
.gr-button-tool{
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
margin: 0.55em 0;
}
#quicksettings .gr-button-tool{
margin: 0;
}
#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
padding-top: 0.9em;
}
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{
border: none;
padding-bottom: 0.5em;
}
footer {
display: none !important;
}
#footer{
text-align: center;
}
#footer div{
display: inline-block;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running

68
v2-inference-v.yaml Normal file
View File

@ -0,0 +1,68 @@
model:
base_learning_rate: 1.0e-4
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -1,4 +1,5 @@
import os import os
import sys
import threading import threading
import time import time
import importlib import importlib
@ -55,8 +56,8 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts() modules.scripts.load_scripts()
modelloader.load_upscalers() modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
@ -169,23 +170,22 @@ def webui():
modules.script_callbacks.app_started_callback(shared.demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo) wait_on_server(shared.demo)
print('Restarting UI...')
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading extensions')
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
print('Reloading custom scripts') modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
modelloader.load_upscalers() modelloader.load_upscalers()
print('Reloading modules: modules.ui') for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(modules.ui) importlib.reload(module)
print('Refreshing Model List')
modules.sd_models.list_models() modules.sd_models.list_models()
print('Restarting Gradio')
if __name__ == "__main__": if __name__ == "__main__":