Re-merge #611 - View/Cancel in-progress diffusions (#796)

* JobManager: Re-merge #611

PR #611 seems to have got lost in the shuffle after
the transition to 'dev'.

This commit re-merges the feature branch. This adds
support for viewing preview images as the image
generates, as well as cancelling in-progress images
and a couple fixes and clean-ups.

* JobManager: Clear jobs that fail to start

Sometimes if a job fails to start it will get stuck in the active job
list. This commit ensures that jobs that raise exceptions are cleared,
and also adds a start timer to clear out jobs that fail to start
within a reasonable amount of time.
This commit is contained in:
cobryan05 2022-09-14 15:48:56 -05:00 committed by GitHub
parent 8bc8b006fd
commit 81f58d58d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 260 additions and 53 deletions

View File

@ -1,7 +1,7 @@
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
from __future__ import annotations
import gradio as gr
from gradio.components import Component, Gallery
from gradio.components import Component, Gallery, Slider
from threading import Event, Timer
from typing import Callable, List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, field
@ -9,6 +9,7 @@ from functools import partial
from PIL.Image import Image
import uuid
import traceback
import time
@dataclass(eq=True, frozen=True)
@ -30,9 +31,21 @@ class JobInfo:
session_key: str
job_token: Optional[int] = None
images: List[Image] = field(default_factory=list)
active_image: Image = None
rec_steps_enabled: bool = False
rec_steps_imgs: List[Image] = field(default_factory=list)
rec_steps_intrvl: int = None
rec_steps_to_gallery: bool = False
rec_steps_to_file: bool = False
should_stop: Event = field(default_factory=Event)
refresh_active_image_requested: Event = field(default_factory=Event)
refresh_active_image_done: Event = field(default_factory=Event)
stop_cur_iter: Event = field(default_factory=Event)
active_iteration_cnt: int = field(default_factory=int)
job_status: str = field(default_factory=str)
finished: bool = False
started: bool = False
timestamp: float = None
removed_output_idxs: List[int] = field(default_factory=list)
@ -76,7 +89,7 @@ class JobManagerUi:
'''
return self._job_manager._wrap_func(
func=func, inputs=inputs, outputs=outputs,
refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text
job_ui=self
)
_refresh_btn: gr.Button
@ -84,10 +97,19 @@ class JobManagerUi:
_status_text: gr.Textbox
_stop_all_session_btn: gr.Button
_free_done_sessions_btn: gr.Button
_active_image: gr.Image
_active_image_stop_btn: gr.Button
_active_image_refresh_btn: gr.Button
_rec_steps_intrvl_sldr: gr.Slider
_rec_steps_checkbox: gr.Checkbox
_save_rec_steps_to_gallery_chkbx: gr.Checkbox
_save_rec_steps_to_file_chkbx: gr.Checkbox
_job_manager: JobManager
class JobManager:
JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running
def __init__(self, max_jobs: int):
self._max_jobs: int = max_jobs
self._avail_job_tokens: List[Any] = list(range(max_jobs))
@ -102,11 +124,23 @@ class JobManager:
'''
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
with gr.Tabs():
with gr.TabItem("Current Session"):
with gr.TabItem("Job Controls"):
with gr.Row():
stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary")
refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary")
stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
with gr.Row():
active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
with gr.TabItem("Batch Progress Settings"):
with gr.Row():
record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
record_steps_interval_slider = gr.Slider(
value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
with gr.Row() as record_steps_box:
steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
with gr.TabItem("Maintenance"):
with gr.Row():
gr.Markdown(
@ -118,9 +152,15 @@ class JobManager:
free_done_sessions_btn = gr.Button(
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
)
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
_job_manager=self)
_active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
_active_image_refresh_btn=active_image_refresh_btn,
_rec_steps_checkbox=record_steps_checkbox,
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
_rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
def clear_all_finished_jobs(self):
''' Removes all currently finished jobs, across all sessions.
@ -134,6 +174,7 @@ class JobManager:
for session in self._sessions.values():
for job in session.jobs.values():
job.should_stop.set()
job.stop_cur_iter.set()
def _get_job_token(self, block: bool = False) -> Optional[int]:
''' Attempts to acquire a job token, optionally blocking until available '''
@ -175,6 +216,26 @@ class JobManager:
job_info.should_stop.set()
return "Stopping after current batch finishes"
def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Updates information from the active iteration '''
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return [None, f"Session {session_key} was not running function {func_key}"]
job_info.refresh_active_image_requested.set()
if job_info.refresh_active_image_done.wait(timeout=20.0):
job_info.refresh_active_image_done.clear()
return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
return [gr.Image.update(visible=False), "Timed out getting image"]
def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Marks that the active iteration should be stopped'''
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return [None, f"Session {session_key} was not running function {func_key}"]
job_info.stop_cur_iter.set()
return [gr.Image.update(visible=False), "Stopping current iteration"]
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
''' Helper to get the SessionInfo and JobInfo. '''
session_info = self._sessions.get(session_key, None)
@ -207,7 +268,8 @@ class JobManager:
def _pre_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
status_text: gr.Textbox, session_key: str) -> List[Component]:
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
session_key: str) -> List[Component]:
''' Called when a job is about to start '''
session_info, job_info = self._get_call_info(func_key, session_key)
@ -219,7 +281,9 @@ class JobManager:
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates")
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
}
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
@ -228,12 +292,19 @@ class JobManager:
if session_info is None or job_info is None:
return []
job_info.started = True
try:
if job_info.should_stop.is_set():
raise Exception(f"Job {job_info} requested a stop before execution began")
outputs = job_info.func(*job_info.inputs, job_info=job_info)
except Exception as e:
job_info.job_status = f"Error: {e}"
print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}")
outputs = []
raise
finally:
job_info.finished = True
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
self._release_job_token(job_info.job_token)
# Filter the function output for any removed outputs
filtered_output = []
@ -241,11 +312,6 @@ class JobManager:
if idx not in job_info.removed_output_idxs:
filtered_output.append(output)
job_info.finished = True
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
self._release_job_token(job_info.job_token)
# The wrapper added a dummy JSON output. Append a random text string
# to fire the dummy objects 'change' event to notify that the job is done
filtered_output.append(triggerChangeEvent())
@ -254,12 +320,16 @@ class JobManager:
def _post_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
status_text: gr.Textbox, session_key: str) -> List[Component]:
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
session_key: str) -> List[Component]:
''' Called when a job completes '''
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
status_text: gr.Textbox.update(value="Generation has finished!")
status_text: gr.Textbox.update(value="Generation has finished!"),
active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
active_image: gr.Image.update(visible=False)
}
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
@ -270,21 +340,17 @@ class JobManager:
if session_info is None or job_info is None:
return []
if job_info.finished:
session_info.finished_jobs.pop(func_key)
return job_info.images
def _wrap_func(
self, func: Callable, inputs: List[Component], outputs: List[Component],
refresh_btn: gr.Button = None, stop_btn: gr.Button = None,
status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]:
def _wrap_func(self, func: Callable, inputs: List[Component],
outputs: List[Component],
job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
''' handles JobManageUI's wrap_func'''
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
# Create a unique key for this job
func_key = FuncKey(job_id=uuid.uuid4(), func=func)
func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
if self._session_key is None:
@ -302,9 +368,6 @@ class JobManager:
del outputs[idx]
break
# Add the session key to the inputs
inputs += [self._session_key]
# Create dummy objects
update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject")
update_gallery_obj.change(
@ -314,21 +377,49 @@ class JobManager:
queue=False
)
if refresh_btn:
refresh_btn.variant = 'secondary'
refresh_btn.click(
if job_ui._refresh_btn:
job_ui._refresh_btn.variant = 'secondary'
job_ui._refresh_btn.click(
partial(self._refresh_func, func_key),
[self._session_key],
[update_gallery_obj, status_text],
[update_gallery_obj, job_ui._status_text],
queue=False
)
if stop_btn:
stop_btn.variant = 'secondary'
stop_btn.click(
if job_ui._stop_btn:
job_ui._stop_btn.variant = 'secondary'
job_ui._stop_btn.click(
partial(self._stop_wrapped_func, func_key),
[self._session_key],
[status_text],
[job_ui._status_text],
queue=False
)
if job_ui._active_image and job_ui._active_image_refresh_btn:
job_ui._active_image_refresh_btn.click(
partial(self._refresh_cur_iter_func, func_key),
[self._session_key],
[job_ui._active_image, job_ui._status_text],
queue=False
)
if job_ui._active_image_stop_btn:
job_ui._active_image_stop_btn.click(
partial(self._stop_cur_iter_func, func_key),
[self._session_key],
[job_ui._active_image, job_ui._status_text],
queue=False
)
if job_ui._stop_all_session_btn:
job_ui._stop_all_session_btn.click(
self.stop_all_jobs, [], [],
queue=False
)
if job_ui._free_done_sessions_btn:
job_ui._free_done_sessions_btn.click(
self.clear_all_finished_jobs, [], [],
queue=False
)
@ -346,7 +437,8 @@ class JobManager:
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
# the Component as a key... so group together the UI components that the event listeners are going to update
# to make it easy to append to function calls and outputs
job_ui_params = [refresh_btn, stop_btn, status_text]
job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
@ -375,27 +467,50 @@ class JobManager:
queue=False
)
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
# Add any components that we want the runtime values for
added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
def wrapped_func(*inputs):
session_key = inputs[-1]
inputs = inputs[:-1]
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
def wrapped_func(*wrapped_inputs):
# Remove the added_inputs (pop opposite order of list)
wrapped_inputs = list(wrapped_inputs)
rec_steps_interval: int = wrapped_inputs.pop()
save_rec_steps_file: bool = wrapped_inputs.pop()
save_rec_steps_grid: bool = wrapped_inputs.pop()
record_steps_enabled: bool = wrapped_inputs.pop()
session_key: str = wrapped_inputs.pop()
job_inputs = tuple(wrapped_inputs)
# Get or create a session for this key
session_info = self._sessions.setdefault(session_key, SessionInfo())
# Is this session already running this job?
if func_key in session_info.jobs:
return {status_text: "This session is already running that function!"}
job_info = session_info.jobs[func_key]
# If the job seems stuck in 'starting' then go ahead and toss it
if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME:
job_info.should_stop.set()
job_info.stop_cur_iter.set()
session_info.jobs.pop(func_key)
return {job_ui._status_text: "Canceled possibly hung job. Try again"}
return {job_ui._status_text: "This session is already running that function!"}
# Is this a new run of a previously finished job? Clear old info
if func_key in session_info.finished_jobs:
session_info.finished_jobs.pop(func_key)
job_token = self._get_job_token(block=False)
job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
job_token=job_token)
job = JobInfo(
inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time())
session_info.jobs[func_key] = job
ret = {pre_call_dummyobj: triggerChangeEvent()}
if job_token is None:
ret[status_text] = "Job is queued"
ret[job_ui._status_text] = "Job is queued"
return ret
return wrapped_func, inputs, [pre_call_dummyobj, status_text]
return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]

