mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 14:52:31 +03:00
* JobManager: Re-merge #611 PR #611 seems to have got lost in the shuffle after the transition to 'dev'. This commit re-merges the feature branch. This adds support for viewing preview images as the image generates, as well as cancelling in-progress images and a couple fixes and clean-ups. * JobManager: Clear jobs that fail to start Sometimes if a job fails to start it will get stuck in the active job list. This commit ensures that jobs that raise exceptions are cleared, and also adds a start timer to clear out jobs that fail to start within a reasonable amount of time.
This commit is contained in:
parent
8bc8b006fd
commit
81f58d58d0
@ -1,7 +1,7 @@
|
||||
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
|
||||
from __future__ import annotations
|
||||
import gradio as gr
|
||||
from gradio.components import Component, Gallery
|
||||
from gradio.components import Component, Gallery, Slider
|
||||
from threading import Event, Timer
|
||||
from typing import Callable, List, Dict, Tuple, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
@ -9,6 +9,7 @@ from functools import partial
|
||||
from PIL.Image import Image
|
||||
import uuid
|
||||
import traceback
|
||||
import time
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
@ -30,9 +31,21 @@ class JobInfo:
|
||||
session_key: str
|
||||
job_token: Optional[int] = None
|
||||
images: List[Image] = field(default_factory=list)
|
||||
active_image: Image = None
|
||||
rec_steps_enabled: bool = False
|
||||
rec_steps_imgs: List[Image] = field(default_factory=list)
|
||||
rec_steps_intrvl: int = None
|
||||
rec_steps_to_gallery: bool = False
|
||||
rec_steps_to_file: bool = False
|
||||
should_stop: Event = field(default_factory=Event)
|
||||
refresh_active_image_requested: Event = field(default_factory=Event)
|
||||
refresh_active_image_done: Event = field(default_factory=Event)
|
||||
stop_cur_iter: Event = field(default_factory=Event)
|
||||
active_iteration_cnt: int = field(default_factory=int)
|
||||
job_status: str = field(default_factory=str)
|
||||
finished: bool = False
|
||||
started: bool = False
|
||||
timestamp: float = None
|
||||
removed_output_idxs: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@ -76,7 +89,7 @@ class JobManagerUi:
|
||||
'''
|
||||
return self._job_manager._wrap_func(
|
||||
func=func, inputs=inputs, outputs=outputs,
|
||||
refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text
|
||||
job_ui=self
|
||||
)
|
||||
|
||||
_refresh_btn: gr.Button
|
||||
@ -84,10 +97,19 @@ class JobManagerUi:
|
||||
_status_text: gr.Textbox
|
||||
_stop_all_session_btn: gr.Button
|
||||
_free_done_sessions_btn: gr.Button
|
||||
_active_image: gr.Image
|
||||
_active_image_stop_btn: gr.Button
|
||||
_active_image_refresh_btn: gr.Button
|
||||
_rec_steps_intrvl_sldr: gr.Slider
|
||||
_rec_steps_checkbox: gr.Checkbox
|
||||
_save_rec_steps_to_gallery_chkbx: gr.Checkbox
|
||||
_save_rec_steps_to_file_chkbx: gr.Checkbox
|
||||
_job_manager: JobManager
|
||||
|
||||
|
||||
class JobManager:
|
||||
JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running
|
||||
|
||||
def __init__(self, max_jobs: int):
|
||||
self._max_jobs: int = max_jobs
|
||||
self._avail_job_tokens: List[Any] = list(range(max_jobs))
|
||||
@ -102,11 +124,23 @@ class JobManager:
|
||||
'''
|
||||
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Current Session"):
|
||||
with gr.TabItem("Job Controls"):
|
||||
with gr.Row():
|
||||
stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary")
|
||||
stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
|
||||
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
|
||||
with gr.Row():
|
||||
active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
|
||||
active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
|
||||
active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
|
||||
with gr.TabItem("Batch Progress Settings"):
|
||||
with gr.Row():
|
||||
record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
|
||||
record_steps_interval_slider = gr.Slider(
|
||||
value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
|
||||
with gr.Row() as record_steps_box:
|
||||
steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
|
||||
steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
|
||||
with gr.TabItem("Maintenance"):
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -118,9 +152,15 @@ class JobManager:
|
||||
free_done_sessions_btn = gr.Button(
|
||||
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
|
||||
)
|
||||
|
||||
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
|
||||
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
|
||||
_job_manager=self)
|
||||
_active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
|
||||
_active_image_refresh_btn=active_image_refresh_btn,
|
||||
_rec_steps_checkbox=record_steps_checkbox,
|
||||
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
|
||||
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
|
||||
_rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
|
||||
|
||||
def clear_all_finished_jobs(self):
|
||||
''' Removes all currently finished jobs, across all sessions.
|
||||
@ -134,6 +174,7 @@ class JobManager:
|
||||
for session in self._sessions.values():
|
||||
for job in session.jobs.values():
|
||||
job.should_stop.set()
|
||||
job.stop_cur_iter.set()
|
||||
|
||||
def _get_job_token(self, block: bool = False) -> Optional[int]:
|
||||
''' Attempts to acquire a job token, optionally blocking until available '''
|
||||
@ -175,6 +216,26 @@ class JobManager:
|
||||
job_info.should_stop.set()
|
||||
return "Stopping after current batch finishes"
|
||||
|
||||
def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Updates information from the active iteration '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
|
||||
job_info.refresh_active_image_requested.set()
|
||||
if job_info.refresh_active_image_done.wait(timeout=20.0):
|
||||
job_info.refresh_active_image_done.clear()
|
||||
return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
|
||||
return [gr.Image.update(visible=False), "Timed out getting image"]
|
||||
|
||||
def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Marks that the active iteration should be stopped'''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
job_info.stop_cur_iter.set()
|
||||
return [gr.Image.update(visible=False), "Stopping current iteration"]
|
||||
|
||||
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
|
||||
''' Helper to get the SessionInfo and JobInfo. '''
|
||||
session_info = self._sessions.get(session_key, None)
|
||||
@ -207,7 +268,8 @@ class JobManager:
|
||||
|
||||
def _pre_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
''' Called when a job is about to start '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
|
||||
@ -219,7 +281,9 @@ class JobManager:
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates")
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
|
||||
active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
|
||||
}
|
||||
|
||||
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -228,12 +292,19 @@ class JobManager:
|
||||
if session_info is None or job_info is None:
|
||||
return []
|
||||
|
||||
job_info.started = True
|
||||
try:
|
||||
if job_info.should_stop.is_set():
|
||||
raise Exception(f"Job {job_info} requested a stop before execution began")
|
||||
outputs = job_info.func(*job_info.inputs, job_info=job_info)
|
||||
except Exception as e:
|
||||
job_info.job_status = f"Error: {e}"
|
||||
print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}")
|
||||
outputs = []
|
||||
raise
|
||||
finally:
|
||||
job_info.finished = True
|
||||
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
|
||||
self._release_job_token(job_info.job_token)
|
||||
|
||||
# Filter the function output for any removed outputs
|
||||
filtered_output = []
|
||||
@ -241,11 +312,6 @@ class JobManager:
|
||||
if idx not in job_info.removed_output_idxs:
|
||||
filtered_output.append(output)
|
||||
|
||||
job_info.finished = True
|
||||
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
|
||||
|
||||
self._release_job_token(job_info.job_token)
|
||||
|
||||
# The wrapper added a dummy JSON output. Append a random text string
|
||||
# to fire the dummy objects 'change' event to notify that the job is done
|
||||
filtered_output.append(triggerChangeEvent())
|
||||
@ -254,12 +320,16 @@ class JobManager:
|
||||
|
||||
def _post_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
''' Called when a job completes '''
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has finished!")
|
||||
status_text: gr.Textbox.update(value="Generation has finished!"),
|
||||
active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
|
||||
active_image: gr.Image.update(visible=False)
|
||||
}
|
||||
|
||||
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -270,21 +340,17 @@ class JobManager:
|
||||
if session_info is None or job_info is None:
|
||||
return []
|
||||
|
||||
if job_info.finished:
|
||||
session_info.finished_jobs.pop(func_key)
|
||||
|
||||
return job_info.images
|
||||
|
||||
def _wrap_func(
|
||||
self, func: Callable, inputs: List[Component], outputs: List[Component],
|
||||
refresh_btn: gr.Button = None, stop_btn: gr.Button = None,
|
||||
status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]:
|
||||
def _wrap_func(self, func: Callable, inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
|
||||
''' handles JobManageUI's wrap_func'''
|
||||
|
||||
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
|
||||
|
||||
# Create a unique key for this job
|
||||
func_key = FuncKey(job_id=uuid.uuid4(), func=func)
|
||||
func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
|
||||
|
||||
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
|
||||
if self._session_key is None:
|
||||
@ -302,9 +368,6 @@ class JobManager:
|
||||
del outputs[idx]
|
||||
break
|
||||
|
||||
# Add the session key to the inputs
|
||||
inputs += [self._session_key]
|
||||
|
||||
# Create dummy objects
|
||||
update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject")
|
||||
update_gallery_obj.change(
|
||||
@ -314,21 +377,49 @@ class JobManager:
|
||||
queue=False
|
||||
)
|
||||
|
||||
if refresh_btn:
|
||||
refresh_btn.variant = 'secondary'
|
||||
refresh_btn.click(
|
||||
if job_ui._refresh_btn:
|
||||
job_ui._refresh_btn.variant = 'secondary'
|
||||
job_ui._refresh_btn.click(
|
||||
partial(self._refresh_func, func_key),
|
||||
[self._session_key],
|
||||
[update_gallery_obj, status_text],
|
||||
[update_gallery_obj, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if stop_btn:
|
||||
stop_btn.variant = 'secondary'
|
||||
stop_btn.click(
|
||||
if job_ui._stop_btn:
|
||||
job_ui._stop_btn.variant = 'secondary'
|
||||
job_ui._stop_btn.click(
|
||||
partial(self._stop_wrapped_func, func_key),
|
||||
[self._session_key],
|
||||
[status_text],
|
||||
[job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._active_image and job_ui._active_image_refresh_btn:
|
||||
job_ui._active_image_refresh_btn.click(
|
||||
partial(self._refresh_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._active_image_stop_btn:
|
||||
job_ui._active_image_stop_btn.click(
|
||||
partial(self._stop_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._stop_all_session_btn:
|
||||
job_ui._stop_all_session_btn.click(
|
||||
self.stop_all_jobs, [], [],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._free_done_sessions_btn:
|
||||
job_ui._free_done_sessions_btn.click(
|
||||
self.clear_all_finished_jobs, [], [],
|
||||
queue=False
|
||||
)
|
||||
|
||||
@ -346,7 +437,8 @@ class JobManager:
|
||||
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
|
||||
# the Component as a key... so group together the UI components that the event listeners are going to update
|
||||
# to make it easy to append to function calls and outputs
|
||||
job_ui_params = [refresh_btn, stop_btn, status_text]
|
||||
job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
|
||||
job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
|
||||
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
|
||||
|
||||
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
|
||||
@ -375,27 +467,50 @@ class JobManager:
|
||||
queue=False
|
||||
)
|
||||
|
||||
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
|
||||
# Add any components that we want the runtime values for
|
||||
added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
|
||||
job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
|
||||
|
||||
def wrapped_func(*inputs):
|
||||
session_key = inputs[-1]
|
||||
inputs = inputs[:-1]
|
||||
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
|
||||
def wrapped_func(*wrapped_inputs):
|
||||
# Remove the added_inputs (pop opposite order of list)
|
||||
|
||||
wrapped_inputs = list(wrapped_inputs)
|
||||
rec_steps_interval: int = wrapped_inputs.pop()
|
||||
save_rec_steps_file: bool = wrapped_inputs.pop()
|
||||
save_rec_steps_grid: bool = wrapped_inputs.pop()
|
||||
record_steps_enabled: bool = wrapped_inputs.pop()
|
||||
session_key: str = wrapped_inputs.pop()
|
||||
job_inputs = tuple(wrapped_inputs)
|
||||
|
||||
# Get or create a session for this key
|
||||
session_info = self._sessions.setdefault(session_key, SessionInfo())
|
||||
|
||||
# Is this session already running this job?
|
||||
if func_key in session_info.jobs:
|
||||
return {status_text: "This session is already running that function!"}
|
||||
job_info = session_info.jobs[func_key]
|
||||
# If the job seems stuck in 'starting' then go ahead and toss it
|
||||
if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME:
|
||||
job_info.should_stop.set()
|
||||
job_info.stop_cur_iter.set()
|
||||
session_info.jobs.pop(func_key)
|
||||
return {job_ui._status_text: "Canceled possibly hung job. Try again"}
|
||||
return {job_ui._status_text: "This session is already running that function!"}
|
||||
|
||||
# Is this a new run of a previously finished job? Clear old info
|
||||
if func_key in session_info.finished_jobs:
|
||||
session_info.finished_jobs.pop(func_key)
|
||||
|
||||
job_token = self._get_job_token(block=False)
|
||||
job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token)
|
||||
job = JobInfo(
|
||||
inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
|
||||
rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time())
|
||||
session_info.jobs[func_key] = job
|
||||
|
||||
ret = {pre_call_dummyobj: triggerChangeEvent()}
|
||||
if job_token is None:
|
||||
ret[status_text] = "Job is queued"
|
||||
ret[job_ui._status_text] = "Job is queued"
|
||||
return ret
|
||||
|
||||
return wrapped_func, inputs, [pre_call_dummyobj, status_text]
|
||||
return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]
|
||||
|
108
scripts/webui.py
108
scripts/webui.py
@ -74,6 +74,7 @@ import copy
|
||||
from typing import List, Union, Dict, Callable, Any, Optional
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
# tell the user which GPU the code is actually using
|
||||
if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'):
|
||||
@ -307,15 +308,21 @@ class KDiffusionSampler:
|
||||
self.schedule = sampler
|
||||
def get_sampler_name(self):
|
||||
return self.schedule
|
||||
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
|
||||
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ):
|
||||
sigmas = self.model_wrap.get_sigmas(S)
|
||||
x = x_T * sigmas[0]
|
||||
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
|
||||
|
||||
return samples_ddim, None
|
||||
|
||||
@classmethod
|
||||
def img_callback_wrapper(cls, callback: Callable, *args):
|
||||
''' Converts a KDiffusion callback to the standard img_callback '''
|
||||
if callback:
|
||||
arg_dict = args[0]
|
||||
callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i'])
|
||||
|
||||
def create_random_tensors(shape, seeds):
|
||||
xs = []
|
||||
@ -1003,6 +1010,7 @@ def process_images(
|
||||
|
||||
if job_info:
|
||||
job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}"
|
||||
job_info.rec_steps_imgs.clear()
|
||||
for idx,(p,s) in enumerate(zip(prompts,seeds)):
|
||||
job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}"
|
||||
print(f"Current prompt: {p}")
|
||||
@ -1057,7 +1065,78 @@ def process_images(
|
||||
# finally, slerp base_x noise to target_x noise for creating a variant
|
||||
x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x)
|
||||
|
||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||
# If optimized then use first stage for preview and store it on cpu until needed
|
||||
if opt.optimized:
|
||||
step_preview_model = modelFS
|
||||
step_preview_model.cpu()
|
||||
else:
|
||||
step_preview_model = model
|
||||
|
||||
def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int):
|
||||
''' Called from the sampler every iteration '''
|
||||
if job_info:
|
||||
job_info.active_iteration_cnt = iter_num
|
||||
record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl)
|
||||
if record_periodic_image or job_info.refresh_active_image_requested.is_set():
|
||||
preview_start_time = time.time()
|
||||
if opt.optimized:
|
||||
step_preview_model.to(device)
|
||||
|
||||
decoded_batch: List[torch.Tensor] = []
|
||||
# Break up batch to save VRAM
|
||||
for sample in image_sample:
|
||||
sample = sample[None, :] # expands the tensor as if it still had a batch dimension
|
||||
decoded_sample = step_preview_model.decode_first_stage(sample)[0]
|
||||
decoded_sample = torch.clamp((decoded_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
decoded_sample = decoded_sample.cpu()
|
||||
decoded_batch.append(decoded_sample)
|
||||
|
||||
batch_size = len(decoded_batch)
|
||||
|
||||
if opt.optimized:
|
||||
step_preview_model.cpu()
|
||||
|
||||
images: List[Image.Image] = []
|
||||
# Convert tensor to image (copied from code below)
|
||||
for ddim in decoded_batch:
|
||||
x_sample = 255. * rearrange(ddim.numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
image = Image.fromarray(x_sample)
|
||||
images.append(image)
|
||||
|
||||
caption = f"Iter {iter_num}"
|
||||
grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images))
|
||||
|
||||
# Save the images if recording steps, and append existing saved steps
|
||||
if job_info.rec_steps_enabled:
|
||||
gallery_img_size = tuple(int(0.25*dim) for dim in images[0].size)
|
||||
job_info.rec_steps_imgs.append(grid.resize(gallery_img_size))
|
||||
|
||||
# Notify the requester that the image is updated
|
||||
if job_info.refresh_active_image_requested.is_set():
|
||||
if job_info.rec_steps_enabled:
|
||||
grid_rows = None if batch_size == 1 else len(job_info.rec_steps_imgs)
|
||||
grid = image_grid(imgs=job_info.rec_steps_imgs[::-1], batch_size=1, force_n_rows=grid_rows)
|
||||
job_info.active_image = grid
|
||||
job_info.refresh_active_image_done.set()
|
||||
job_info.refresh_active_image_requested.clear()
|
||||
|
||||
preview_elapsed_timed = time.time() - preview_start_time
|
||||
if preview_elapsed_timed / job_info.rec_steps_intrvl > 1:
|
||||
print(
|
||||
f"Warning: Preview generation is slowing image generation. It took {preview_elapsed_timed:.2f}s to generate progress images for batch of {batch_size} images!")
|
||||
|
||||
# Interrupt current iteration?
|
||||
if job_info.stop_cur_iter.is_set():
|
||||
job_info.stop_cur_iter.clear()
|
||||
raise StopIteration()
|
||||
|
||||
try:
|
||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback)
|
||||
except StopIteration:
|
||||
print("Skipping iteration")
|
||||
job_info.job_status = "Skipping iteration"
|
||||
continue
|
||||
|
||||
if opt.optimized:
|
||||
modelFS.to(device)
|
||||
@ -1196,6 +1275,19 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
|
||||
if simple_templating:
|
||||
grid_captions.append( captions[i] )
|
||||
|
||||
# Save the progress images?
|
||||
if job_info:
|
||||
if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery):
|
||||
steps_grid = image_grid(job_info.rec_steps_imgs, 1)
|
||||
if job_info.rec_steps_to_gallery:
|
||||
gallery_img_size = tuple(2*dim for dim in image.size)
|
||||
output_images.append( steps_grid.resize( gallery_img_size ) )
|
||||
if job_info.rec_steps_to_file:
|
||||
steps_grid_filename = f"{original_filename}_step_grid"
|
||||
save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
|
||||
|
||||
if opt.optimized:
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
modelFS.to("cpu")
|
||||
@ -1215,7 +1307,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
|
||||
import traceback
|
||||
print("Error creating prompt_matrix text:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
elif batch_size > 1 or n_iter > 1:
|
||||
elif len(output_images) > 0 and (batch_size > 1 or n_iter > 1):
|
||||
grid = image_grid(output_images, batch_size)
|
||||
if grid is not None:
|
||||
grid_count = get_next_sequence_number(outpath, 'grid-')
|
||||
@ -1308,8 +1400,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
|
||||
def init():
|
||||
pass
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback)
|
||||
return samples_ddim
|
||||
|
||||
try:
|
||||
@ -1573,7 +1665,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
|
||||
|
||||
return init_latent, mask,
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
|
||||
t_enc_steps = t_enc
|
||||
obliterate = False
|
||||
if ddim_steps == t_enc_steps:
|
||||
@ -1595,7 +1687,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
|
||||
else:
|
||||
|
||||
x0, z_mask = init_data
|
||||
|
Loading…
Reference in New Issue
Block a user