big rework of progressbar/preview system to allow multiple users to prompts at the same time and do not get previews of each other
This commit is contained in:
parent
ebfdd7baeb
commit
d8b90ac121
@ -1,82 +1,25 @@
|
||||
// code related to showing and updating progressbar shown as the image is being made
|
||||
global_progressbars = {}
|
||||
|
||||
|
||||
galleries = {}
|
||||
storedGallerySelections = {}
|
||||
galleryObservers = {}
|
||||
|
||||
// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
|
||||
timeoutIds = {}
|
||||
|
||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||
// gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
|
||||
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
|
||||
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
|
||||
var progressbarParent
|
||||
if(progressbar){
|
||||
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
|
||||
} else{
|
||||
progressbar = gradioApp().getElementById(id_progressbar)
|
||||
progressbarParent = null
|
||||
function rememberGallerySelection(id_gallery){
|
||||
storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
|
||||
}
|
||||
|
||||
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||
function getGallerySelectedIndex(id_gallery){
|
||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||
|
||||
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
||||
if(progressbar.innerText){
|
||||
let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
|
||||
if(document.title != newtitle){
|
||||
document.title = newtitle;
|
||||
}
|
||||
}else{
|
||||
let newtitle = 'Stable Diffusion'
|
||||
if(document.title != newtitle){
|
||||
document.title = newtitle;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
||||
global_progressbars[id_progressbar] = progressbar
|
||||
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
if(timeoutIds[id_part]) return;
|
||||
|
||||
preview = gradioApp().getElementById(id_preview)
|
||||
gallery = gradioApp().getElementById(id_gallery)
|
||||
|
||||
if(preview != null && gallery != null){
|
||||
preview.style.width = gallery.clientWidth + "px"
|
||||
preview.style.height = gallery.clientHeight + "px"
|
||||
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
|
||||
|
||||
//only watch gallery if there is a generation process going on
|
||||
check_gallery(id_gallery);
|
||||
|
||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||
if(progressDiv){
|
||||
timeoutIds[id_part] = window.setTimeout(function() {
|
||||
timeoutIds[id_part] = null
|
||||
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
|
||||
}, 500)
|
||||
} else{
|
||||
if (skip) {
|
||||
skip.style.display = "none"
|
||||
}
|
||||
interrupt.style.display = "none"
|
||||
|
||||
//disconnect observer once generation finished, so user can close selected image if they want
|
||||
if (galleryObservers[id_gallery]) {
|
||||
galleryObservers[id_gallery].disconnect();
|
||||
galleries[id_gallery] = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
});
|
||||
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
||||
}
|
||||
let currentlySelectedIndex = -1
|
||||
galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
|
||||
|
||||
return currentlySelectedIndex
|
||||
}
|
||||
|
||||
// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
|
||||
function check_gallery(id_gallery){
|
||||
let gallery = gradioApp().getElementById(id_gallery)
|
||||
// if gallery has no change, no need to setting up observer again.
|
||||
@ -85,10 +28,16 @@ function check_gallery(id_gallery){
|
||||
if(galleryObservers[id_gallery]){
|
||||
galleryObservers[id_gallery].disconnect();
|
||||
}
|
||||
let prevSelectedIndex = selected_gallery_index();
|
||||
|
||||
storedGallerySelections[id_gallery] = -1
|
||||
|
||||
galleryObservers[id_gallery] = new MutationObserver(function (){
|
||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||
let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
|
||||
prevSelectedIndex = storedGallerySelections[id_gallery]
|
||||
storedGallerySelections[id_gallery] = -1
|
||||
|
||||
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||
// automatically re-open previously selected index (if exists)
|
||||
activeElement = gradioApp().activeElement;
|
||||
@ -120,30 +69,150 @@ function check_gallery(id_gallery){
|
||||
}
|
||||
|
||||
onUiUpdate(function(){
|
||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
||||
check_gallery('txt2img_gallery')
|
||||
check_gallery('img2img_gallery')
|
||||
})
|
||||
|
||||
function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
|
||||
btn = gradioApp().getElementById(id_part+"_check_progress");
|
||||
if(btn==null) return;
|
||||
|
||||
btn.click();
|
||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||
if(progressDiv && interrupt){
|
||||
if (skip) {
|
||||
skip.style.display = "block"
|
||||
function request(url, data, handler, errorHandler){
|
||||
var xhr = new XMLHttpRequest();
|
||||
var url = url;
|
||||
xhr.open("POST", url, true);
|
||||
xhr.setRequestHeader("Content-Type", "application/json");
|
||||
xhr.onreadystatechange = function () {
|
||||
if (xhr.readyState === 4) {
|
||||
if (xhr.status === 200) {
|
||||
var js = JSON.parse(xhr.responseText);
|
||||
handler(js)
|
||||
} else{
|
||||
errorHandler()
|
||||
}
|
||||
interrupt.style.display = "block"
|
||||
}
|
||||
};
|
||||
var js = JSON.stringify(data);
|
||||
xhr.send(js);
|
||||
}
|
||||
|
||||
function pad2(x){
|
||||
return x<10 ? '0'+x : x
|
||||
}
|
||||
|
||||
function formatTime(secs){
|
||||
if(secs > 3600){
|
||||
return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
|
||||
} else if(secs > 60){
|
||||
return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
|
||||
} else{
|
||||
return Math.floor(secs) + "s"
|
||||
}
|
||||
}
|
||||
|
||||
function requestProgress(id_part){
|
||||
btn = gradioApp().getElementById(id_part+"_check_progress_initial");
|
||||
if(btn==null) return;
|
||||
|
||||
btn.click();
|
||||
function randomId(){
|
||||
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
|
||||
}
|
||||
|
||||
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
||||
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
||||
// calls onProgress every time there is a progress update
|
||||
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
|
||||
var dateStart = new Date()
|
||||
var wasEverActive = false
|
||||
var parentProgressbar = progressbarContainer.parentNode
|
||||
var parentGallery = gallery.parentNode
|
||||
|
||||
var divProgress = document.createElement('div')
|
||||
divProgress.className='progressDiv'
|
||||
var divInner = document.createElement('div')
|
||||
divInner.className='progress'
|
||||
|
||||
divProgress.appendChild(divInner)
|
||||
parentProgressbar.insertBefore(divProgress, progressbarContainer)
|
||||
|
||||
var livePreview = document.createElement('div')
|
||||
livePreview.className='livePreview'
|
||||
parentGallery.insertBefore(livePreview, gallery)
|
||||
|
||||
var removeProgressBar = function(){
|
||||
parentProgressbar.removeChild(divProgress)
|
||||
parentGallery.removeChild(livePreview)
|
||||
atEnd()
|
||||
}
|
||||
|
||||
var fun = function(id_task, id_live_preview){
|
||||
request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
|
||||
console.log(res)
|
||||
|
||||
if(res.completed){
|
||||
removeProgressBar()
|
||||
return
|
||||
}
|
||||
|
||||
var rect = progressbarContainer.getBoundingClientRect()
|
||||
|
||||
if(rect.width){
|
||||
divProgress.style.width = rect.width + "px";
|
||||
}
|
||||
|
||||
progressText = ""
|
||||
|
||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
||||
|
||||
if(res.progress > 0){
|
||||
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
|
||||
}
|
||||
|
||||
if(res.eta){
|
||||
progressText += " ETA: " + formatTime(res.eta)
|
||||
} else if(res.textinfo){
|
||||
progressText += " " + res.textinfo
|
||||
}
|
||||
|
||||
divInner.textContent = progressText
|
||||
|
||||
var elapsedFromStart = (new Date() - dateStart) / 1000
|
||||
|
||||
if(res.active) wasEverActive = true;
|
||||
|
||||
if(! res.active && wasEverActive){
|
||||
removeProgressBar()
|
||||
return
|
||||
}
|
||||
|
||||
if(elapsedFromStart > 5 && !res.queued && !res.active){
|
||||
removeProgressBar()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
if(res.live_preview){
|
||||
var img = new Image();
|
||||
img.onload = function() {
|
||||
var rect = gallery.getBoundingClientRect()
|
||||
if(rect.width){
|
||||
livePreview.style.width = rect.width + "px"
|
||||
livePreview.style.height = rect.height + "px"
|
||||
}
|
||||
|
||||
livePreview.innerHTML = ''
|
||||
livePreview.appendChild(img)
|
||||
if(livePreview.childElementCount > 2){
|
||||
livePreview.removeChild(livePreview.firstElementChild)
|
||||
}
|
||||
}
|
||||
img.src = res.live_preview;
|
||||
}
|
||||
|
||||
|
||||
if(onProgress){
|
||||
onProgress(res)
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
fun(id_task, res.id_live_preview);
|
||||
}, 500)
|
||||
}, function(){
|
||||
removeProgressBar()
|
||||
})
|
||||
}
|
||||
|
||||
fun(id_task, 0)
|
||||
}
|
||||
|
@ -1,8 +1,17 @@
|
||||
|
||||
|
||||
|
||||
function start_training_textual_inversion(){
|
||||
requestProgress('ti')
|
||||
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||
|
||||
return args_to_array(arguments)
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
|
||||
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
|
||||
})
|
||||
|
||||
var res = args_to_array(arguments)
|
||||
|
||||
res[0] = id
|
||||
|
||||
return res
|
||||
}
|
||||
|
@ -126,18 +126,41 @@ function create_submit_args(args){
|
||||
return res
|
||||
}
|
||||
|
||||
function submit(){
|
||||
requestProgress('txt2img')
|
||||
function showSubmitButtons(tabname, show){
|
||||
gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
|
||||
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
|
||||
}
|
||||
|
||||
return create_submit_args(arguments)
|
||||
function submit(){
|
||||
rememberGallerySelection('txt2img_gallery')
|
||||
showSubmitButtons('txt2img', false)
|
||||
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
||||
showSubmitButtons('txt2img', true)
|
||||
|
||||
})
|
||||
|
||||
var res = create_submit_args(arguments)
|
||||
|
||||
res[0] = id
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
function submit_img2img(){
|
||||
requestProgress('img2img')
|
||||
rememberGallerySelection('img2img_gallery')
|
||||
showSubmitButtons('img2img', false)
|
||||
|
||||
res = create_submit_args(arguments)
|
||||
var id = randomId()
|
||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
||||
showSubmitButtons('img2img', true)
|
||||
})
|
||||
|
||||
res[0] = get_tab_index('mode_img2img')
|
||||
var res = create_submit_args(arguments)
|
||||
|
||||
res[0] = id
|
||||
res[1] = get_tab_index('mode_img2img')
|
||||
|
||||
return res
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import threading
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from modules import shared
|
||||
from modules import shared, progress
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
@ -22,10 +22,21 @@ def wrap_queued_call(func):
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
def f(*args, **kwargs):
|
||||
|
||||
shared.state.begin()
|
||||
# if the first argument is a string that says "task(...)", it is treated as a job id
|
||||
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
|
||||
id_task = args[0]
|
||||
progress.add_task_to_queue(id_task)
|
||||
else:
|
||||
id_task = None
|
||||
|
||||
with queue_lock:
|
||||
shared.state.begin()
|
||||
progress.start_task(id_task)
|
||||
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
finally:
|
||||
progress.finish_task(id_task)
|
||||
|
||||
shared.state.end()
|
||||
|
||||
|
@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
|
||||
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
||||
pbar.set_description(description)
|
||||
shared.state.textinfo = description
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||
torch.cuda.set_rng_state_all(cuda_rng_state)
|
||||
hypernetwork.train()
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
shared.state.assign_current_image(image)
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
|
@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
is_batch = mode == 5
|
||||
|
||||
if mode == 0: # img2img
|
||||
|
96
modules/progress.py
Normal file
96
modules/progress.py
Normal file
@ -0,0 +1,96 @@
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
current_task = None
|
||||
pending_tasks = {}
|
||||
finished_tasks = []
|
||||
|
||||
|
||||
def start_task(id_task):
|
||||
global current_task
|
||||
|
||||
current_task = id_task
|
||||
pending_tasks.pop(id_task, None)
|
||||
|
||||
|
||||
def finish_task(id_task):
|
||||
global current_task
|
||||
|
||||
if current_task == id_task:
|
||||
current_task = None
|
||||
|
||||
finished_tasks.append(id_task)
|
||||
if len(finished_tasks) > 16:
|
||||
finished_tasks.pop(0)
|
||||
|
||||
|
||||
def add_task_to_queue(id_job):
|
||||
pending_tasks[id_job] = time.time()
|
||||
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
||||
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
active: bool = Field(title="Whether the task is being worked on right now")
|
||||
queued: bool = Field(title="Whether the task is in queue")
|
||||
completed: bool = Field(title="Whether the task has already finished")
|
||||
progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta: float = Field(default=None, title="ETA in secs")
|
||||
live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
|
||||
id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
|
||||
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
|
||||
def setup_progress_api(app):
|
||||
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
||||
|
||||
|
||||
def progressapi(req: ProgressRequest):
|
||||
active = req.id_task == current_task
|
||||
queued = req.id_task in pending_tasks
|
||||
completed = req.id_task in finished_tasks
|
||||
|
||||
if not active:
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
|
||||
|
||||
progress = 0
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
elapsed_since_start = time.time() - shared.state.time_start
|
||||
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
||||
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
||||
|
||||
id_live_preview = req.id_live_preview
|
||||
shared.state.set_current_image()
|
||||
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
|
||||
image = shared.state.current_image
|
||||
if image is not None:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="png")
|
||||
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
else:
|
||||
live_preview = None
|
||||
else:
|
||||
live_preview = None
|
||||
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||
|
@ -140,7 +140,7 @@ def store_latent(decoded):
|
||||
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.current_image = sample_to_image(decoded)
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
|
@ -152,6 +152,7 @@ def reload_hypernetworks():
|
||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
||||
|
||||
|
||||
|
||||
class State:
|
||||
skipped = False
|
||||
interrupted = False
|
||||
@ -165,6 +166,7 @@ class State:
|
||||
current_latent = None
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
id_live_preview = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
need_restart = False
|
||||
@ -207,6 +209,7 @@ class State:
|
||||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.id_live_preview = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
@ -220,8 +223,8 @@ class State:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
def set_current_image(self):
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
@ -234,12 +237,16 @@ class State:
|
||||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
def assign_current_image(self, image):
|
||||
self.current_image = image
|
||||
self.id_live_preview += 1
|
||||
|
||||
|
||||
state = State()
|
||||
state.server_start = time.time()
|
||||
@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
|
@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
template_file = textual_inversion_templates.get(template_filename, None)
|
||||
@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
|
||||
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
|
||||
pbar.set_description(description)
|
||||
shared.state.textinfo = description
|
||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||
@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
shared.state.assign_current_image(image)
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
|
@ -8,7 +8,7 @@ import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
|
@ -356,7 +356,7 @@ def create_toprow(is_img2img):
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box"):
|
||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
|
||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
@ -384,9 +384,7 @@ def create_toprow(is_img2img):
|
||||
|
||||
|
||||
def setup_progressbar(*args, **kwargs):
|
||||
import modules.ui_progress
|
||||
|
||||
modules.ui_progress.setup_progressbar(*args, **kwargs)
|
||||
pass
|
||||
|
||||
|
||||
def apply_setting(key, value):
|
||||
@ -479,8 +477,8 @@ Requested path was: {f}
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Group():
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||
|
||||
generation_info = None
|
||||
@ -595,15 +593,6 @@ def create_ui():
|
||||
dummy_component = gr.Label(visible=False)
|
||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
|
||||
|
||||
with gr.Row(elem_id='txt2img_progress_row'):
|
||||
with gr.Column(scale=1):
|
||||
pass
|
||||
|
||||
with gr.Column(scale=1):
|
||||
progressbar = gr.HTML(elem_id="txt2img_progressbar")
|
||||
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
|
||||
setup_progressbar(progressbar, txt2img_preview, 'txt2img')
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel', elem_id="txt2img_settings"):
|
||||
for category in ordered_ui_categories():
|
||||
@ -682,6 +671,7 @@ def create_ui():
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||
_js="submit",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
txt2img_prompt,
|
||||
txt2img_negative_prompt,
|
||||
txt2img_prompt_style,
|
||||
@ -782,17 +772,8 @@ def create_ui():
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
|
||||
|
||||
with gr.Row(elem_id='img2img_progress_row'):
|
||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
pass
|
||||
|
||||
with gr.Column(scale=1):
|
||||
progressbar = gr.HTML(elem_id="img2img_progressbar")
|
||||
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
|
||||
setup_progressbar(progressbar, img2img_preview, 'img2img')
|
||||
|
||||
with FormRow().style(equal_height=False):
|
||||
with gr.Column(variant='panel', elem_id="img2img_settings"):
|
||||
copy_image_buttons = []
|
||||
@ -958,6 +939,7 @@ def create_ui():
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||
_js="submit_img2img",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
dummy_component,
|
||||
img2img_prompt,
|
||||
img2img_negative_prompt,
|
||||
@ -1335,15 +1317,11 @@ def create_ui():
|
||||
|
||||
script_callbacks.ui_train_tabs_callback(params)
|
||||
|
||||
with gr.Column():
|
||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||
with gr.Column(elem_id='ti_gallery_container'):
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
||||
|
||||
create_embedding.click(
|
||||
fn=modules.textual_inversion.ui.create_embedding,
|
||||
@ -1384,6 +1362,7 @@ def create_ui():
|
||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
process_src,
|
||||
process_dst,
|
||||
process_width,
|
||||
@ -1411,6 +1390,7 @@ def create_ui():
|
||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
train_embedding_name,
|
||||
embedding_learn_rate,
|
||||
batch_size,
|
||||
@ -1443,6 +1423,7 @@ def create_ui():
|
||||
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
|
||||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
train_hypernetwork_name,
|
||||
hypernetwork_learn_rate,
|
||||
batch_size,
|
||||
|
@ -1,101 +0,0 @@
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
def calc_time_left(progress, threshold, label, force_display, show_eta):
|
||||
if progress == 0:
|
||||
return ""
|
||||
else:
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = (time_since_start/progress)
|
||||
eta_relative = eta-time_since_start
|
||||
if (eta_relative > threshold and show_eta) or force_display:
|
||||
if eta_relative > 3600:
|
||||
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
||||
elif eta_relative > 60:
|
||||
return label + time.strftime('%M:%S', time.gmtime(eta_relative))
|
||||
else:
|
||||
return label + time.strftime('%Ss', time.gmtime(eta_relative))
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def check_progress_call(id_part):
|
||||
if shared.state.job_count == 0:
|
||||
return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
||||
|
||||
progress = 0
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
# Show progress percentage and time left at the same moment, and base it also on steps done
|
||||
show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
|
||||
|
||||
time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
|
||||
if time_left != "":
|
||||
shared.state.time_left_force_display = True
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
progressbar = ""
|
||||
if opts.show_progressbar:
|
||||
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
|
||||
|
||||
image = gr.update(visible=False)
|
||||
preview_visibility = gr.update(visible=False)
|
||||
|
||||
if opts.live_previews_enable:
|
||||
shared.state.set_current_image()
|
||||
image = shared.state.current_image
|
||||
|
||||
if image is None:
|
||||
image = gr.update(value=None)
|
||||
else:
|
||||
preview_visibility = gr.update(visible=True)
|
||||
|
||||
if shared.state.textinfo is not None:
|
||||
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
||||
else:
|
||||
textinfo_result = gr.update(visible=False)
|
||||
|
||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
||||
|
||||
|
||||
def check_progress_call_initial(id_part):
|
||||
shared.state.job_count = -1
|
||||
shared.state.current_latent = None
|
||||
shared.state.current_image = None
|
||||
shared.state.textinfo = None
|
||||
shared.state.time_start = time.time()
|
||||
shared.state.time_left_force_display = False
|
||||
|
||||
return check_progress_call(id_part)
|
||||
|
||||
|
||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||
if textinfo is None:
|
||||
textinfo = gr.HTML(visible=False)
|
||||
|
||||
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
||||
check_progress.click(
|
||||
fn=lambda: check_progress_call(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
||||
|
||||
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
||||
check_progress_initial.click(
|
||||
fn=lambda: check_progress_call_initial(id_part),
|
||||
show_progress=False,
|
||||
inputs=[],
|
||||
outputs=[progressbar, preview, preview, textinfo],
|
||||
)
|
42
style.css
42
style.css
@ -305,10 +305,11 @@ input[type="range"]{
|
||||
}
|
||||
|
||||
.progressDiv{
|
||||
width: 100%;
|
||||
position: absolute;
|
||||
height: 20px;
|
||||
top: -20px;
|
||||
background: #b4c0cc;
|
||||
border-radius: 8px;
|
||||
border-radius: 8px !important;
|
||||
}
|
||||
|
||||
.dark .progressDiv{
|
||||
@ -325,6 +326,21 @@ input[type="range"]{
|
||||
padding: 0 8px 0 0;
|
||||
text-align: right;
|
||||
border-radius: 8px;
|
||||
overflow: visible;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.livePreview{
|
||||
position: absolute;
|
||||
z-index: 300;
|
||||
background-color: white;
|
||||
margin: -4px;
|
||||
}
|
||||
|
||||
.livePreview img{
|
||||
object-fit: contain;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
#lightboxModal{
|
||||
@ -450,23 +466,25 @@ input[type="range"]{
|
||||
display:none
|
||||
}
|
||||
|
||||
#txt2img_interrupt, #img2img_interrupt{
|
||||
#txt2img_generate_box, #img2img_generate_box{
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{
|
||||
position: absolute;
|
||||
width: 50%;
|
||||
height: 72px;
|
||||
height: 100%;
|
||||
background: #b4c0cc;
|
||||
border-radius: 0px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
#txt2img_interrupt, #img2img_interrupt{
|
||||
right: 0;
|
||||
border-radius: 0 0.5rem 0.5rem 0;
|
||||
}
|
||||
#txt2img_skip, #img2img_skip{
|
||||
position: absolute;
|
||||
width: 50%;
|
||||
right: 0px;
|
||||
height: 72px;
|
||||
background: #b4c0cc;
|
||||
border-radius: 0px;
|
||||
display: none;
|
||||
left: 0;
|
||||
border-radius: 0.5rem 0 0 0.5rem;
|
||||
}
|
||||
|
||||
.red {
|
||||
|
3
webui.py
3
webui.py
@ -34,6 +34,7 @@ import modules.sd_vae
|
||||
import modules.txt2img
|
||||
import modules.script_callbacks
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.progress
|
||||
|
||||
import modules.ui
|
||||
from modules import modelloader
|
||||
@ -181,6 +182,8 @@ def webui():
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
modules.progress.setup_progress_api(app)
|
||||
|
||||
if launch_api:
|
||||
create_api(app)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user