View File

@ -74,6 +74,7 @@ import copy
from typing import List, Union, Dict, Callable, Any, Optional
from pathlib import Path
from collections import namedtuple
from functools import partial
# tell the user which GPU the code is actually using
if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'):
@ -307,15 +308,21 @@ class KDiffusionSampler:
self.schedule = sampler
def get_sampler_name(self):
return self.schedule
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
return samples_ddim, None
@classmethod
def img_callback_wrapper(cls, callback: Callable, *args):
''' Converts a KDiffusion callback to the standard img_callback '''
if callback:
arg_dict = args[0]
callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i'])
def create_random_tensors(shape, seeds):
xs = []
@ -1003,6 +1010,7 @@ def process_images(
if job_info:
job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}"
job_info.rec_steps_imgs.clear()
for idx,(p,s) in enumerate(zip(prompts,seeds)):
job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}"
print(f"Current prompt: {p}")
@ -1057,7 +1065,78 @@ def process_images(
# finally, slerp base_x noise to target_x noise for creating a variant
x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
# If optimized then use first stage for preview and store it on cpu until needed
if opt.optimized:
step_preview_model = modelFS
step_preview_model.cpu()
else:
step_preview_model = model
def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int):
''' Called from the sampler every iteration '''
if job_info:
job_info.active_iteration_cnt = iter_num
record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl)
if record_periodic_image or job_info.refresh_active_image_requested.is_set():
preview_start_time = time.time()
if opt.optimized:
step_preview_model.to(device)
decoded_batch: List[torch.Tensor] = []
# Break up batch to save VRAM
for sample in image_sample:
sample = sample[None, :] # expands the tensor as if it still had a batch dimension
decoded_sample = step_preview_model.decode_first_stage(sample)[0]
decoded_sample = torch.clamp((decoded_sample + 1.0) / 2.0, min=0.0, max=1.0)
decoded_sample = decoded_sample.cpu()
decoded_batch.append(decoded_sample)
batch_size = len(decoded_batch)
if opt.optimized:
step_preview_model.cpu()
images: List[Image.Image] = []
# Convert tensor to image (copied from code below)
for ddim in decoded_batch:
x_sample = 255. * rearrange(ddim.numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
images.append(image)
caption = f"Iter {iter_num}"
grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images))
# Save the images if recording steps, and append existing saved steps
if job_info.rec_steps_enabled:
gallery_img_size = tuple(int(0.25*dim) for dim in images[0].size)
job_info.rec_steps_imgs.append(grid.resize(gallery_img_size))
# Notify the requester that the image is updated
if job_info.refresh_active_image_requested.is_set():
if job_info.rec_steps_enabled:
grid_rows = None if batch_size == 1 else len(job_info.rec_steps_imgs)
grid = image_grid(imgs=job_info.rec_steps_imgs[::-1], batch_size=1, force_n_rows=grid_rows)
job_info.active_image = grid
job_info.refresh_active_image_done.set()
job_info.refresh_active_image_requested.clear()
preview_elapsed_timed = time.time() - preview_start_time
if preview_elapsed_timed / job_info.rec_steps_intrvl > 1:
print(
f"Warning: Preview generation is slowing image generation. It took {preview_elapsed_timed:.2f}s to generate progress images for batch of {batch_size} images!")
# Interrupt current iteration?
if job_info.stop_cur_iter.is_set():
job_info.stop_cur_iter.clear()
raise StopIteration()
try:
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback)
except StopIteration:
print("Skipping iteration")
job_info.job_status = "Skipping iteration"
continue
if opt.optimized:
modelFS.to(device)
@ -1196,6 +1275,19 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
if simple_templating:
grid_captions.append( captions[i] )
# Save the progress images?
if job_info:
if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery):
steps_grid = image_grid(job_info.rec_steps_imgs, 1)
if job_info.rec_steps_to_gallery:
gallery_img_size = tuple(2*dim for dim in image.size)
output_images.append( steps_grid.resize( gallery_img_size ) )
if job_info.rec_steps_to_file:
steps_grid_filename = f"{original_filename}_step_grid"
save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
if opt.optimized:
mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu")
@ -1215,7 +1307,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
import traceback
print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
elif batch_size > 1 or n_iter > 1:
elif len(output_images) > 0 and (batch_size > 1 or n_iter > 1):
grid = image_grid(output_images, batch_size)
if grid is not None:
grid_count = get_next_sequence_number(outpath, 'grid-')
@ -1308,8 +1400,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
def init():
pass
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback)
return samples_ddim
try:
@ -1573,7 +1665,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
return init_latent, mask,
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps:
@ -1595,7 +1687,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
else:
x0, z_mask = init_data