Merge branch 'master' into improved-hr-conflict-test
This commit is contained in:
commit
f5e4436453
@ -157,5 +157,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
|
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
|
||||||
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
||||||
- Security advice - RyotaK
|
- Security advice - RyotaK
|
||||||
|
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
@ -3,7 +3,9 @@ import os
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import shared, devices, sd_models
|
from modules import shared, devices, sd_models, errors
|
||||||
|
|
||||||
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
re_digits = re.compile(r"\d+")
|
re_digits = re.compile(r"\d+")
|
||||||
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
||||||
@ -43,6 +45,23 @@ class LoraOnDisk:
|
|||||||
def __init__(self, name, filename):
|
def __init__(self, name, filename):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
_, ext = os.path.splitext(filename)
|
||||||
|
if ext.lower() == ".safetensors":
|
||||||
|
try:
|
||||||
|
self.metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading lora {filename}")
|
||||||
|
|
||||||
|
if self.metadata:
|
||||||
|
m = {}
|
||||||
|
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||||
|
m[k] = v
|
||||||
|
|
||||||
|
self.metadata = m
|
||||||
|
|
||||||
|
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
||||||
|
|
||||||
|
|
||||||
class LoraModule:
|
class LoraModule:
|
||||||
|
@ -15,21 +15,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def list_items(self):
|
def list_items(self):
|
||||||
for name, lora_on_disk in lora.available_loras.items():
|
for name, lora_on_disk in lora.available_loras.items():
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
previews = [path + ".png", path + ".preview.png"]
|
|
||||||
|
|
||||||
preview = None
|
|
||||||
for file in previews:
|
|
||||||
if os.path.isfile(file):
|
|
||||||
preview = self.link_preview(file)
|
|
||||||
break
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": path,
|
"filename": path,
|
||||||
"preview": preview,
|
"preview": self.find_preview(path),
|
||||||
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": path + ".png",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
|
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
<div class='card' {preview_html} onclick={card_clicked}>
|
<div class='card' {preview_html} onclick={card_clicked}>
|
||||||
|
{metadata_button}
|
||||||
|
|
||||||
<div class='actions'>
|
<div class='actions'>
|
||||||
<div class='additional'>
|
<div class='additional'>
|
||||||
<ul>
|
<ul>
|
||||||
@ -7,6 +9,7 @@
|
|||||||
<span style="display:none" class='search_term'>{search_term}</span>
|
<span style="display:none" class='search_term'>{search_term}</span>
|
||||||
</div>
|
</div>
|
||||||
<span class='name'>{name}</span>
|
<span class='name'>{name}</span>
|
||||||
|
<span class='description'>{description}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
@ -417,3 +417,222 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/LICENSE">Scaled Dot Product Attention</a></h2>
|
||||||
|
<small>Some small amounts of code borrowed and reworked.</small>
|
||||||
|
<pre>
|
||||||
|
Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
</pre>
|
@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){
|
|||||||
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
|
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
|
||||||
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
|
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
|
||||||
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
|
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
|
||||||
var close = gradioApp().getElementById(tabname+'_extra_close')
|
|
||||||
|
|
||||||
search.classList.add('search')
|
search.classList.add('search')
|
||||||
tabs.appendChild(search)
|
tabs.appendChild(search)
|
||||||
tabs.appendChild(refresh)
|
tabs.appendChild(refresh)
|
||||||
tabs.appendChild(close)
|
|
||||||
|
|
||||||
search.addEventListener("input", function(evt){
|
search.addEventListener("input", function(evt){
|
||||||
searchTerm = search.value.toLowerCase()
|
searchTerm = search.value.toLowerCase()
|
||||||
@ -78,7 +76,7 @@ function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
|||||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
||||||
|
|
||||||
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
||||||
textarea.value = textarea.value + " " + textToAdd
|
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
|
||||||
}
|
}
|
||||||
|
|
||||||
updateInput(textarea)
|
updateInput(textarea)
|
||||||
@ -104,4 +102,40 @@ function extraNetworksSearchButton(tabs_id, event){
|
|||||||
|
|
||||||
searchTextarea.value = text
|
searchTextarea.value = text
|
||||||
updateInput(searchTextarea)
|
updateInput(searchTextarea)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var globalPopup = null;
|
||||||
|
var globalPopupInner = null;
|
||||||
|
function popup(contents){
|
||||||
|
if(! globalPopup){
|
||||||
|
globalPopup = document.createElement('div')
|
||||||
|
globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
|
||||||
|
globalPopup.classList.add('global-popup');
|
||||||
|
|
||||||
|
var close = document.createElement('div')
|
||||||
|
close.classList.add('global-popup-close');
|
||||||
|
close.onclick = function(){ globalPopup.style.display = "none"; };
|
||||||
|
close.title = "Close";
|
||||||
|
globalPopup.appendChild(close)
|
||||||
|
|
||||||
|
globalPopupInner = document.createElement('div')
|
||||||
|
globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
|
||||||
|
globalPopupInner.classList.add('global-popup-inner');
|
||||||
|
globalPopup.appendChild(globalPopupInner)
|
||||||
|
|
||||||
|
gradioApp().appendChild(globalPopup);
|
||||||
|
}
|
||||||
|
|
||||||
|
globalPopupInner.innerHTML = '';
|
||||||
|
globalPopupInner.appendChild(contents);
|
||||||
|
|
||||||
|
globalPopup.style.display = "flex";
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksShowMetadata(text){
|
||||||
|
elem = document.createElement('pre')
|
||||||
|
elem.classList.add('popup-metadata');
|
||||||
|
elem.textContent = text;
|
||||||
|
|
||||||
|
popup(elem);
|
||||||
|
}
|
||||||
|
@ -6,6 +6,7 @@ titles = {
|
|||||||
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
||||||
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
|
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
|
||||||
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
||||||
|
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
|
||||||
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
||||||
|
|
||||||
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
|
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
|
||||||
|
@ -11,7 +11,7 @@ function showModal(event) {
|
|||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
||||||
}
|
}
|
||||||
lb.style.display = "block";
|
lb.style.display = "flex";
|
||||||
lb.focus()
|
lb.focus()
|
||||||
|
|
||||||
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
||||||
|
@ -15,7 +15,7 @@ onUiUpdate(function(){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
|
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
|
||||||
|
|
||||||
if (galleryPreviews == null) return;
|
if (galleryPreviews == null) return;
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
|
|
||||||
var divProgress = document.createElement('div')
|
var divProgress = document.createElement('div')
|
||||||
divProgress.className='progressDiv'
|
divProgress.className='progressDiv'
|
||||||
divProgress.style.display = opts.show_progressbar ? "" : "none"
|
divProgress.style.display = opts.show_progressbar ? "block" : "none"
|
||||||
var divInner = document.createElement('div')
|
var divInner = document.createElement('div')
|
||||||
divInner.className='progress'
|
divInner.className='progress'
|
||||||
|
|
||||||
|
44
launch.py
44
launch.py
@ -8,6 +8,14 @@ import platform
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(add_help=False)
|
||||||
|
parser.add_argument("--ui-settings-file", type=str, default='config.json')
|
||||||
|
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
args, _ = parser.parse_known_args(sys.argv)
|
||||||
|
|
||||||
|
script_path = os.path.dirname(__file__)
|
||||||
|
data_path = os.getcwd()
|
||||||
|
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
dir_extensions = "extensions"
|
dir_extensions = "extensions"
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
@ -122,7 +130,7 @@ def is_installed(package):
|
|||||||
|
|
||||||
|
|
||||||
def repo_dir(name):
|
def repo_dir(name):
|
||||||
return os.path.join(dir_repos, name)
|
return os.path.join(script_path, dir_repos, name)
|
||||||
|
|
||||||
|
|
||||||
def run_python(code, desc=None, errdesc=None):
|
def run_python(code, desc=None, errdesc=None):
|
||||||
@ -161,7 +169,17 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
if commithash is not None:
|
if commithash is not None:
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
|
|
||||||
|
|
||||||
|
def git_pull_recursive(dir):
|
||||||
|
for subdir, _, _ in os.walk(dir):
|
||||||
|
if os.path.exists(os.path.join(subdir, '.git')):
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
||||||
|
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
||||||
|
|
||||||
|
|
||||||
def version_check(commit):
|
def version_check(commit):
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
@ -205,7 +223,7 @@ def list_extensions(settings_file):
|
|||||||
|
|
||||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||||
|
|
||||||
return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
|
return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
|
||||||
|
|
||||||
|
|
||||||
def run_extensions_installers(settings_file):
|
def run_extensions_installers(settings_file):
|
||||||
@ -242,11 +260,8 @@ def prepare_environment():
|
|||||||
|
|
||||||
sys.argv += shlex.split(commandline_args)
|
sys.argv += shlex.split(commandline_args)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(add_help=False)
|
|
||||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
|
|
||||||
args, _ = parser.parse_known_args(sys.argv)
|
|
||||||
|
|
||||||
sys.argv, _ = extract_arg(sys.argv, '-f')
|
sys.argv, _ = extract_arg(sys.argv, '-f')
|
||||||
|
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
|
||||||
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||||
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
|
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
|
||||||
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||||
@ -295,7 +310,7 @@ def prepare_environment():
|
|||||||
if not is_installed("pyngrok") and ngrok:
|
if not is_installed("pyngrok") and ngrok:
|
||||||
run_pip("install pyngrok", "ngrok")
|
run_pip("install pyngrok", "ngrok")
|
||||||
|
|
||||||
os.makedirs(dir_repos, exist_ok=True)
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
||||||
@ -304,14 +319,19 @@ def prepare_environment():
|
|||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|
||||||
if not is_installed("lpips"):
|
if not is_installed("lpips"):
|
||||||
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||||
|
|
||||||
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
if not os.path.isfile(requirements_file):
|
||||||
|
requirements_file = os.path.join(script_path, requirements_file)
|
||||||
|
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
|
||||||
|
|
||||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||||
|
|
||||||
if update_check:
|
if update_check:
|
||||||
version_check(commit)
|
version_check(commit)
|
||||||
|
|
||||||
|
if update_all_extensions:
|
||||||
|
git_pull_recursive(os.path.join(data_path, dir_extensions))
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
@ -327,7 +347,7 @@ def tests(test_dir):
|
|||||||
sys.argv.append("--api")
|
sys.argv.append("--api")
|
||||||
if "--ckpt" not in sys.argv:
|
if "--ckpt" not in sys.argv:
|
||||||
sys.argv.append("--ckpt")
|
sys.argv.append("--ckpt")
|
||||||
sys.argv.append("./test/test_files/empty.pt")
|
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
||||||
if "--skip-torch-cuda-test" not in sys.argv:
|
if "--skip-torch-cuda-test" not in sys.argv:
|
||||||
sys.argv.append("--skip-torch-cuda-test")
|
sys.argv.append("--skip-torch-cuda-test")
|
||||||
if "--disable-nan-check" not in sys.argv:
|
if "--disable-nan-check" not in sys.argv:
|
||||||
@ -336,7 +356,7 @@ def tests(test_dir):
|
|||||||
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
||||||
|
|
||||||
os.environ['COMMANDLINE_ARGS'] = ""
|
os.environ['COMMANDLINE_ARGS'] = ""
|
||||||
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
|
with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr:
|
||||||
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
|
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
|
||||||
|
|
||||||
import test.server_poll
|
import test.server_poll
|
||||||
|
@ -150,6 +150,7 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
||||||
|
|
||||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
@ -163,47 +164,98 @@ class Api:
|
|||||||
|
|
||||||
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
||||||
|
|
||||||
def get_script(self, script_name, script_runner):
|
def get_selectable_script(self, script_name, script_runner):
|
||||||
if script_name is None:
|
if script_name is None or script_name == "":
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if not script_runner.scripts:
|
|
||||||
script_runner.initialize_scripts(False)
|
|
||||||
ui.create_ui()
|
|
||||||
|
|
||||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||||
script = script_runner.selectable_scripts[script_idx]
|
script = script_runner.selectable_scripts[script_idx]
|
||||||
return script, script_idx
|
return script, script_idx
|
||||||
|
|
||||||
|
def get_scripts_list(self):
|
||||||
|
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
||||||
|
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
||||||
|
|
||||||
|
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
||||||
|
|
||||||
|
def get_script(self, script_name, script_runner):
|
||||||
|
if script_name is None or script_name == "":
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||||
|
return script_runner.scripts[script_idx]
|
||||||
|
|
||||||
|
def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
|
||||||
|
#find max idx from the scripts in runner and generate a none array to init script_args
|
||||||
|
last_arg_index = 1
|
||||||
|
for script in script_runner.scripts:
|
||||||
|
if last_arg_index < script.args_to:
|
||||||
|
last_arg_index = script.args_to
|
||||||
|
# None everywhere except position 0 to initialize script args
|
||||||
|
script_args = [None]*last_arg_index
|
||||||
|
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
||||||
|
if selectable_scripts:
|
||||||
|
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
||||||
|
script_args[0] = selectable_idx + 1
|
||||||
|
else:
|
||||||
|
# when [0] = 0 no selectable script to run
|
||||||
|
script_args[0] = 0
|
||||||
|
|
||||||
|
# Now check for always on scripts
|
||||||
|
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||||
|
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||||
|
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||||
|
if alwayson_script == None:
|
||||||
|
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||||
|
# Selectable script in always on script param check
|
||||||
|
if alwayson_script.alwayson == False:
|
||||||
|
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
||||||
|
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||||
|
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||||
|
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
|
||||||
|
return script_args
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||||
script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
|
script_runner = scripts.scripts_txt2img
|
||||||
|
if not script_runner.scripts:
|
||||||
|
script_runner.initialize_scripts(False)
|
||||||
|
ui.create_ui()
|
||||||
|
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
||||||
|
|
||||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
||||||
"do_not_save_samples": True,
|
"do_not_save_samples": not txt2imgreq.save_images,
|
||||||
"do_not_save_grid": True
|
"do_not_save_grid": not txt2imgreq.save_images,
|
||||||
}
|
})
|
||||||
)
|
|
||||||
if populate.sampler_name:
|
if populate.sampler_name:
|
||||||
populate.sampler_index = None # prevent a warning later on
|
populate.sampler_index = None # prevent a warning later on
|
||||||
|
|
||||||
args = vars(populate)
|
args = vars(populate)
|
||||||
args.pop('script_name', None)
|
args.pop('script_name', None)
|
||||||
|
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||||
|
args.pop('alwayson_scripts', None)
|
||||||
|
|
||||||
|
script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
|
||||||
|
|
||||||
|
send_images = args.pop('send_images', True)
|
||||||
|
args.pop('save_images', None)
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
||||||
|
p.scripts = script_runner
|
||||||
|
p.outpath_grids = opts.outdir_txt2img_grids
|
||||||
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if script is not None:
|
if selectable_scripts != None:
|
||||||
p.outpath_grids = opts.outdir_txt2img_grids
|
p.script_args = script_args
|
||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args)
|
|
||||||
else:
|
else:
|
||||||
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
@ -212,41 +264,53 @@ class Api:
|
|||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
|
|
||||||
script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
|
|
||||||
|
|
||||||
mask = img2imgreq.mask
|
mask = img2imgreq.mask
|
||||||
if mask:
|
if mask:
|
||||||
mask = decode_base64_to_image(mask)
|
mask = decode_base64_to_image(mask)
|
||||||
|
|
||||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
script_runner = scripts.scripts_img2img
|
||||||
|
if not script_runner.scripts:
|
||||||
|
script_runner.initialize_scripts(True)
|
||||||
|
ui.create_ui()
|
||||||
|
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
||||||
|
|
||||||
|
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
||||||
"do_not_save_samples": True,
|
"do_not_save_samples": not img2imgreq.save_images,
|
||||||
"do_not_save_grid": True,
|
"do_not_save_grid": not img2imgreq.save_images,
|
||||||
"mask": mask
|
"mask": mask,
|
||||||
}
|
})
|
||||||
)
|
|
||||||
if populate.sampler_name:
|
if populate.sampler_name:
|
||||||
populate.sampler_index = None # prevent a warning later on
|
populate.sampler_index = None # prevent a warning later on
|
||||||
|
|
||||||
args = vars(populate)
|
args = vars(populate)
|
||||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||||
args.pop('script_name', None)
|
args.pop('script_name', None)
|
||||||
|
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||||
|
args.pop('alwayson_scripts', None)
|
||||||
|
|
||||||
|
script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
|
||||||
|
|
||||||
|
send_images = args.pop('send_images', True)
|
||||||
|
args.pop('save_images', None)
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||||
|
p.scripts = script_runner
|
||||||
|
p.outpath_grids = opts.outdir_img2img_grids
|
||||||
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if script is not None:
|
if selectable_scripts != None:
|
||||||
p.outpath_grids = opts.outdir_img2img_grids
|
p.script_args = script_args
|
||||||
p.outpath_samples = opts.outdir_img2img_samples
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args)
|
|
||||||
else:
|
else:
|
||||||
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
if not img2imgreq.include_init_images:
|
if not img2imgreq.include_init_images:
|
||||||
img2imgreq.init_images = None
|
img2imgreq.init_images = None
|
||||||
|
@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
|
|||||||
"outpath_samples",
|
"outpath_samples",
|
||||||
"outpath_grids",
|
"outpath_grids",
|
||||||
"sampler_index",
|
"sampler_index",
|
||||||
"do_not_save_samples",
|
# "do_not_save_samples",
|
||||||
"do_not_save_grid",
|
# "do_not_save_grid",
|
||||||
"extra_generation_params",
|
"extra_generation_params",
|
||||||
"overlay_images",
|
"overlay_images",
|
||||||
"do_not_reload_embeddings",
|
"do_not_reload_embeddings",
|
||||||
@ -100,13 +100,31 @@ class PydanticModelGenerator:
|
|||||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||||
"StableDiffusionProcessingTxt2Img",
|
"StableDiffusionProcessingTxt2Img",
|
||||||
StableDiffusionProcessingTxt2Img,
|
StableDiffusionProcessingTxt2Img,
|
||||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
[
|
||||||
|
{"key": "sampler_index", "type": str, "default": "Euler"},
|
||||||
|
{"key": "script_name", "type": str, "default": None},
|
||||||
|
{"key": "script_args", "type": list, "default": []},
|
||||||
|
{"key": "send_images", "type": bool, "default": True},
|
||||||
|
{"key": "save_images", "type": bool, "default": False},
|
||||||
|
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||||
|
]
|
||||||
).generate_model()
|
).generate_model()
|
||||||
|
|
||||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||||
"StableDiffusionProcessingImg2Img",
|
"StableDiffusionProcessingImg2Img",
|
||||||
StableDiffusionProcessingImg2Img,
|
StableDiffusionProcessingImg2Img,
|
||||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
[
|
||||||
|
{"key": "sampler_index", "type": str, "default": "Euler"},
|
||||||
|
{"key": "init_images", "type": list, "default": None},
|
||||||
|
{"key": "denoising_strength", "type": float, "default": 0.75},
|
||||||
|
{"key": "mask", "type": str, "default": None},
|
||||||
|
{"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
|
||||||
|
{"key": "script_name", "type": str, "default": None},
|
||||||
|
{"key": "script_args", "type": list, "default": []},
|
||||||
|
{"key": "send_images", "type": bool, "default": True},
|
||||||
|
{"key": "save_images", "type": bool, "default": False},
|
||||||
|
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||||
|
]
|
||||||
).generate_model()
|
).generate_model()
|
||||||
|
|
||||||
class TextToImageResponse(BaseModel):
|
class TextToImageResponse(BaseModel):
|
||||||
@ -267,3 +285,7 @@ class EmbeddingsResponse(BaseModel):
|
|||||||
class MemoryResponse(BaseModel):
|
class MemoryResponse(BaseModel):
|
||||||
ram: dict = Field(title="RAM", description="System memory stats")
|
ram: dict = Field(title="RAM", description="System memory stats")
|
||||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||||
|
|
||||||
|
class ScriptsList(BaseModel):
|
||||||
|
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
||||||
|
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
@ -55,7 +55,7 @@ def setup_model(dirname):
|
|||||||
if self.net is not None and self.face_helper is not None:
|
if self.net is not None and self.face_helper is not None:
|
||||||
self.net.to(devices.device_codeformer)
|
self.net.to(devices.device_codeformer)
|
||||||
return self.net, self.face_helper
|
return self.net, self.face_helper
|
||||||
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
|
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
|
||||||
if len(model_paths) != 0:
|
if len(model_paths) != 0:
|
||||||
ckpt_path = model_paths[0]
|
ckpt_path = model_paths[0]
|
||||||
else:
|
else:
|
||||||
|
@ -66,7 +66,7 @@ class Extension:
|
|||||||
|
|
||||||
def check_updates(self):
|
def check_updates(self):
|
||||||
repo = git.Repo(self.path)
|
repo = git.Repo(self.path)
|
||||||
for fetch in repo.remote().fetch("--dry-run"):
|
for fetch in repo.remote().fetch(dry_run=True):
|
||||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
self.can_update = True
|
self.can_update = True
|
||||||
self.status = "behind"
|
self.status = "behind"
|
||||||
@ -79,8 +79,8 @@ class Extension:
|
|||||||
repo = git.Repo(self.path)
|
repo = git.Repo(self.path)
|
||||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||||
repo.git.fetch('--all')
|
repo.git.fetch(all=True)
|
||||||
repo.git.reset('--hard', 'origin')
|
repo.git.reset('origin', hard=True)
|
||||||
|
|
||||||
|
|
||||||
def list_extensions():
|
def list_extensions():
|
||||||
|
@ -23,13 +23,14 @@ registered_param_bindings = []
|
|||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
class ParamBinding:
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
|
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
||||||
self.paste_button = paste_button
|
self.paste_button = paste_button
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
self.source_text_component = source_text_component
|
self.source_text_component = source_text_component
|
||||||
self.source_image_component = source_image_component
|
self.source_image_component = source_image_component
|
||||||
self.source_tabname = source_tabname
|
self.source_tabname = source_tabname
|
||||||
self.override_settings_component = override_settings_component
|
self.override_settings_component = override_settings_component
|
||||||
|
self.paste_field_names = paste_field_names
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
@ -134,7 +135,7 @@ def connect_paste_params_buttons():
|
|||||||
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
||||||
|
|
||||||
if binding.source_tabname is not None and fields is not None:
|
if binding.source_tabname is not None and fields is not None:
|
||||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
|
||||||
binding.paste_button.click(
|
binding.paste_button.click(
|
||||||
fn=lambda *x: x,
|
fn=lambda *x: x,
|
||||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||||
@ -292,6 +293,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
|
|
||||||
settings_map = {}
|
settings_map = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
infotext_to_setting_name_mapping = [
|
infotext_to_setting_name_mapping = [
|
||||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||||
@ -300,7 +303,11 @@ infotext_to_setting_name_mapping = [
|
|||||||
('Noise multiplier', 'initial_noise_multiplier'),
|
('Noise multiplier', 'initial_noise_multiplier'),
|
||||||
('Eta', 'eta_ancestral'),
|
('Eta', 'eta_ancestral'),
|
||||||
('Eta DDIM', 'eta_ddim'),
|
('Eta DDIM', 'eta_ddim'),
|
||||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
|
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
||||||
|
('UniPC variant', 'uni_pc_variant'),
|
||||||
|
('UniPC skip type', 'uni_pc_skip_type'),
|
||||||
|
('UniPC order', 'uni_pc_order'),
|
||||||
|
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
elif image_to_save.mode == 'I;16':
|
elif image_to_save.mode == 'I;16':
|
||||||
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||||
|
|
||||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||||
|
|
||||||
if opts.enable_pnginfo and info is not None:
|
if opts.enable_pnginfo and info is not None:
|
||||||
exif_bytes = piexif.dump({
|
exif_bytes = piexif.dump({
|
||||||
@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
os.replace(temp_file_path, filename_without_extension + extension)
|
os.replace(temp_file_path, filename_without_extension + extension)
|
||||||
|
|
||||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||||
|
if hasattr(os, 'statvfs'):
|
||||||
|
max_name_len = os.statvfs(path).f_namemax
|
||||||
|
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
|
||||||
|
params.filename = fullfn_without_extension + extension
|
||||||
|
fullfn = params.filename
|
||||||
_atomically_save_image(image, fullfn_without_extension, extension)
|
_atomically_save_image(image, fullfn_without_extension, extension)
|
||||||
|
|
||||||
image.already_saved_as = fullfn
|
image.already_saved_as = fullfn
|
||||||
|
@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
if output_dtype == torch.int64:
|
if output_dtype == torch.int64:
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||||
return cumsum_func(input, *args, **kwargs)
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
@ -45,7 +45,6 @@ if has_mps:
|
|||||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
|
@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
self.data = defaultdict(int)
|
self.data = defaultdict(int)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
torch.cuda.mem_get_info()
|
self.cuda_mem_get_info()
|
||||||
torch.cuda.memory_stats(self.device)
|
torch.cuda.memory_stats(self.device)
|
||||||
except Exception as e: # AMD or whatever
|
except Exception as e: # AMD or whatever
|
||||||
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
|
|
||||||
|
def cuda_mem_get_info(self):
|
||||||
|
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
|
||||||
|
return torch.cuda.mem_get_info(index)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
self.run_flag.clear()
|
self.run_flag.clear()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
self.data["min_free"] = self.cuda_mem_get_info()[0]
|
||||||
|
|
||||||
while self.run_flag.is_set():
|
while self.run_flag.is_set():
|
||||||
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
free, total = self.cuda_mem_get_info()
|
||||||
self.data["min_free"] = min(self.data["min_free"], free)
|
self.data["min_free"] = min(self.data["min_free"], free)
|
||||||
|
|
||||||
time.sleep(1 / self.opts.memmon_poll_rate)
|
time.sleep(1 / self.opts.memmon_poll_rate)
|
||||||
@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
if not self.disabled:
|
if not self.disabled:
|
||||||
free, total = torch.cuda.mem_get_info()
|
free, total = self.cuda_mem_get_info()
|
||||||
self.data["free"] = free
|
self.data["free"] = free
|
||||||
self.data["total"] = total
|
self.data["total"] = total
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.upscaler import Upscaler
|
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
|
|
||||||
@ -169,4 +169,8 @@ def load_upscalers():
|
|||||||
scaler = cls(commandline_options.get(cmd_name, None))
|
scaler = cls(commandline_options.get(cmd_name, None))
|
||||||
datas += scaler.scalers
|
datas += scaler.scalers
|
||||||
|
|
||||||
shared.sd_upscalers = datas
|
shared.sd_upscalers = sorted(
|
||||||
|
datas,
|
||||||
|
# Special case for UpscalerNone keeps it at the beginning of the list.
|
||||||
|
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
||||||
|
)
|
||||||
|
1
modules/models/diffusion/uni_pc/__init__.py
Normal file
1
modules/models/diffusion/uni_pc/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .sampler import UniPCSampler
|
100
modules/models/diffusion/uni_pc/sampler.py
Normal file
100
modules/models/diffusion/uni_pc/sampler.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
||||||
|
from modules import shared, devices
|
||||||
|
|
||||||
|
|
||||||
|
class UniPCSampler(object):
|
||||||
|
def __init__(self, model, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||||
|
self.before_sample = None
|
||||||
|
self.after_sample = None
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != devices.device:
|
||||||
|
attr = attr.to(devices.device)
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def set_hooks(self, before_sample, after_sample, after_update):
|
||||||
|
self.before_sample = before_sample
|
||||||
|
self.after_sample = after_sample
|
||||||
|
self.after_update = after_update
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
elif isinstance(conditioning, list):
|
||||||
|
for ctmp in conditioning:
|
||||||
|
if ctmp.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
# print(f'Data shape for UniPC sampling is {size}')
|
||||||
|
|
||||||
|
device = self.model.betas.device
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(size, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||||
|
|
||||||
|
# SD 1.X is "noise", SD 2.X is "v"
|
||||||
|
model_type = "v" if self.model.parameterization == "v" else "noise"
|
||||||
|
|
||||||
|
model_fn = model_wrapper(
|
||||||
|
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||||
|
ns,
|
||||||
|
model_type=model_type,
|
||||||
|
guidance_type="classifier-free",
|
||||||
|
#condition=conditioning,
|
||||||
|
#unconditional_condition=unconditional_conditioning,
|
||||||
|
guidance_scale=unconditional_guidance_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
||||||
|
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||||
|
|
||||||
|
return x.to(device), None
|
857
modules/models/diffusion/uni_pc/uni_pc.py
Normal file
857
modules/models/diffusion/uni_pc/uni_pc.py
Normal file
@ -0,0 +1,857 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from tqdm.auto import trange
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseScheduleVP:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
schedule='discrete',
|
||||||
|
betas=None,
|
||||||
|
alphas_cumprod=None,
|
||||||
|
continuous_beta_0=0.1,
|
||||||
|
continuous_beta_1=20.,
|
||||||
|
):
|
||||||
|
"""Create a wrapper class for the forward SDE (VP type).
|
||||||
|
|
||||||
|
***
|
||||||
|
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||||
|
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
||||||
|
***
|
||||||
|
|
||||||
|
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
||||||
|
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
||||||
|
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
||||||
|
|
||||||
|
log_alpha_t = self.marginal_log_mean_coeff(t)
|
||||||
|
sigma_t = self.marginal_std(t)
|
||||||
|
lambda_t = self.marginal_lambda(t)
|
||||||
|
|
||||||
|
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
||||||
|
|
||||||
|
t = self.inverse_lambda(lambda_t)
|
||||||
|
|
||||||
|
===============================================================
|
||||||
|
|
||||||
|
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
||||||
|
|
||||||
|
1. For discrete-time DPMs:
|
||||||
|
|
||||||
|
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
||||||
|
t_i = (i + 1) / N
|
||||||
|
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
||||||
|
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
||||||
|
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
||||||
|
|
||||||
|
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
||||||
|
|
||||||
|
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
||||||
|
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
||||||
|
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
||||||
|
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
||||||
|
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
||||||
|
and
|
||||||
|
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
||||||
|
|
||||||
|
|
||||||
|
2. For continuous-time DPMs:
|
||||||
|
|
||||||
|
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
||||||
|
schedule are the default settings in DDPM and improved-DDPM:
|
||||||
|
|
||||||
|
Args:
|
||||||
|
beta_min: A `float` number. The smallest beta for the linear schedule.
|
||||||
|
beta_max: A `float` number. The largest beta for the linear schedule.
|
||||||
|
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
||||||
|
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
||||||
|
T: A `float` number. The ending time of the forward process.
|
||||||
|
|
||||||
|
===============================================================
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
||||||
|
'linear' or 'cosine' for continuous-time DPMs.
|
||||||
|
Returns:
|
||||||
|
A wrapper object of the forward SDE (VP type).
|
||||||
|
|
||||||
|
===============================================================
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
||||||
|
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
||||||
|
|
||||||
|
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
||||||
|
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
|
|
||||||
|
# For continuous-time DPMs (VPSDE), linear schedule:
|
||||||
|
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||||
|
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
||||||
|
|
||||||
|
self.schedule = schedule
|
||||||
|
if schedule == 'discrete':
|
||||||
|
if betas is not None:
|
||||||
|
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
||||||
|
else:
|
||||||
|
assert alphas_cumprod is not None
|
||||||
|
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
||||||
|
self.total_N = len(log_alphas)
|
||||||
|
self.T = 1.
|
||||||
|
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
||||||
|
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
||||||
|
else:
|
||||||
|
self.total_N = 1000
|
||||||
|
self.beta_0 = continuous_beta_0
|
||||||
|
self.beta_1 = continuous_beta_1
|
||||||
|
self.cosine_s = 0.008
|
||||||
|
self.cosine_beta_max = 999.
|
||||||
|
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||||
|
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
||||||
|
self.schedule = schedule
|
||||||
|
if schedule == 'cosine':
|
||||||
|
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
||||||
|
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
||||||
|
self.T = 0.9946
|
||||||
|
else:
|
||||||
|
self.T = 1.
|
||||||
|
|
||||||
|
def marginal_log_mean_coeff(self, t):
|
||||||
|
"""
|
||||||
|
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
if self.schedule == 'discrete':
|
||||||
|
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
||||||
|
elif self.schedule == 'linear':
|
||||||
|
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||||
|
elif self.schedule == 'cosine':
|
||||||
|
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
||||||
|
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||||
|
return log_alpha_t
|
||||||
|
|
||||||
|
def marginal_alpha(self, t):
|
||||||
|
"""
|
||||||
|
Compute alpha_t of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||||
|
|
||||||
|
def marginal_std(self, t):
|
||||||
|
"""
|
||||||
|
Compute sigma_t of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||||
|
|
||||||
|
def marginal_lambda(self, t):
|
||||||
|
"""
|
||||||
|
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||||
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||||
|
return log_mean_coeff - log_std
|
||||||
|
|
||||||
|
def inverse_lambda(self, lamb):
|
||||||
|
"""
|
||||||
|
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
||||||
|
"""
|
||||||
|
if self.schedule == 'linear':
|
||||||
|
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||||
|
Delta = self.beta_0**2 + tmp
|
||||||
|
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||||
|
elif self.schedule == 'discrete':
|
||||||
|
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
||||||
|
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
||||||
|
return t.reshape((-1,))
|
||||||
|
else:
|
||||||
|
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||||
|
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||||
|
t = t_fn(log_alpha)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def model_wrapper(
|
||||||
|
model,
|
||||||
|
noise_schedule,
|
||||||
|
model_type="noise",
|
||||||
|
model_kwargs={},
|
||||||
|
guidance_type="uncond",
|
||||||
|
#condition=None,
|
||||||
|
#unconditional_condition=None,
|
||||||
|
guidance_scale=1.,
|
||||||
|
classifier_fn=None,
|
||||||
|
classifier_kwargs={},
|
||||||
|
):
|
||||||
|
"""Create a wrapper function for the noise prediction model.
|
||||||
|
|
||||||
|
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
||||||
|
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
||||||
|
|
||||||
|
We support four types of the diffusion model by setting `model_type`:
|
||||||
|
|
||||||
|
1. "noise": noise prediction model. (Trained by predicting noise).
|
||||||
|
|
||||||
|
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
||||||
|
|
||||||
|
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
||||||
|
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
||||||
|
|
||||||
|
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
||||||
|
arXiv preprint arXiv:2202.00512 (2022).
|
||||||
|
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||||
|
arXiv preprint arXiv:2210.02303 (2022).
|
||||||
|
|
||||||
|
4. "score": marginal score function. (Trained by denoising score matching).
|
||||||
|
Note that the score function and the noise prediction model follows a simple relationship:
|
||||||
|
```
|
||||||
|
noise(x_t, t) = -sigma_t * score(x_t, t)
|
||||||
|
```
|
||||||
|
|
||||||
|
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
||||||
|
1. "uncond": unconditional sampling by DPMs.
|
||||||
|
The input `model` has the following format:
|
||||||
|
``
|
||||||
|
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||||
|
``
|
||||||
|
|
||||||
|
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
||||||
|
The input `model` has the following format:
|
||||||
|
``
|
||||||
|
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||||
|
``
|
||||||
|
|
||||||
|
The input `classifier_fn` has the following format:
|
||||||
|
``
|
||||||
|
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
||||||
|
``
|
||||||
|
|
||||||
|
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
||||||
|
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
||||||
|
|
||||||
|
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
||||||
|
The input `model` has the following format:
|
||||||
|
``
|
||||||
|
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||||
|
``
|
||||||
|
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||||
|
|
||||||
|
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||||
|
arXiv preprint arXiv:2207.12598 (2022).
|
||||||
|
|
||||||
|
|
||||||
|
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||||
|
or continuous-time labels (i.e. epsilon to T).
|
||||||
|
|
||||||
|
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
||||||
|
``
|
||||||
|
def model_fn(x, t_continuous) -> noise:
|
||||||
|
t_input = get_model_input_time(t_continuous)
|
||||||
|
return noise_pred(model, x, t_input, **model_kwargs)
|
||||||
|
``
|
||||||
|
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||||
|
|
||||||
|
===============================================================
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A diffusion model with the corresponding format described above.
|
||||||
|
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
||||||
|
model_type: A `str`. The parameterization type of the diffusion model.
|
||||||
|
"noise" or "x_start" or "v" or "score".
|
||||||
|
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
||||||
|
guidance_type: A `str`. The type of the guidance for sampling.
|
||||||
|
"uncond" or "classifier" or "classifier-free".
|
||||||
|
condition: A pytorch tensor. The condition for the guided sampling.
|
||||||
|
Only used for "classifier" or "classifier-free" guidance type.
|
||||||
|
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
||||||
|
Only used for "classifier-free" guidance type.
|
||||||
|
guidance_scale: A `float`. The scale for the guided sampling.
|
||||||
|
classifier_fn: A classifier function. Only used for the classifier guidance.
|
||||||
|
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
||||||
|
Returns:
|
||||||
|
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_model_input_time(t_continuous):
|
||||||
|
"""
|
||||||
|
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||||
|
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
||||||
|
For continuous-time DPMs, we just use `t_continuous`.
|
||||||
|
"""
|
||||||
|
if noise_schedule.schedule == 'discrete':
|
||||||
|
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
||||||
|
else:
|
||||||
|
return t_continuous
|
||||||
|
|
||||||
|
def noise_pred_fn(x, t_continuous, cond=None):
|
||||||
|
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||||
|
t_continuous = t_continuous.expand((x.shape[0]))
|
||||||
|
t_input = get_model_input_time(t_continuous)
|
||||||
|
if cond is None:
|
||||||
|
output = model(x, t_input, None, **model_kwargs)
|
||||||
|
else:
|
||||||
|
output = model(x, t_input, cond, **model_kwargs)
|
||||||
|
if model_type == "noise":
|
||||||
|
return output
|
||||||
|
elif model_type == "x_start":
|
||||||
|
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||||
|
dims = x.dim()
|
||||||
|
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
||||||
|
elif model_type == "v":
|
||||||
|
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||||
|
dims = x.dim()
|
||||||
|
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
||||||
|
elif model_type == "score":
|
||||||
|
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||||
|
dims = x.dim()
|
||||||
|
return -expand_dims(sigma_t, dims) * output
|
||||||
|
|
||||||
|
def cond_grad_fn(x, t_input, condition):
|
||||||
|
"""
|
||||||
|
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
||||||
|
"""
|
||||||
|
with torch.enable_grad():
|
||||||
|
x_in = x.detach().requires_grad_(True)
|
||||||
|
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
||||||
|
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
||||||
|
|
||||||
|
def model_fn(x, t_continuous, condition, unconditional_condition):
|
||||||
|
"""
|
||||||
|
The noise predicition model function that is used for DPM-Solver.
|
||||||
|
"""
|
||||||
|
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||||
|
t_continuous = t_continuous.expand((x.shape[0]))
|
||||||
|
if guidance_type == "uncond":
|
||||||
|
return noise_pred_fn(x, t_continuous)
|
||||||
|
elif guidance_type == "classifier":
|
||||||
|
assert classifier_fn is not None
|
||||||
|
t_input = get_model_input_time(t_continuous)
|
||||||
|
cond_grad = cond_grad_fn(x, t_input, condition)
|
||||||
|
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||||
|
noise = noise_pred_fn(x, t_continuous)
|
||||||
|
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
||||||
|
elif guidance_type == "classifier-free":
|
||||||
|
if guidance_scale == 1. or unconditional_condition is None:
|
||||||
|
return noise_pred_fn(x, t_continuous, cond=condition)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
|
if isinstance(condition, dict):
|
||||||
|
assert isinstance(unconditional_condition, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in condition:
|
||||||
|
if isinstance(condition[k], list):
|
||||||
|
c_in[k] = [torch.cat([
|
||||||
|
unconditional_condition[k][i],
|
||||||
|
condition[k][i]]) for i in range(len(condition[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([
|
||||||
|
unconditional_condition[k],
|
||||||
|
condition[k]])
|
||||||
|
elif isinstance(condition, list):
|
||||||
|
c_in = list()
|
||||||
|
assert isinstance(unconditional_condition, list)
|
||||||
|
for i in range(len(condition)):
|
||||||
|
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_condition, condition])
|
||||||
|
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||||
|
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||||
|
|
||||||
|
assert model_type in ["noise", "x_start", "v"]
|
||||||
|
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
||||||
|
return model_fn
|
||||||
|
|
||||||
|
|
||||||
|
class UniPC:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_fn,
|
||||||
|
noise_schedule,
|
||||||
|
predict_x0=True,
|
||||||
|
thresholding=False,
|
||||||
|
max_val=1.,
|
||||||
|
variant='bh1',
|
||||||
|
condition=None,
|
||||||
|
unconditional_condition=None,
|
||||||
|
before_sample=None,
|
||||||
|
after_sample=None,
|
||||||
|
after_update=None
|
||||||
|
):
|
||||||
|
"""Construct a UniPC.
|
||||||
|
|
||||||
|
We support both data_prediction and noise_prediction.
|
||||||
|
"""
|
||||||
|
self.model_fn_ = model_fn
|
||||||
|
self.noise_schedule = noise_schedule
|
||||||
|
self.variant = variant
|
||||||
|
self.predict_x0 = predict_x0
|
||||||
|
self.thresholding = thresholding
|
||||||
|
self.max_val = max_val
|
||||||
|
self.condition = condition
|
||||||
|
self.unconditional_condition = unconditional_condition
|
||||||
|
self.before_sample = before_sample
|
||||||
|
self.after_sample = after_sample
|
||||||
|
self.after_update = after_update
|
||||||
|
|
||||||
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
|
"""
|
||||||
|
The dynamic thresholding method.
|
||||||
|
"""
|
||||||
|
dims = x0.dim()
|
||||||
|
p = self.dynamic_thresholding_ratio
|
||||||
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||||
|
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||||
|
x0 = torch.clamp(x0, -s, s) / s
|
||||||
|
return x0
|
||||||
|
|
||||||
|
def model(self, x, t):
|
||||||
|
cond = self.condition
|
||||||
|
uncond = self.unconditional_condition
|
||||||
|
if self.before_sample is not None:
|
||||||
|
x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
|
||||||
|
res = self.model_fn_(x, t, cond, uncond)
|
||||||
|
if self.after_sample is not None:
|
||||||
|
x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
|
||||||
|
|
||||||
|
if isinstance(res, tuple):
|
||||||
|
# (None, pred_x0)
|
||||||
|
res = res[1]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def noise_prediction_fn(self, x, t):
|
||||||
|
"""
|
||||||
|
Return the noise prediction model.
|
||||||
|
"""
|
||||||
|
return self.model(x, t)
|
||||||
|
|
||||||
|
def data_prediction_fn(self, x, t):
|
||||||
|
"""
|
||||||
|
Return the data prediction model (with thresholding).
|
||||||
|
"""
|
||||||
|
noise = self.noise_prediction_fn(x, t)
|
||||||
|
dims = x.dim()
|
||||||
|
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||||
|
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
||||||
|
if self.thresholding:
|
||||||
|
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||||
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||||
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||||
|
x0 = torch.clamp(x0, -s, s) / s
|
||||||
|
return x0
|
||||||
|
|
||||||
|
def model_fn(self, x, t):
|
||||||
|
"""
|
||||||
|
Convert the model to the noise prediction model or the data prediction model.
|
||||||
|
"""
|
||||||
|
if self.predict_x0:
|
||||||
|
return self.data_prediction_fn(x, t)
|
||||||
|
else:
|
||||||
|
return self.noise_prediction_fn(x, t)
|
||||||
|
|
||||||
|
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
||||||
|
"""Compute the intermediate time steps for sampling.
|
||||||
|
"""
|
||||||
|
if skip_type == 'logSNR':
|
||||||
|
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
||||||
|
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
||||||
|
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
||||||
|
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
||||||
|
elif skip_type == 'time_uniform':
|
||||||
|
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||||
|
elif skip_type == 'time_quadratic':
|
||||||
|
t_order = 2
|
||||||
|
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||||
|
return t
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
||||||
|
|
||||||
|
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||||
|
"""
|
||||||
|
Get the order of each step for sampling by the singlestep DPM-Solver.
|
||||||
|
"""
|
||||||
|
if order == 3:
|
||||||
|
K = steps // 3 + 1
|
||||||
|
if steps % 3 == 0:
|
||||||
|
orders = [3,] * (K - 2) + [2, 1]
|
||||||
|
elif steps % 3 == 1:
|
||||||
|
orders = [3,] * (K - 1) + [1]
|
||||||
|
else:
|
||||||
|
orders = [3,] * (K - 1) + [2]
|
||||||
|
elif order == 2:
|
||||||
|
if steps % 2 == 0:
|
||||||
|
K = steps // 2
|
||||||
|
orders = [2,] * K
|
||||||
|
else:
|
||||||
|
K = steps // 2 + 1
|
||||||
|
orders = [2,] * (K - 1) + [1]
|
||||||
|
elif order == 1:
|
||||||
|
K = steps
|
||||||
|
orders = [1,] * steps
|
||||||
|
else:
|
||||||
|
raise ValueError("'order' must be '1' or '2' or '3'.")
|
||||||
|
if skip_type == 'logSNR':
|
||||||
|
# To reproduce the results in DPM-Solver paper
|
||||||
|
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
||||||
|
else:
|
||||||
|
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
||||||
|
return timesteps_outer, orders
|
||||||
|
|
||||||
|
def denoise_to_zero_fn(self, x, s):
|
||||||
|
"""
|
||||||
|
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||||
|
"""
|
||||||
|
return self.data_prediction_fn(x, s)
|
||||||
|
|
||||||
|
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
||||||
|
if len(t.shape) == 0:
|
||||||
|
t = t.view(-1)
|
||||||
|
if 'bh' in self.variant:
|
||||||
|
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
else:
|
||||||
|
assert self.variant == 'vary_coeff'
|
||||||
|
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
|
||||||
|
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||||
|
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||||
|
ns = self.noise_schedule
|
||||||
|
assert order <= len(model_prev_list)
|
||||||
|
|
||||||
|
# first compute rks
|
||||||
|
t_prev_0 = t_prev_list[-1]
|
||||||
|
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||||
|
lambda_t = ns.marginal_lambda(t)
|
||||||
|
model_prev_0 = model_prev_list[-1]
|
||||||
|
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||||
|
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
||||||
|
alpha_t = torch.exp(log_alpha_t)
|
||||||
|
|
||||||
|
h = lambda_t - lambda_prev_0
|
||||||
|
|
||||||
|
rks = []
|
||||||
|
D1s = []
|
||||||
|
for i in range(1, order):
|
||||||
|
t_prev_i = t_prev_list[-(i + 1)]
|
||||||
|
model_prev_i = model_prev_list[-(i + 1)]
|
||||||
|
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||||
|
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||||
|
rks.append(rk)
|
||||||
|
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||||
|
|
||||||
|
rks.append(1.)
|
||||||
|
rks = torch.tensor(rks, device=x.device)
|
||||||
|
|
||||||
|
K = len(rks)
|
||||||
|
# build C matrix
|
||||||
|
C = []
|
||||||
|
|
||||||
|
col = torch.ones_like(rks)
|
||||||
|
for k in range(1, K + 1):
|
||||||
|
C.append(col)
|
||||||
|
col = col * rks / (k + 1)
|
||||||
|
C = torch.stack(C, dim=1)
|
||||||
|
|
||||||
|
if len(D1s) > 0:
|
||||||
|
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||||
|
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
||||||
|
A_p = C_inv_p
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
#print('using corrector')
|
||||||
|
C_inv = torch.linalg.inv(C)
|
||||||
|
A_c = C_inv
|
||||||
|
|
||||||
|
hh = -h if self.predict_x0 else h
|
||||||
|
h_phi_1 = torch.expm1(hh)
|
||||||
|
h_phi_ks = []
|
||||||
|
factorial_k = 1
|
||||||
|
h_phi_k = h_phi_1
|
||||||
|
for k in range(1, K + 2):
|
||||||
|
h_phi_ks.append(h_phi_k)
|
||||||
|
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
||||||
|
factorial_k *= (k + 1)
|
||||||
|
|
||||||
|
model_t = None
|
||||||
|
if self.predict_x0:
|
||||||
|
x_t_ = (
|
||||||
|
sigma_t / sigma_prev_0 * x
|
||||||
|
- alpha_t * h_phi_1 * model_prev_0
|
||||||
|
)
|
||||||
|
# now predictor
|
||||||
|
x_t = x_t_
|
||||||
|
if len(D1s) > 0:
|
||||||
|
# compute the residuals for predictor
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||||
|
# now corrector
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_
|
||||||
|
k = 0
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||||
|
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||||
|
else:
|
||||||
|
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||||
|
x_t_ = (
|
||||||
|
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
||||||
|
- (sigma_t * h_phi_1) * model_prev_0
|
||||||
|
)
|
||||||
|
# now predictor
|
||||||
|
x_t = x_t_
|
||||||
|
if len(D1s) > 0:
|
||||||
|
# compute the residuals for predictor
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||||
|
# now corrector
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_
|
||||||
|
k = 0
|
||||||
|
for k in range(K - 1):
|
||||||
|
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||||
|
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||||
|
return x_t, model_t
|
||||||
|
|
||||||
|
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||||
|
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||||
|
ns = self.noise_schedule
|
||||||
|
assert order <= len(model_prev_list)
|
||||||
|
dims = x.dim()
|
||||||
|
|
||||||
|
# first compute rks
|
||||||
|
t_prev_0 = t_prev_list[-1]
|
||||||
|
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||||
|
lambda_t = ns.marginal_lambda(t)
|
||||||
|
model_prev_0 = model_prev_list[-1]
|
||||||
|
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||||
|
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||||
|
alpha_t = torch.exp(log_alpha_t)
|
||||||
|
|
||||||
|
h = lambda_t - lambda_prev_0
|
||||||
|
|
||||||
|
rks = []
|
||||||
|
D1s = []
|
||||||
|
for i in range(1, order):
|
||||||
|
t_prev_i = t_prev_list[-(i + 1)]
|
||||||
|
model_prev_i = model_prev_list[-(i + 1)]
|
||||||
|
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||||
|
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
||||||
|
rks.append(rk)
|
||||||
|
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||||
|
|
||||||
|
rks.append(1.)
|
||||||
|
rks = torch.tensor(rks, device=x.device)
|
||||||
|
|
||||||
|
R = []
|
||||||
|
b = []
|
||||||
|
|
||||||
|
hh = -h[0] if self.predict_x0 else h[0]
|
||||||
|
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||||
|
h_phi_k = h_phi_1 / hh - 1
|
||||||
|
|
||||||
|
factorial_i = 1
|
||||||
|
|
||||||
|
if self.variant == 'bh1':
|
||||||
|
B_h = hh
|
||||||
|
elif self.variant == 'bh2':
|
||||||
|
B_h = torch.expm1(hh)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
for i in range(1, order + 1):
|
||||||
|
R.append(torch.pow(rks, i - 1))
|
||||||
|
b.append(h_phi_k * factorial_i / B_h)
|
||||||
|
factorial_i *= (i + 1)
|
||||||
|
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||||
|
|
||||||
|
R = torch.stack(R)
|
||||||
|
b = torch.tensor(b, device=x.device)
|
||||||
|
|
||||||
|
# now predictor
|
||||||
|
use_predictor = len(D1s) > 0 and x_t is None
|
||||||
|
if len(D1s) > 0:
|
||||||
|
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||||
|
if x_t is None:
|
||||||
|
# for order 2, we use a simplified version
|
||||||
|
if order == 2:
|
||||||
|
rhos_p = torch.tensor([0.5], device=b.device)
|
||||||
|
else:
|
||||||
|
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||||
|
else:
|
||||||
|
D1s = None
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
#print('using corrector')
|
||||||
|
# for order 1, we use a simplified version
|
||||||
|
if order == 1:
|
||||||
|
rhos_c = torch.tensor([0.5], device=b.device)
|
||||||
|
else:
|
||||||
|
rhos_c = torch.linalg.solve(R, b)
|
||||||
|
|
||||||
|
model_t = None
|
||||||
|
if self.predict_x0:
|
||||||
|
x_t_ = (
|
||||||
|
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
||||||
|
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
|
||||||
|
)
|
||||||
|
|
||||||
|
if x_t is None:
|
||||||
|
if use_predictor:
|
||||||
|
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||||
|
else:
|
||||||
|
pred_res = 0
|
||||||
|
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
if D1s is not None:
|
||||||
|
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||||
|
else:
|
||||||
|
corr_res = 0
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||||
|
else:
|
||||||
|
x_t_ = (
|
||||||
|
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
||||||
|
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
||||||
|
)
|
||||||
|
if x_t is None:
|
||||||
|
if use_predictor:
|
||||||
|
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||||
|
else:
|
||||||
|
pred_res = 0
|
||||||
|
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
||||||
|
|
||||||
|
if use_corrector:
|
||||||
|
model_t = self.model_fn(x_t, t)
|
||||||
|
if D1s is not None:
|
||||||
|
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||||
|
else:
|
||||||
|
corr_res = 0
|
||||||
|
D1_t = (model_t - model_prev_0)
|
||||||
|
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||||
|
return x_t, model_t
|
||||||
|
|
||||||
|
|
||||||
|
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||||
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||||
|
atol=0.0078, rtol=0.05, corrector=False,
|
||||||
|
):
|
||||||
|
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||||
|
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
|
device = x.device
|
||||||
|
if method == 'multistep':
|
||||||
|
assert steps >= order, "UniPC order must be < sampling steps"
|
||||||
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||||
|
#print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
||||||
|
assert timesteps.shape[0] - 1 == steps
|
||||||
|
with torch.no_grad():
|
||||||
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
|
t_prev_list = [vec_t]
|
||||||
|
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||||
|
for init_order in range(1, order):
|
||||||
|
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
model_prev_list.append(model_x)
|
||||||
|
t_prev_list.append(vec_t)
|
||||||
|
for step in trange(order, steps + 1):
|
||||||
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
|
if lower_order_final:
|
||||||
|
step_order = min(order, steps + 1 - step)
|
||||||
|
else:
|
||||||
|
step_order = order
|
||||||
|
#print('this step order:', step_order)
|
||||||
|
if step == steps:
|
||||||
|
#print('do not run corrector at the last step')
|
||||||
|
use_corrector = False
|
||||||
|
else:
|
||||||
|
use_corrector = True
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
for i in range(order - 1):
|
||||||
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
t_prev_list[-1] = vec_t
|
||||||
|
# We do not need to evaluate the final model value.
|
||||||
|
if step < steps:
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
if denoise_to_zero:
|
||||||
|
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
#############################################################
|
||||||
|
# other utility functions
|
||||||
|
#############################################################
|
||||||
|
|
||||||
|
def interpolate_fn(x, xp, yp):
|
||||||
|
"""
|
||||||
|
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
||||||
|
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
||||||
|
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
||||||
|
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
||||||
|
yp: PyTorch tensor with shape [C, K].
|
||||||
|
Returns:
|
||||||
|
The function values f(x), with shape [N, C].
|
||||||
|
"""
|
||||||
|
N, K = x.shape[0], xp.shape[1]
|
||||||
|
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
||||||
|
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
||||||
|
x_idx = torch.argmin(x_indices, dim=2)
|
||||||
|
cand_start_idx = x_idx - 1
|
||||||
|
start_idx = torch.where(
|
||||||
|
torch.eq(x_idx, 0),
|
||||||
|
torch.tensor(1, device=x.device),
|
||||||
|
torch.where(
|
||||||
|
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
||||||
|
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
||||||
|
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
||||||
|
start_idx2 = torch.where(
|
||||||
|
torch.eq(x_idx, 0),
|
||||||
|
torch.tensor(0, device=x.device),
|
||||||
|
torch.where(
|
||||||
|
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||||
|
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||||
|
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||||
|
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||||
|
return cand
|
||||||
|
|
||||||
|
|
||||||
|
def expand_dims(v, dims):
|
||||||
|
"""
|
||||||
|
Expand the tensor `v` to the dim `dims`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
`v`: a PyTorch tensor with shape [N].
|
||||||
|
`dim`: a `int`.
|
||||||
|
Returns:
|
||||||
|
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
||||||
|
"""
|
||||||
|
return v[(...,) + (None,)*(dims - 1)]
|
@ -597,6 +597,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
|
extra_network_data = None
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
p.iteration = n
|
p.iteration = n
|
||||||
|
|
||||||
@ -620,6 +621,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||||
|
|
||||||
if len(prompts) == 0:
|
if len(prompts) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -753,7 +757,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
if not p.disable_extra_networks:
|
if not p.disable_extra_networks and extra_network_data:
|
||||||
extra_networks.deactivate(p, extra_network_data)
|
extra_networks.deactivate(p, extra_network_data)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
@ -944,7 +948,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
|
img2img_sampler_name = self.sampler_name
|
||||||
|
|
||||||
|
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
||||||
|
img2img_sampler_name = 'DDIM'
|
||||||
|
|
||||||
if self.hr_sampler == '---':
|
if self.hr_sampler == '---':
|
||||||
pass
|
pass
|
||||||
|
@ -29,7 +29,7 @@ class ImageSaveParams:
|
|||||||
|
|
||||||
|
|
||||||
class CFGDenoiserParams:
|
class CFGDenoiserParams:
|
||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
@ -44,6 +44,12 @@ class CFGDenoiserParams:
|
|||||||
|
|
||||||
self.total_sampling_steps = total_sampling_steps
|
self.total_sampling_steps = total_sampling_steps
|
||||||
"""Total number of sampling steps planned"""
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
|
self.text_cond = text_cond
|
||||||
|
""" Encoder hidden states of text conditioning from prompt"""
|
||||||
|
|
||||||
|
self.text_uncond = text_uncond
|
||||||
|
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoisedParams:
|
class CFGDenoisedParams:
|
||||||
|
@ -33,6 +33,11 @@ class Script:
|
|||||||
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
|
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
paste_field_names = None
|
||||||
|
"""if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
|
||||||
|
various "Send to <X>" buttons when clicked
|
||||||
|
"""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
@ -80,6 +85,20 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def before_process_batch(self, p, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Called before extra networks are parsed from the prompt, so you can add
|
||||||
|
new extra network keywords to the prompt with this callback.
|
||||||
|
|
||||||
|
**kwargs will have those items:
|
||||||
|
- batch_number - index of current batch, from 0 to number of batches-1
|
||||||
|
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||||
|
- seeds - list of seeds for current batch
|
||||||
|
- subseeds - list of subseeds for current batch
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def process_batch(self, p, *args, **kwargs):
|
def process_batch(self, p, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Same as process(), but called for every batch.
|
Same as process(), but called for every batch.
|
||||||
@ -256,6 +275,7 @@ class ScriptRunner:
|
|||||||
self.alwayson_scripts = []
|
self.alwayson_scripts = []
|
||||||
self.titles = []
|
self.titles = []
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
|
self.paste_field_names = []
|
||||||
|
|
||||||
def initialize_scripts(self, is_img2img):
|
def initialize_scripts(self, is_img2img):
|
||||||
from modules import scripts_auto_postprocessing
|
from modules import scripts_auto_postprocessing
|
||||||
@ -304,6 +324,9 @@ class ScriptRunner:
|
|||||||
if script.infotext_fields is not None:
|
if script.infotext_fields is not None:
|
||||||
self.infotext_fields += script.infotext_fields
|
self.infotext_fields += script.infotext_fields
|
||||||
|
|
||||||
|
if script.paste_field_names is not None:
|
||||||
|
self.paste_field_names += script.paste_field_names
|
||||||
|
|
||||||
inputs += controls
|
inputs += controls
|
||||||
inputs_alwayson += [script.alwayson for _ in controls]
|
inputs_alwayson += [script.alwayson for _ in controls]
|
||||||
script.args_to = len(inputs)
|
script.args_to = len(inputs)
|
||||||
@ -388,6 +411,15 @@ class ScriptRunner:
|
|||||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
def before_process_batch(self, p, **kwargs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.before_process_batch(p, *script_args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
def process_batch(self, p, **kwargs):
|
def process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
|
@ -37,11 +37,23 @@ def apply_optimizations():
|
|||||||
|
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
|
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||||
optimization_method = 'xformers'
|
optimization_method = 'xformers'
|
||||||
|
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
|
||||||
|
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
|
||||||
|
optimization_method = 'sdp-no-mem'
|
||||||
|
elif cmd_opts.opt_sdp_attention and can_use_sdp:
|
||||||
|
print("Applying scaled dot product cross attention optimization.")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
|
||||||
|
optimization_method = 'sdp'
|
||||||
elif cmd_opts.opt_sub_quad_attention:
|
elif cmd_opts.opt_sub_quad_attention:
|
||||||
print("Applying sub-quadratic cross attention optimization.")
|
print("Applying sub-quadratic cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||||
|
@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||||
|
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||||
|
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||||
|
batch_size, sequence_length, inner_dim = x.shape
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
||||||
|
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
||||||
|
|
||||||
|
h = self.heads
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||||
|
k_in = self.to_k(context_k)
|
||||||
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
|
head_dim = inner_dim // h
|
||||||
|
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
dtype = q.dtype
|
||||||
|
if shared.opts.upcast_attn:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
|
||||||
|
hidden_states = hidden_states.to(dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = self.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
|
||||||
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
|
return scaled_dot_product_attention_forward(self, x, context, mask)
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
def cross_attention_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
def sdp_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
dtype = q.dtype
|
||||||
|
if shared.opts.upcast_attn:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
||||||
|
out = out.to(dtype)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
|
out = self.proj_out(out)
|
||||||
|
return x + out
|
||||||
|
|
||||||
|
def sdp_no_mem_attnblock_forward(self, x):
|
||||||
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
|
return sdp_attnblock_forward(self, x)
|
||||||
|
|
||||||
def sub_quad_attnblock_forward(self, x):
|
def sub_quad_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||||||
return pl_sd
|
return pl_sd
|
||||||
|
|
||||||
|
|
||||||
|
def read_metadata_from_safetensors(filename):
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(filename, mode="rb") as file:
|
||||||
|
metadata_len = file.read(8)
|
||||||
|
metadata_len = int.from_bytes(metadata_len, "little")
|
||||||
|
json_start = file.read(2)
|
||||||
|
|
||||||
|
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
|
||||||
|
json_data = json_start + file.read(metadata_len-2)
|
||||||
|
json_obj = json.loads(json_data)
|
||||||
|
|
||||||
|
res = {}
|
||||||
|
for k, v in json_obj.get("__metadata__", {}).items():
|
||||||
|
res[k] = v
|
||||||
|
if isinstance(v, str) and v[0:1] == '{':
|
||||||
|
try:
|
||||||
|
res[k] = json.loads(v)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
|
@ -32,7 +32,7 @@ def set_samplers():
|
|||||||
global samplers, samplers_for_img2img
|
global samplers, samplers_for_img2img
|
||||||
|
|
||||||
hidden = set(shared.opts.hide_samplers)
|
hidden = set(shared.opts.hide_samplers)
|
||||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
|
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
||||||
|
|
||||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||||
|
@ -7,19 +7,27 @@ import torch
|
|||||||
|
|
||||||
from modules.shared import state
|
from modules.shared import state
|
||||||
from modules import sd_samplers_common, prompt_parser, shared
|
from modules import sd_samplers_common, prompt_parser, shared
|
||||||
|
import modules.models.diffusion.uni_pc
|
||||||
|
|
||||||
|
|
||||||
samplers_data_compvis = [
|
samplers_data_compvis = [
|
||||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||||
|
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
class VanillaStableDiffusionSampler:
|
||||||
def __init__(self, constructor, sd_model):
|
def __init__(self, constructor, sd_model):
|
||||||
self.sampler = constructor(sd_model)
|
self.sampler = constructor(sd_model)
|
||||||
|
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
||||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
||||||
|
self.orig_p_sample_ddim = None
|
||||||
|
if self.is_plms:
|
||||||
|
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
||||||
|
elif self.is_ddim:
|
||||||
|
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler:
|
|||||||
return self.last_latent
|
return self.last_latent
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
|
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||||
|
|
||||||
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
|
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x = img_orig * self.mask + self.nmask * x
|
||||||
|
|
||||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||||
# Note that they need to be lists because it just concatenates them later.
|
# Note that they need to be lists because it just concatenates them later.
|
||||||
@ -84,12 +101,13 @@ class VanillaStableDiffusionSampler:
|
|||||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
return x, ts, cond, unconditional_conditioning
|
||||||
|
|
||||||
|
def update_step(self, last_latent):
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
||||||
else:
|
else:
|
||||||
self.last_latent = res[1]
|
self.last_latent = last_latent
|
||||||
|
|
||||||
sd_samplers_common.store_latent(self.last_latent)
|
sd_samplers_common.store_latent(self.last_latent)
|
||||||
|
|
||||||
@ -97,26 +115,51 @@ class VanillaStableDiffusionSampler:
|
|||||||
state.sampling_step = self.step
|
state.sampling_step = self.step
|
||||||
shared.total_tqdm.update()
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
return res
|
def after_sample(self, x, ts, cond, uncond, res):
|
||||||
|
if not self.is_unipc:
|
||||||
|
self.update_step(res[1])
|
||||||
|
|
||||||
|
return x, ts, cond, uncond, res
|
||||||
|
|
||||||
|
def unipc_after_update(self, x, model_x):
|
||||||
|
self.update_step(x)
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||||
if self.eta != 0.0:
|
if self.eta != 0.0:
|
||||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||||
|
|
||||||
|
if self.is_unipc:
|
||||||
|
keys = [
|
||||||
|
('UniPC variant', 'uni_pc_variant'),
|
||||||
|
('UniPC skip type', 'uni_pc_skip_type'),
|
||||||
|
('UniPC order', 'uni_pc_order'),
|
||||||
|
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, key in keys:
|
||||||
|
v = getattr(shared.opts, key)
|
||||||
|
if v != shared.opts.get_default(key):
|
||||||
|
p.extra_generation_params[name] = v
|
||||||
|
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||||
if hasattr(self.sampler, fieldname):
|
if hasattr(self.sampler, fieldname):
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||||
|
if self.is_unipc:
|
||||||
|
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
||||||
|
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
|
|
||||||
def adjust_steps_if_invalid(self, p, num_steps):
|
def adjust_steps_if_invalid(self, p, num_steps):
|
||||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
|
||||||
|
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
|
||||||
|
num_steps = shared.opts.uni_pc_order
|
||||||
valid_step = 999 / (1000 // num_steps)
|
valid_step = 999 / (1000 // num_steps)
|
||||||
if valid_step == math.floor(valid_step):
|
if valid_step == math.floor(valid_step):
|
||||||
return int(valid_step) + 1
|
return int(valid_step) + 1
|
||||||
|
|
||||||
return num_steps
|
return num_steps
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||||
cfg_denoiser_callback(denoiser_params)
|
cfg_denoiser_callback(denoiser_params)
|
||||||
x_in = denoiser_params.x
|
x_in = denoiser_params.x
|
||||||
image_cond_in = denoiser_params.image_cond
|
image_cond_in = denoiser_params.image_cond
|
||||||
sigma_in = denoiser_params.sigma
|
sigma_in = denoiser_params.sigma
|
||||||
|
tensor = denoiser_params.text_cond
|
||||||
|
uncond = denoiser_params.text_uncond
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1]:
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
|
@ -35,8 +35,11 @@ def model():
|
|||||||
global sd_vae_approx_model
|
global sd_vae_approx_model
|
||||||
|
|
||||||
if sd_vae_approx_model is None:
|
if sd_vae_approx_model is None:
|
||||||
|
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
|
||||||
sd_vae_approx_model = VAEApprox()
|
sd_vae_approx_model = VAEApprox()
|
||||||
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
|
if not os.path.exists(model_path):
|
||||||
|
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
|
||||||
|
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||||
sd_vae_approx_model.eval()
|
sd_vae_approx_model.eval()
|
||||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||||
|
|
||||||
|
@ -69,6 +69,8 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo
|
|||||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
|
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||||
|
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||||
@ -114,7 +116,10 @@ parser.add_argument("--no-download-sd-model", action='store_true', help="don't d
|
|||||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||||
|
cmd_opts = parser.parse_args()
|
||||||
|
else:
|
||||||
|
cmd_opts, _ = parser.parse_known_args()
|
||||||
|
|
||||||
restricted_opts = {
|
restricted_opts = {
|
||||||
"samples_filename_pattern",
|
"samples_filename_pattern",
|
||||||
@ -305,6 +310,7 @@ def list_samplers():
|
|||||||
|
|
||||||
|
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
tab_names = []
|
||||||
|
|
||||||
options_templates = {}
|
options_templates = {}
|
||||||
|
|
||||||
@ -327,9 +333,11 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
|
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
||||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||||
|
"img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
|
||||||
|
|
||||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||||
@ -440,6 +448,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -460,6 +469,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||||
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
||||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||||
@ -485,6 +495,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
||||||
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||||
|
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
||||||
|
'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
|
||||||
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||||
@ -559,6 +573,15 @@ class Options:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_default(self, key):
|
||||||
|
"""returns the default value for the key"""
|
||||||
|
|
||||||
|
data_label = self.data_labels.get(key)
|
||||||
|
if data_label is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data_label.default
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||||
|
|
||||||
@ -691,6 +714,7 @@ class TotalTQDM:
|
|||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
if self._tqdm is not None:
|
if self._tqdm is not None:
|
||||||
|
self._tqdm.refresh()
|
||||||
self._tqdm.close()
|
self._tqdm.close()
|
||||||
self._tqdm = None
|
self._tqdm = None
|
||||||
|
|
||||||
|
@ -33,3 +33,6 @@ class Timer:
|
|||||||
res += ")"
|
res += ")"
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__init__()
|
||||||
|
@ -957,7 +957,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
|
||||||
|
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||||
|
|
||||||
@ -1581,6 +1581,10 @@ def create_ui():
|
|||||||
extensions_interface = ui_extensions.create_ui()
|
extensions_interface = ui_extensions.create_ui()
|
||||||
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
||||||
|
|
||||||
|
shared.tab_names = []
|
||||||
|
for _interface, label, _ifid in interfaces:
|
||||||
|
shared.tab_names.append(label)
|
||||||
|
|
||||||
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||||
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
||||||
@ -1591,6 +1595,8 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.Tabs(elem_id="tabs") as tabs:
|
with gr.Tabs(elem_id="tabs") as tabs:
|
||||||
for interface, label, ifid in interfaces:
|
for interface, label, ifid in interfaces:
|
||||||
|
if label in shared.opts.hidden_tabs:
|
||||||
|
continue
|
||||||
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
||||||
interface.render()
|
interface.render()
|
||||||
|
|
||||||
@ -1763,7 +1769,8 @@ def create_ui():
|
|||||||
|
|
||||||
|
|
||||||
def reload_javascript():
|
def reload_javascript():
|
||||||
head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
|
script_js = os.path.join(script_path, "script.js")
|
||||||
|
head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
|
||||||
|
|
||||||
inline = f"{localization.localization_js(shared.opts.localization)};"
|
inline = f"{localization.localization_js(shared.opts.localization)};"
|
||||||
if cmd_opts.theme is not None:
|
if cmd_opts.theme is not None:
|
||||||
@ -1772,6 +1779,9 @@ def reload_javascript():
|
|||||||
for script in modules.scripts.list_scripts("javascript", ".js"):
|
for script in modules.scripts.list_scripts("javascript", ".js"):
|
||||||
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
|
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
|
||||||
|
|
||||||
|
for script in modules.scripts.list_scripts("javascript", ".mjs"):
|
||||||
|
head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
|
||||||
|
|
||||||
head += f'<script type="text/javascript">{inline}</script>\n'
|
head += f'<script type="text/javascript">{inline}</script>\n'
|
||||||
|
|
||||||
def template_response(*args, **kwargs):
|
def template_response(*args, **kwargs):
|
||||||
|
@ -198,9 +198,16 @@ Requested path was: {f}
|
|||||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
|
||||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||||
|
|
||||||
|
paste_field_names = []
|
||||||
|
if tabname == "txt2img":
|
||||||
|
paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
|
||||||
|
elif tabname == "img2img":
|
||||||
|
paste_field_names = modules.scripts.scripts_img2img.paste_field_names
|
||||||
|
|
||||||
for paste_tabname, paste_button in buttons.items():
|
for paste_tabname, paste_button in buttons.items():
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||||
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
|
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
|
||||||
|
paste_field_names=paste_field_names
|
||||||
))
|
))
|
||||||
|
|
||||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
||||||
|
@ -304,7 +304,7 @@ def create_ui():
|
|||||||
with gr.TabItem("Available"):
|
with gr.TabItem("Available"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||||
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
|
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
|
||||||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||||
|
|
||||||
|
@ -30,8 +30,8 @@ def add_pages_to_demo(app):
|
|||||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||||
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
if ext not in (".png", ".jpg"):
|
if ext not in (".png", ".jpg", ".webp"):
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
|
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||||
|
|
||||||
# would profit from returning 304
|
# would profit from returning 304
|
||||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||||
@ -124,19 +124,56 @@ class ExtraNetworksPage:
|
|||||||
if onclick is None:
|
if onclick is None:
|
||||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||||
|
|
||||||
|
metadata_button = ""
|
||||||
|
metadata = item.get("metadata")
|
||||||
|
if metadata:
|
||||||
|
metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
|
||||||
|
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
|
||||||
"prompt": item.get("prompt", None),
|
"prompt": item.get("prompt", None),
|
||||||
"tabname": json.dumps(tabname),
|
"tabname": json.dumps(tabname),
|
||||||
"local_preview": json.dumps(item["local_preview"]),
|
"local_preview": json.dumps(item["local_preview"]),
|
||||||
"name": item["name"],
|
"name": item["name"],
|
||||||
|
"description": (item.get("description") or ""),
|
||||||
"card_clicked": onclick,
|
"card_clicked": onclick,
|
||||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||||
"search_term": item.get("search_term", ""),
|
"search_term": item.get("search_term", ""),
|
||||||
|
"metadata_button": metadata_button,
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.card_page.format(**args)
|
return self.card_page.format(**args)
|
||||||
|
|
||||||
|
def find_preview(self, path):
|
||||||
|
"""
|
||||||
|
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
preview_extensions = ["png", "jpg", "webp"]
|
||||||
|
if shared.opts.samples_format not in preview_extensions:
|
||||||
|
preview_extensions.append(shared.opts.samples_format)
|
||||||
|
|
||||||
|
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
|
||||||
|
|
||||||
|
for file in potential_files:
|
||||||
|
if os.path.isfile(file):
|
||||||
|
return self.link_preview(file)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_description(self, path):
|
||||||
|
"""
|
||||||
|
Find and read a description file for a given path (without extension).
|
||||||
|
"""
|
||||||
|
for file in [f"{path}.txt", f"{path}.description.txt"]:
|
||||||
|
try:
|
||||||
|
with open(file, "r", encoding="utf-8", errors="replace") as f:
|
||||||
|
return f.read()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def intialize():
|
def intialize():
|
||||||
extra_pages.clear()
|
extra_pages.clear()
|
||||||
@ -183,7 +220,6 @@ def create_ui(container, button, tabname):
|
|||||||
|
|
||||||
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||||
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
|
|
||||||
|
|
||||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||||
@ -194,7 +230,6 @@ def create_ui(container, button, tabname):
|
|||||||
|
|
||||||
state_visible = gr.State(value=False)
|
state_visible = gr.State(value=False)
|
||||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
||||||
button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
|
|
||||||
|
|
||||||
def refresh():
|
def refresh():
|
||||||
res = []
|
res = []
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import html
|
import html
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import urllib.parse
|
|
||||||
|
|
||||||
from modules import shared, ui_extra_networks, sd_models
|
from modules import shared, ui_extra_networks, sd_models
|
||||||
|
|
||||||
@ -17,21 +16,14 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
checkpoint: sd_models.CheckpointInfo
|
checkpoint: sd_models.CheckpointInfo
|
||||||
for name, checkpoint in sd_models.checkpoints_list.items():
|
for name, checkpoint in sd_models.checkpoints_list.items():
|
||||||
path, ext = os.path.splitext(checkpoint.filename)
|
path, ext = os.path.splitext(checkpoint.filename)
|
||||||
previews = [path + ".png", path + ".preview.png"]
|
|
||||||
|
|
||||||
preview = None
|
|
||||||
for file in previews:
|
|
||||||
if os.path.isfile(file):
|
|
||||||
preview = self.link_preview(file)
|
|
||||||
break
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": checkpoint.name_for_extra,
|
"name": checkpoint.name_for_extra,
|
||||||
"filename": path,
|
"filename": path,
|
||||||
"preview": preview,
|
"preview": self.find_preview(path),
|
||||||
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||||
"local_preview": path + ".png",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
}
|
}
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
|
@ -14,21 +14,15 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def list_items(self):
|
def list_items(self):
|
||||||
for name, path in shared.hypernetworks.items():
|
for name, path in shared.hypernetworks.items():
|
||||||
path, ext = os.path.splitext(path)
|
path, ext = os.path.splitext(path)
|
||||||
previews = [path + ".png", path + ".preview.png"]
|
|
||||||
|
|
||||||
preview = None
|
|
||||||
for file in previews:
|
|
||||||
if os.path.isfile(file):
|
|
||||||
preview = self.link_preview(file)
|
|
||||||
break
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": path,
|
"filename": path,
|
||||||
"preview": preview,
|
"preview": self.find_preview(path),
|
||||||
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(path),
|
"search_term": self.search_terms_from_path(path),
|
||||||
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": path + ".png",
|
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||||
}
|
}
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from modules import ui_extra_networks, sd_hijack
|
from modules import ui_extra_networks, sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||||
@ -15,19 +15,14 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def list_items(self):
|
def list_items(self):
|
||||||
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
|
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
|
||||||
path, ext = os.path.splitext(embedding.filename)
|
path, ext = os.path.splitext(embedding.filename)
|
||||||
preview_file = path + ".preview.png"
|
|
||||||
|
|
||||||
preview = None
|
|
||||||
if os.path.isfile(preview_file):
|
|
||||||
preview = self.link_preview(preview_file)
|
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": embedding.name,
|
"name": embedding.name,
|
||||||
"filename": embedding.filename,
|
"filename": embedding.filename,
|
||||||
"preview": preview,
|
"preview": self.find_preview(path),
|
||||||
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(embedding.filename),
|
"search_term": self.search_terms_from_path(embedding.filename),
|
||||||
"prompt": json.dumps(embedding.name),
|
"prompt": json.dumps(embedding.name),
|
||||||
"local_preview": path + ".preview.png",
|
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||||
}
|
}
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
|
@ -23,8 +23,8 @@ torchdiffeq==0.2.3
|
|||||||
kornia==0.6.7
|
kornia==0.6.7
|
||||||
lark==1.1.2
|
lark==1.1.2
|
||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
GitPython==3.1.27
|
GitPython==3.1.30
|
||||||
torchsde==0.2.5
|
torchsde==0.2.5
|
||||||
safetensors==0.2.7
|
safetensors==0.2.7
|
||||||
httpcore<=0.15
|
httpcore<=0.15
|
||||||
fastapi==0.90.1
|
fastapi==0.94.0
|
||||||
|
@ -100,7 +100,7 @@ class Script(scripts.Script):
|
|||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
||||||
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size)
|
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
|
||||||
processed.images.insert(0, grid)
|
processed.images.insert(0, grid)
|
||||||
processed.index_of_first_image = 1
|
processed.index_of_first_image = 1
|
||||||
processed.infotexts.insert(0, processed.infotexts[0])
|
processed.infotexts.insert(0, processed.infotexts[0])
|
||||||
|
@ -128,6 +128,24 @@ def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
|
|||||||
p.styles.extend(x.split(','))
|
p.styles.extend(x.split(','))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_uni_pc_order(p, x, xs):
|
||||||
|
opts.data["uni_pc_order"] = min(x, p.steps - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_face_restore(p, opt, x):
|
||||||
|
opt = opt.lower()
|
||||||
|
if opt == 'codeformer':
|
||||||
|
is_active = True
|
||||||
|
p.face_restoration_model = 'CodeFormer'
|
||||||
|
elif opt == 'gfpgan':
|
||||||
|
is_active = True
|
||||||
|
p.face_restoration_model = 'GFPGAN'
|
||||||
|
else:
|
||||||
|
is_active = opt in ('true', 'yes', 'y', '1')
|
||||||
|
|
||||||
|
p.restore_faces = is_active
|
||||||
|
|
||||||
|
|
||||||
def format_value_add_label(p, opt, x):
|
def format_value_add_label(p, opt, x):
|
||||||
if type(x) == float:
|
if type(x) == float:
|
||||||
x = round(x, 8)
|
x = round(x, 8)
|
||||||
@ -205,6 +223,8 @@ axis_options = [
|
|||||||
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
|
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
|
||||||
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
|
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
|
||||||
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
|
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
|
||||||
|
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
|
||||||
|
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -213,49 +233,47 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
|||||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||||
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
||||||
|
|
||||||
# Temporary list of all the images that are generated to be populated into the grid.
|
list_size = (len(xs) * len(ys) * len(zs))
|
||||||
# Will be filled with empty images for any individual step that fails to process properly
|
|
||||||
image_cache = [None] * (len(xs) * len(ys) * len(zs))
|
|
||||||
|
|
||||||
processed_result = None
|
processed_result = None
|
||||||
cell_mode = "P"
|
|
||||||
cell_size = (1, 1)
|
|
||||||
|
|
||||||
state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
|
state.job_count = list_size * p.n_iter
|
||||||
|
|
||||||
def process_cell(x, y, z, ix, iy, iz):
|
def process_cell(x, y, z, ix, iy, iz):
|
||||||
nonlocal image_cache, processed_result, cell_mode, cell_size
|
nonlocal processed_result
|
||||||
|
|
||||||
def index(ix, iy, iz):
|
def index(ix, iy, iz):
|
||||||
return ix + iy * len(xs) + iz * len(xs) * len(ys)
|
return ix + iy * len(xs) + iz * len(xs) * len(ys)
|
||||||
|
|
||||||
state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
|
state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
|
||||||
|
|
||||||
processed: Processed = cell(x, y, z)
|
processed: Processed = cell(x, y, z)
|
||||||
|
|
||||||
try:
|
if processed_result is None:
|
||||||
# this dereference will throw an exception if the image was not processed
|
# Use our first processed result object as a template container to hold our full results
|
||||||
# (this happens in cases such as if the user stops the process from the UI)
|
processed_result = copy(processed)
|
||||||
processed_image = processed.images[0]
|
processed_result.images = [None] * list_size
|
||||||
|
processed_result.all_prompts = [None] * list_size
|
||||||
|
processed_result.all_seeds = [None] * list_size
|
||||||
|
processed_result.infotexts = [None] * list_size
|
||||||
|
processed_result.index_of_first_image = 1
|
||||||
|
|
||||||
if processed_result is None:
|
idx = index(ix, iy, iz)
|
||||||
# Use our first valid processed result as a template container to hold our full results
|
if processed.images:
|
||||||
processed_result = copy(processed)
|
# Non-empty list indicates some degree of success.
|
||||||
cell_mode = processed_image.mode
|
processed_result.images[idx] = processed.images[0]
|
||||||
cell_size = processed_image.size
|
processed_result.all_prompts[idx] = processed.prompt
|
||||||
processed_result.images = [Image.new(cell_mode, cell_size)]
|
processed_result.all_seeds[idx] = processed.seed
|
||||||
processed_result.all_prompts = [processed.prompt]
|
processed_result.infotexts[idx] = processed.infotexts[0]
|
||||||
processed_result.all_seeds = [processed.seed]
|
else:
|
||||||
processed_result.infotexts = [processed.infotexts[0]]
|
cell_mode = "P"
|
||||||
|
cell_size = (processed_result.width, processed_result.height)
|
||||||
|
if processed_result.images[0] is not None:
|
||||||
|
cell_mode = processed_result.images[0].mode
|
||||||
|
#This corrects size in case of batches:
|
||||||
|
cell_size = processed_result.images[0].size
|
||||||
|
processed_result.images[idx] = Image.new(cell_mode, cell_size)
|
||||||
|
|
||||||
image_cache[index(ix, iy, iz)] = processed_image
|
|
||||||
if include_lone_images:
|
|
||||||
processed_result.images.append(processed_image)
|
|
||||||
processed_result.all_prompts.append(processed.prompt)
|
|
||||||
processed_result.all_seeds.append(processed.seed)
|
|
||||||
processed_result.infotexts.append(processed.infotexts[0])
|
|
||||||
except:
|
|
||||||
image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
|
|
||||||
|
|
||||||
if first_axes_processed == 'x':
|
if first_axes_processed == 'x':
|
||||||
for ix, x in enumerate(xs):
|
for ix, x in enumerate(xs):
|
||||||
@ -289,36 +307,48 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
|||||||
process_cell(x, y, z, ix, iy, iz)
|
process_cell(x, y, z, ix, iy, iz)
|
||||||
|
|
||||||
if not processed_result:
|
if not processed_result:
|
||||||
|
# Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
|
||||||
|
print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")
|
||||||
|
return Processed(p, [])
|
||||||
|
elif not any(processed_result.images):
|
||||||
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
||||||
return Processed(p, [])
|
return Processed(p, [])
|
||||||
|
|
||||||
sub_grids = [None] * len(zs)
|
z_count = len(zs)
|
||||||
for i in range(len(zs)):
|
sub_grids = [None] * z_count
|
||||||
start_index = i * len(xs) * len(ys)
|
for i in range(z_count):
|
||||||
|
start_index = (i * len(xs) * len(ys)) + i
|
||||||
end_index = start_index + len(xs) * len(ys)
|
end_index = start_index + len(xs) * len(ys)
|
||||||
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
|
grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))
|
||||||
if draw_legend:
|
if draw_legend:
|
||||||
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
|
grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size)
|
||||||
sub_grids[i] = grid
|
processed_result.images.insert(i, grid)
|
||||||
if include_sub_grids and len(zs) > 1:
|
processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])
|
||||||
processed_result.images.insert(i+1, grid)
|
processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])
|
||||||
|
processed_result.infotexts.insert(i, processed_result.infotexts[start_index])
|
||||||
|
|
||||||
sub_grid_size = sub_grids[0].size
|
sub_grid_size = processed_result.images[0].size
|
||||||
z_grid = images.image_grid(sub_grids, rows=1)
|
z_grid = images.image_grid(processed_result.images[:z_count], rows=1)
|
||||||
if draw_legend:
|
if draw_legend:
|
||||||
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||||
processed_result.images[0] = z_grid
|
processed_result.images.insert(0, z_grid)
|
||||||
|
#TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
|
||||||
|
#processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
|
||||||
|
#processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
|
||||||
|
processed_result.infotexts.insert(0, processed_result.infotexts[0])
|
||||||
|
|
||||||
return processed_result, sub_grids
|
return processed_result
|
||||||
|
|
||||||
|
|
||||||
class SharedSettingsStackHelper(object):
|
class SharedSettingsStackHelper(object):
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
||||||
self.vae = opts.sd_vae
|
self.vae = opts.sd_vae
|
||||||
|
self.uni_pc_order = opts.uni_pc_order
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, tb):
|
def __exit__(self, exc_type, exc_value, tb):
|
||||||
opts.data["sd_vae"] = self.vae
|
opts.data["sd_vae"] = self.vae
|
||||||
|
opts.data["uni_pc_order"] = self.uni_pc_order
|
||||||
modules.sd_models.reload_model_weights()
|
modules.sd_models.reload_model_weights()
|
||||||
modules.sd_vae.reload_vae_weights()
|
modules.sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
@ -418,7 +448,7 @@ class Script(scripts.Script):
|
|||||||
if opt.label == 'Nothing':
|
if opt.label == 'Nothing':
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
|
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
||||||
|
|
||||||
if opt.type == int:
|
if opt.type == int:
|
||||||
valslist_ext = []
|
valslist_ext = []
|
||||||
@ -484,6 +514,10 @@ class Script(scripts.Script):
|
|||||||
z_opt = self.current_axis_options[z_type]
|
z_opt = self.current_axis_options[z_type]
|
||||||
zs = process_axis(z_opt, z_values)
|
zs = process_axis(z_opt, z_values)
|
||||||
|
|
||||||
|
# this could be moved to common code, but unlikely to be ever triggered anywhere else
|
||||||
|
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
|
||||||
|
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
|
||||||
|
|
||||||
def fix_axis_seeds(axis_opt, axis_list):
|
def fix_axis_seeds(axis_opt, axis_list):
|
||||||
if axis_opt.label in ['Seed', 'Var. seed']:
|
if axis_opt.label in ['Seed', 'Var. seed']:
|
||||||
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
|
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
|
||||||
@ -533,7 +567,7 @@ class Script(scripts.Script):
|
|||||||
# If one of the axes is very slow to change between (like SD model
|
# If one of the axes is very slow to change between (like SD model
|
||||||
# checkpoint), then make sure it is in the outer iteration of the nested
|
# checkpoint), then make sure it is in the outer iteration of the nested
|
||||||
# `for` loop.
|
# `for` loop.
|
||||||
first_axes_processed = 'x'
|
first_axes_processed = 'z'
|
||||||
second_axes_processed = 'y'
|
second_axes_processed = 'y'
|
||||||
if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
|
if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
|
||||||
first_axes_processed = 'x'
|
first_axes_processed = 'x'
|
||||||
@ -593,7 +627,7 @@ class Script(scripts.Script):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
with SharedSettingsStackHelper():
|
with SharedSettingsStackHelper():
|
||||||
processed, sub_grids = draw_xyz_grid(
|
processed = draw_xyz_grid(
|
||||||
p,
|
p,
|
||||||
xs=xs,
|
xs=xs,
|
||||||
ys=ys,
|
ys=ys,
|
||||||
@ -610,11 +644,30 @@ class Script(scripts.Script):
|
|||||||
margin_size=margin_size
|
margin_size=margin_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if opts.grid_save and len(sub_grids) > 1:
|
if not processed.images:
|
||||||
for sub_grid in sub_grids:
|
# It broke, no further handling needed.
|
||||||
images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
return processed
|
||||||
|
|
||||||
|
z_count = len(zs)
|
||||||
|
|
||||||
|
if not include_lone_images:
|
||||||
|
# Don't need sub-images anymore, drop from list:
|
||||||
|
processed.images = processed.images[:z_count+1]
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
# Auto-save main and sub-grids:
|
||||||
|
grid_count = z_count + 1 if z_count > 1 else 1
|
||||||
|
for g in range(grid_count):
|
||||||
|
#TODO: See previous comment about intentional data misalignment.
|
||||||
|
adj_g = g-1 if g > 0 else g
|
||||||
|
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
|
||||||
|
|
||||||
|
if not include_sub_grids:
|
||||||
|
# Done with sub-grids, drop all related information:
|
||||||
|
for sg in range(z_count):
|
||||||
|
del processed.images[1]
|
||||||
|
del processed.all_prompts[1]
|
||||||
|
del processed.all_seeds[1]
|
||||||
|
del processed.infotexts[1]
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
78
style.css
78
style.css
@ -362,6 +362,46 @@ input[type="range"]{
|
|||||||
height: 100%;
|
height: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.popup-metadata{
|
||||||
|
color: black;
|
||||||
|
background: white;
|
||||||
|
display: inline-block;
|
||||||
|
padding: 1em;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.global-popup{
|
||||||
|
display: flex;
|
||||||
|
position: fixed;
|
||||||
|
z-index: 1001;
|
||||||
|
left: 0;
|
||||||
|
top: 0;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
overflow: auto;
|
||||||
|
background-color: rgba(20, 20, 20, 0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.global-popup-close:before {
|
||||||
|
content: "×";
|
||||||
|
}
|
||||||
|
|
||||||
|
.global-popup-close{
|
||||||
|
position: fixed;
|
||||||
|
right: 0.25em;
|
||||||
|
top: 0;
|
||||||
|
cursor: pointer;
|
||||||
|
color: white;
|
||||||
|
font-size: 32pt;
|
||||||
|
}
|
||||||
|
|
||||||
|
.global-popup-inner{
|
||||||
|
display: inline-block;
|
||||||
|
margin: auto;
|
||||||
|
padding: 2em;
|
||||||
|
}
|
||||||
|
|
||||||
#lightboxModal{
|
#lightboxModal{
|
||||||
display: none;
|
display: none;
|
||||||
position: fixed;
|
position: fixed;
|
||||||
@ -436,9 +476,7 @@ input[type="range"]{
|
|||||||
|
|
||||||
#modalImage {
|
#modalImage {
|
||||||
display: block;
|
display: block;
|
||||||
margin-left: auto;
|
margin: auto;
|
||||||
margin-right: auto;
|
|
||||||
margin-top: auto;
|
|
||||||
width: auto;
|
width: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -839,6 +877,27 @@ footer {
|
|||||||
margin-left: 0.5em;
|
margin-left: 0.5em;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{
|
||||||
|
content: "🛈";
|
||||||
|
}
|
||||||
|
.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
|
||||||
|
display: none;
|
||||||
|
position: absolute;
|
||||||
|
right: 0;
|
||||||
|
color: white;
|
||||||
|
text-shadow: 2px 2px 3px black;
|
||||||
|
padding: 0.25em;
|
||||||
|
font-size: 22pt;
|
||||||
|
}
|
||||||
|
.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
|
||||||
|
display: inline-block;
|
||||||
|
}
|
||||||
|
.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
.extra-network-thumbs {
|
.extra-network-thumbs {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-flow: row wrap;
|
flex-flow: row wrap;
|
||||||
@ -856,7 +915,7 @@ footer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.extra-network-thumbs .card:hover .additional a {
|
.extra-network-thumbs .card:hover .additional a {
|
||||||
display: block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
|
|
||||||
.extra-network-thumbs .actions .additional a {
|
.extra-network-thumbs .actions .additional a {
|
||||||
@ -939,6 +998,17 @@ footer {
|
|||||||
line-break: anywhere;
|
line-break: anywhere;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card .actions .description {
|
||||||
|
display: block;
|
||||||
|
max-height: 3em;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
line-height: 1.1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.extra-network-cards .card .actions .description:hover {
|
||||||
|
max-height: none;
|
||||||
|
}
|
||||||
|
|
||||||
.extra-network-cards .card .actions:hover .additional{
|
.extra-network-cards .card .actions:hover .additional{
|
||||||
display: block;
|
display: block;
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import requests
|
import requests
|
||||||
from gradio.processing_utils import encode_pil_to_base64
|
from gradio.processing_utils import encode_pil_to_base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from modules.paths import script_path
|
||||||
|
|
||||||
class TestExtrasWorking(unittest.TestCase):
|
class TestExtrasWorking(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase):
|
|||||||
"upscaler_1": "None",
|
"upscaler_1": "None",
|
||||||
"upscaler_2": "None",
|
"upscaler_2": "None",
|
||||||
"extras_upscaler_2_visibility": 0,
|
"extras_upscaler_2_visibility": 0,
|
||||||
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
|
"image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_simple_upscaling_performed(self):
|
def test_simple_upscaling_performed(self):
|
||||||
@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
|
self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
|
||||||
self.png_info = {
|
self.png_info = {
|
||||||
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
|
"image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_png_info_performed(self):
|
def test_png_info_performed(self):
|
||||||
@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
|
self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
|
||||||
self.interrogate = {
|
self.interrogate = {
|
||||||
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")),
|
"image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))),
|
||||||
"model": "clip"
|
"model": "clip"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,14 +1,16 @@
|
|||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import requests
|
import requests
|
||||||
from gradio.processing_utils import encode_pil_to_base64
|
from gradio.processing_utils import encode_pil_to_base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from modules.paths import script_path
|
||||||
|
|
||||||
|
|
||||||
class TestImg2ImgWorking(unittest.TestCase):
|
class TestImg2ImgWorking(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
|
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
|
||||||
self.simple_img2img = {
|
self.simple_img2img = {
|
||||||
"init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
|
"init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))],
|
||||||
"resize_mode": 0,
|
"resize_mode": 0,
|
||||||
"denoising_strength": 0.75,
|
"denoising_strength": 0.75,
|
||||||
"mask": None,
|
"mask": None,
|
||||||
@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase):
|
|||||||
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||||
|
|
||||||
def test_inpainting_masked_performed(self):
|
def test_inpainting_masked_performed(self):
|
||||||
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
|
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
|
||||||
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||||
|
|
||||||
def test_inpainting_with_inverted_masked_performed(self):
|
def test_inpainting_with_inverted_masked_performed(self):
|
||||||
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
|
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
|
||||||
self.simple_img2img["inpainting_mask_invert"] = True
|
self.simple_img2img["inpainting_mask_invert"] = True
|
||||||
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||||
|
|
||||||
|
@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase):
|
|||||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
self.simple_txt2img["sampler_index"] = "DDIM"
|
self.simple_txt2img["sampler_index"] = "DDIM"
|
||||||
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
self.simple_txt2img["sampler_index"] = "UniPC"
|
||||||
|
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
|
||||||
|
|
||||||
def test_txt2img_multiple_batches_performed(self):
|
def test_txt2img_multiple_batches_performed(self):
|
||||||
self.simple_txt2img["n_iter"] = 2
|
self.simple_txt2img["n_iter"] = 2
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
from modules.paths import script_path
|
||||||
|
|
||||||
|
|
||||||
def run_tests(proc, test_dir):
|
def run_tests(proc, test_dir):
|
||||||
@ -15,8 +17,8 @@ def run_tests(proc, test_dir):
|
|||||||
break
|
break
|
||||||
if proc.poll() is None:
|
if proc.poll() is None:
|
||||||
if test_dir is None:
|
if test_dir is None:
|
||||||
test_dir = "test"
|
test_dir = os.path.join(script_path, "test")
|
||||||
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
|
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir)
|
||||||
result = unittest.TextTestRunner(verbosity=2).run(suite)
|
result = unittest.TextTestRunner(verbosity=2).run(suite)
|
||||||
return len(result.failures) + len(result.errors)
|
return len(result.failures) + len(result.errors)
|
||||||
else:
|
else:
|
||||||
|
75
webui.py
75
webui.py
@ -12,11 +12,22 @@ from packaging import version
|
|||||||
import logging
|
import logging
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
|
from modules import paths, timer, import_hook, errors
|
||||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
startup_timer = timer.Timer()
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
startup_timer.record("import torch")
|
||||||
|
|
||||||
|
import gradio
|
||||||
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
|
import ldm.modules.encoders.modules
|
||||||
|
startup_timer.record("import ldm")
|
||||||
|
|
||||||
|
from modules import extra_networks, ui_extra_networks_checkpoints
|
||||||
|
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||||
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
|
|
||||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||||
@ -30,7 +41,6 @@ import modules.gfpgan_model as gfpgan
|
|||||||
import modules.img2img
|
import modules.img2img
|
||||||
|
|
||||||
import modules.lowvram
|
import modules.lowvram
|
||||||
import modules.paths
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
@ -45,6 +55,8 @@ from modules import modelloader
|
|||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
|
|
||||||
|
startup_timer.record("other imports")
|
||||||
|
|
||||||
|
|
||||||
if cmd_opts.server_name:
|
if cmd_opts.server_name:
|
||||||
server_name = cmd_opts.server_name
|
server_name = cmd_opts.server_name
|
||||||
@ -88,6 +100,7 @@ def initialize():
|
|||||||
|
|
||||||
extensions.list_extensions()
|
extensions.list_extensions()
|
||||||
localization.list_localizations(cmd_opts.localizations_dir)
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
|
startup_timer.record("list extensions")
|
||||||
|
|
||||||
if cmd_opts.ui_debug_mode:
|
if cmd_opts.ui_debug_mode:
|
||||||
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||||
@ -96,16 +109,28 @@ def initialize():
|
|||||||
|
|
||||||
modelloader.cleanup_models()
|
modelloader.cleanup_models()
|
||||||
modules.sd_models.setup_model()
|
modules.sd_models.setup_model()
|
||||||
|
startup_timer.record("list SD models")
|
||||||
|
|
||||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||||
|
startup_timer.record("setup codeformer")
|
||||||
|
|
||||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
|
startup_timer.record("setup gfpgan")
|
||||||
|
|
||||||
modelloader.list_builtin_upscalers()
|
modelloader.list_builtin_upscalers()
|
||||||
|
startup_timer.record("list builtin upscalers")
|
||||||
|
|
||||||
modules.scripts.load_scripts()
|
modules.scripts.load_scripts()
|
||||||
|
startup_timer.record("load scripts")
|
||||||
|
|
||||||
modelloader.load_upscalers()
|
modelloader.load_upscalers()
|
||||||
|
startup_timer.record("load upscalers")
|
||||||
|
|
||||||
modules.sd_vae.refresh_vae_list()
|
modules.sd_vae.refresh_vae_list()
|
||||||
|
startup_timer.record("refresh VAE")
|
||||||
|
|
||||||
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||||
|
startup_timer.record("refresh textual inversion templates")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
modules.sd_models.load_model()
|
modules.sd_models.load_model()
|
||||||
@ -114,6 +139,7 @@ def initialize():
|
|||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
startup_timer.record("load SD checkpoint")
|
||||||
|
|
||||||
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
||||||
|
|
||||||
@ -121,8 +147,10 @@ def initialize():
|
|||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
|
startup_timer.record("opts onchange")
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
startup_timer.record("reload hypernets")
|
||||||
|
|
||||||
ui_extra_networks.intialize()
|
ui_extra_networks.intialize()
|
||||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
@ -131,6 +159,7 @@ def initialize():
|
|||||||
|
|
||||||
extra_networks.initialize()
|
extra_networks.initialize()
|
||||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
|
startup_timer.record("extra networks")
|
||||||
|
|
||||||
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
||||||
|
|
||||||
@ -144,6 +173,7 @@ def initialize():
|
|||||||
print("TLS setup invalid, running webui without TLS")
|
print("TLS setup invalid, running webui without TLS")
|
||||||
else:
|
else:
|
||||||
print("Running with TLS")
|
print("Running with TLS")
|
||||||
|
startup_timer.record("TLS")
|
||||||
|
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
@ -153,13 +183,16 @@ def initialize():
|
|||||||
signal.signal(signal.SIGINT, sigint_handler)
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
|
|
||||||
def setup_cors(app):
|
def setup_middleware(app):
|
||||||
|
app.middleware_stack = None # reset current middleware to allow modifying user provided list
|
||||||
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
|
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||||
elif cmd_opts.cors_allow_origins:
|
elif cmd_opts.cors_allow_origins:
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||||
elif cmd_opts.cors_allow_origins_regex:
|
elif cmd_opts.cors_allow_origins_regex:
|
||||||
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||||
|
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
||||||
|
|
||||||
|
|
||||||
def create_api(app):
|
def create_api(app):
|
||||||
@ -183,12 +216,12 @@ def api_only():
|
|||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
setup_cors(app)
|
setup_middleware(app)
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
||||||
api = create_api(app)
|
api = create_api(app)
|
||||||
|
|
||||||
modules.script_callbacks.app_started_callback(None, app)
|
modules.script_callbacks.app_started_callback(None, app)
|
||||||
|
|
||||||
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
@ -199,21 +232,24 @@ def webui():
|
|||||||
while 1:
|
while 1:
|
||||||
if shared.opts.clean_temp_dir_at_start:
|
if shared.opts.clean_temp_dir_at_start:
|
||||||
ui_tempdir.cleanup_tmpdr()
|
ui_tempdir.cleanup_tmpdr()
|
||||||
|
startup_timer.record("cleanup temp dir")
|
||||||
|
|
||||||
modules.script_callbacks.before_ui_callback()
|
modules.script_callbacks.before_ui_callback()
|
||||||
|
startup_timer.record("scripts before_ui_callback")
|
||||||
|
|
||||||
shared.demo = modules.ui.create_ui()
|
shared.demo = modules.ui.create_ui()
|
||||||
|
startup_timer.record("create ui")
|
||||||
|
|
||||||
if cmd_opts.gradio_queue:
|
if cmd_opts.gradio_queue:
|
||||||
shared.demo.queue(64)
|
shared.demo.queue(64)
|
||||||
|
|
||||||
gradio_auth_creds = []
|
gradio_auth_creds = []
|
||||||
if cmd_opts.gradio_auth:
|
if cmd_opts.gradio_auth:
|
||||||
gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',')
|
gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
|
||||||
if cmd_opts.gradio_auth_path:
|
if cmd_opts.gradio_auth_path:
|
||||||
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
||||||
for line in file.readlines():
|
for line in file.readlines():
|
||||||
gradio_auth_creds += [x.strip() for x in line.split(',')]
|
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
|
||||||
|
|
||||||
app, local_url, share_url = shared.demo.launch(
|
app, local_url, share_url = shared.demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
@ -229,15 +265,15 @@ def webui():
|
|||||||
# after initial launch, disable --autolaunch for subsequent restarts
|
# after initial launch, disable --autolaunch for subsequent restarts
|
||||||
cmd_opts.autolaunch = False
|
cmd_opts.autolaunch = False
|
||||||
|
|
||||||
|
startup_timer.record("gradio launch")
|
||||||
|
|
||||||
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
||||||
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
||||||
# running web ui and do whatever the attacker wants, including installing an extension and
|
# running web ui and do whatever the attacker wants, including installing an extension and
|
||||||
# running its code. We disable this here. Suggested by RyotaK.
|
# running its code. We disable this here. Suggested by RyotaK.
|
||||||
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
|
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
|
||||||
|
|
||||||
setup_cors(app)
|
setup_middleware(app)
|
||||||
|
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
||||||
|
|
||||||
modules.progress.setup_progress_api(app)
|
modules.progress.setup_progress_api(app)
|
||||||
|
|
||||||
@ -247,28 +283,42 @@ def webui():
|
|||||||
ui_extra_networks.add_pages_to_demo(app)
|
ui_extra_networks.add_pages_to_demo(app)
|
||||||
|
|
||||||
modules.script_callbacks.app_started_callback(shared.demo, app)
|
modules.script_callbacks.app_started_callback(shared.demo, app)
|
||||||
|
startup_timer.record("scripts app_started_callback")
|
||||||
|
|
||||||
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
|
|
||||||
wait_on_server(shared.demo)
|
wait_on_server(shared.demo)
|
||||||
print('Restarting UI...')
|
print('Restarting UI...')
|
||||||
|
|
||||||
|
startup_timer.reset()
|
||||||
|
|
||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
modules.script_callbacks.script_unloaded_callback()
|
modules.script_callbacks.script_unloaded_callback()
|
||||||
extensions.list_extensions()
|
extensions.list_extensions()
|
||||||
|
startup_timer.record("list extensions")
|
||||||
|
|
||||||
localization.list_localizations(cmd_opts.localizations_dir)
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
|
|
||||||
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
||||||
modules.scripts.reload_scripts()
|
modules.scripts.reload_scripts()
|
||||||
|
startup_timer.record("load scripts")
|
||||||
|
|
||||||
modules.script_callbacks.model_loaded_callback(shared.sd_model)
|
modules.script_callbacks.model_loaded_callback(shared.sd_model)
|
||||||
|
startup_timer.record("model loaded callback")
|
||||||
|
|
||||||
modelloader.load_upscalers()
|
modelloader.load_upscalers()
|
||||||
|
startup_timer.record("load upscalers")
|
||||||
|
|
||||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||||
importlib.reload(module)
|
importlib.reload(module)
|
||||||
|
startup_timer.record("reload script modules")
|
||||||
|
|
||||||
modules.sd_models.list_models()
|
modules.sd_models.list_models()
|
||||||
|
startup_timer.record("list SD models")
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
startup_timer.record("reload hypernetworks")
|
||||||
|
|
||||||
ui_extra_networks.intialize()
|
ui_extra_networks.intialize()
|
||||||
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
|
||||||
@ -277,6 +327,7 @@ def webui():
|
|||||||
|
|
||||||
extra_networks.initialize()
|
extra_networks.initialize()
|
||||||
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
|
||||||
|
startup_timer.record("initialize extra networks")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user