2023-07-27 07:11:09 +03:00
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: AGPL-3.0
2023-07-21 02:05:24 +03:00
2023-07-01 00:13:46 +03:00
import math
import cv2
import os
import torch
2023-07-26 06:57:29 +03:00
import time
2023-07-27 07:11:09 +03:00
import hashlib
2023-07-01 00:13:46 +03:00
import functools
import gradio as gr
import numpy as np
2023-07-21 02:05:24 +03:00
import modules
2023-07-01 00:13:46 +03:00
import modules . paths as paths
import modules . scripts as scripts
2023-07-27 07:11:09 +03:00
from modules import images , devices , extra_networks , masking , shared
from modules . processing import (
StableDiffusionProcessing , Processed , apply_overlay , apply_color_correction ,
get_fixed_seed , create_random_tensors , create_infotext , setup_color_correction
)
from modules . sd_models import CheckpointInfo
from modules . shared import Shared , opts , state
from PIL import Image , ImageOps
import openvino . frontend . pytorch . torchdynamo . backend
from openvino . frontend . pytorch . torchdynamo . execute import partitioned_modules , compiled_cache
2023-07-19 17:34:56 +03:00
from openvino . runtime import Core
2023-07-01 00:13:46 +03:00
from diffusers import (
2023-07-27 07:11:09 +03:00
StableDiffusionPipeline ,
2023-07-01 00:13:46 +03:00
DDIMScheduler ,
DPMSolverMultistepScheduler ,
EulerAncestralDiscreteScheduler ,
EulerDiscreteScheduler ,
HeunDiscreteScheduler ,
2023-07-19 17:34:56 +03:00
LMSDiscreteScheduler ,
2023-07-01 00:13:46 +03:00
PNDMScheduler ,
)
2023-07-26 06:57:29 +03:00
class ModelState :
def __init__ ( self ) :
self . recompile = 1
self . device = " CPU "
2023-07-27 07:11:09 +03:00
self . height = 512
self . width = 512
self . batch_size = 1
2023-07-26 06:57:29 +03:00
model_state = ModelState ( )
2023-07-01 00:13:46 +03:00
2023-07-27 07:11:09 +03:00
def openvino_clear_caches ( ) :
global partitioned_modules
global compiled_cache
compiled_cache . clear ( )
partitioned_modules . clear ( )
2023-07-19 17:34:56 +03:00
def sd_diffusers_model ( self ) :
import modules . sd_models
return modules . sd_models . model_data . get_sd_model ( )
2023-07-01 00:13:46 +03:00
def cond_stage_key ( self ) :
return None
2023-07-27 07:11:09 +03:00
shared . sd_diffusers_model = sd_diffusers_model
2023-07-19 17:34:56 +03:00
def set_scheduler ( sd_model , sampler_name ) :
if ( sampler_name == " Euler a " ) :
sd_model . scheduler = EulerAncestralDiscreteScheduler . from_config ( sd_model . scheduler . config )
elif ( sampler_name == " Euler " ) :
sd_model . scheduler = EulerDiscreteScheduler . from_config ( sd_model . scheduler . config )
elif ( sampler_name == " LMS " ) :
sd_model . scheduler = LMSDiscreteScheduler . from_config ( sd_model . scheduler . config )
2023-07-22 07:08:01 +03:00
elif ( sampler_name == " Heun " ) :
sd_model . scheduler = HeunDiscreteScheduler . from_config ( sd_model . scheduler . config )
2023-07-19 17:34:56 +03:00
elif ( sampler_name == " DPM++ 2M " ) :
sd_model . scheduler = DPMSolverMultistepScheduler . from_config ( sd_model . scheduler . config , algorithm_type = " dpmsolver++ " , use_karras_sigmas = False )
elif ( sampler_name == " LMS Karras " ) :
sd_model . scheduler = LMSDiscreteScheduler . from_config ( sd_model . scheduler . config , use_karras_sigmas = True )
elif ( sampler_name == " DPM++ 2M Karras " ) :
sd_model . scheduler = DPMSolverMultistepScheduler . from_config ( sd_model . scheduler . config , algorithm_type = " dpmsolver++ " , use_karras_sigmas = True )
elif ( sampler_name == " DDIM " ) :
sd_model . scheduler = DDIMScheduler . from_config ( sd_model . scheduler . config )
elif ( sampler_name == " PLMS " ) :
sd_model . scheduler = PNDMScheduler . from_config ( sd_model . scheduler . config )
else :
sd_model . scheduler = EulerAncestralDiscreteScheduler . from_config ( sd_model . scheduler . config )
return sd_model . scheduler
2023-07-01 00:13:46 +03:00
2023-07-27 07:11:09 +03:00
def get_diffusers_sd_model ( sampler_name , enable_caching , openvino_device ) :
2023-07-26 06:57:29 +03:00
if ( model_state . recompile == 1 ) :
2023-07-25 08:08:28 +03:00
torch . _dynamo . reset ( )
2023-07-27 07:11:09 +03:00
openvino_clear_caches ( )
2023-07-19 17:34:56 +03:00
curr_dir_path = os . getcwd ( )
model_path = " /models/Stable-diffusion/ "
checkpoint_name = shared . opts . sd_model_checkpoint . split ( " " ) [ 0 ]
2023-07-27 07:11:09 +03:00
checkpoint_path = curr_dir_path + model_path + checkpoint_name
2023-07-19 17:34:56 +03:00
sd_model = StableDiffusionPipeline . from_single_file ( checkpoint_path )
checkpoint_info = CheckpointInfo ( checkpoint_path )
sd_model . sd_checkpoint_info = checkpoint_info
sd_model . sd_model_hash = checkpoint_info . calculate_shorthash ( )
2023-07-01 00:13:46 +03:00
sd_model . safety_checker = None
2023-07-25 08:08:28 +03:00
sd_model . cond_stage_key = functools . partial ( cond_stage_key , shared . sd_model )
2023-07-19 17:34:56 +03:00
sd_model . scheduler = set_scheduler ( sd_model , sampler_name )
2023-07-21 02:05:24 +03:00
sd_model . unet = torch . compile ( sd_model . unet , backend = " openvino " )
sd_model . vae . decode = torch . compile ( sd_model . vae . decode , backend = " openvino " )
2023-07-19 17:34:56 +03:00
shared . sd_diffusers_model = sd_model
2023-07-21 02:05:24 +03:00
del sd_model
2023-07-27 07:11:09 +03:00
return shared . sd_diffusers_model
2023-07-01 00:13:46 +03:00
2023-07-19 17:34:56 +03:00
def init_new ( self , all_prompts , all_seeds , all_subseeds ) :
2023-07-01 00:13:46 +03:00
crop_region = None
image_mask = self . image_mask
if image_mask is not None :
image_mask = image_mask . convert ( ' L ' )
if self . inpainting_mask_invert :
image_mask = ImageOps . invert ( image_mask )
if self . mask_blur_x > 0 :
np_mask = np . array ( image_mask )
kernel_size = 2 * int ( 4 * self . mask_blur_x + 0.5 ) + 1
np_mask = cv2 . GaussianBlur ( np_mask , ( kernel_size , 1 ) , self . mask_blur_x )
image_mask = Image . fromarray ( np_mask )
if self . mask_blur_y > 0 :
np_mask = np . array ( image_mask )
kernel_size = 2 * int ( 4 * self . mask_blur_y + 0.5 ) + 1
np_mask = cv2 . GaussianBlur ( np_mask , ( 1 , kernel_size ) , self . mask_blur_y )
image_mask = Image . fromarray ( np_mask )
if self . inpaint_full_res :
self . mask_for_overlay = image_mask
mask = image_mask . convert ( ' L ' )
crop_region = masking . get_crop_region ( np . array ( mask ) , self . inpaint_full_res_padding )
crop_region = masking . expand_crop_region ( crop_region , self . width , self . height , mask . width , mask . height )
x1 , y1 , x2 , y2 = crop_region
mask = mask . crop ( crop_region )
image_mask = images . resize_image ( 2 , mask , self . width , self . height )
self . paste_to = ( x1 , y1 , x2 - x1 , y2 - y1 )
else :
image_mask = images . resize_image ( self . resize_mode , image_mask , self . width , self . height )
np_mask = np . array ( image_mask )
np_mask = np . clip ( ( np_mask . astype ( np . float32 ) ) * 2 , 0 , 255 ) . astype ( np . uint8 )
self . mask_for_overlay = Image . fromarray ( np_mask )
self . overlay_images = [ ]
latent_mask = self . latent_mask if self . latent_mask is not None else image_mask
add_color_corrections = opts . img2img_color_correction and self . color_corrections is None
if add_color_corrections :
self . color_corrections = [ ]
imgs = [ ]
for img in self . init_images :
# Save init image
if opts . save_init_img :
self . init_img_hash = hashlib . md5 ( img . tobytes ( ) ) . hexdigest ( )
images . save_image ( img , path = opts . outdir_init_images , basename = None , forced_filename = self . init_img_hash , save_to_dirs = False )
image = images . flatten ( img , opts . img2img_background_color )
if crop_region is None and self . resize_mode != 3 :
image = images . resize_image ( self . resize_mode , image , self . width , self . height )
if image_mask is not None :
image_masked = Image . new ( ' RGBa ' , ( image . width , image . height ) )
image_masked . paste ( image . convert ( " RGBA " ) . convert ( " RGBa " ) , mask = ImageOps . invert ( self . mask_for_overlay . convert ( ' L ' ) ) )
self . overlay_images . append ( image_masked . convert ( ' RGBA ' ) )
# crop_region is not None if we are doing inpaint full res
if crop_region is not None :
image = image . crop ( crop_region )
image = images . resize_image ( 2 , image , self . width , self . height )
if image_mask is not None :
if self . inpainting_fill != 1 :
image = masking . fill ( image , latent_mask )
if add_color_corrections :
self . color_corrections . append ( setup_color_correction ( image ) )
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = np . moveaxis ( image , 2 , 0 )
imgs . append ( image )
if len ( imgs ) == 1 :
batch_images = np . expand_dims ( imgs [ 0 ] , axis = 0 ) . repeat ( self . batch_size , axis = 0 )
if self . overlay_images is not None :
self . overlay_images = self . overlay_images * self . batch_size
if self . color_corrections is not None and len ( self . color_corrections ) == 1 :
self . color_corrections = self . color_corrections * self . batch_size
elif len ( imgs ) < = self . batch_size :
self . batch_size = len ( imgs )
batch_images = np . array ( imgs )
else :
raise RuntimeError ( f " bad number of images passed: { len ( imgs ) } ; expecting { self . batch_size } or less " )
image = torch . from_numpy ( batch_images )
image = 2. * image - 1.
image = image . to ( shared . device )
2023-07-19 17:34:56 +03:00
self . init_latent = shared . sd_diffusers_model . vae . encode ( image ) . latent_dist . sample ( )
2023-07-01 00:13:46 +03:00
if self . resize_mode == 3 :
2023-07-21 02:05:24 +03:00
self . init_latent = torch . nn . functional . interpolate ( self . init_latent , size = ( self . height / / 8 , self . width / / 8 ) , mode = " bilinear " )
2023-07-01 00:13:46 +03:00
if image_mask is not None :
init_mask = latent_mask
latmask = init_mask . convert ( ' RGB ' ) . resize ( ( self . init_latent . shape [ 3 ] , self . init_latent . shape [ 2 ] ) )
latmask = np . moveaxis ( np . array ( latmask , dtype = np . float32 ) , 2 , 0 ) / 255
latmask = latmask [ 0 ]
latmask = np . around ( latmask )
latmask = np . tile ( latmask [ None ] , ( 4 , 1 , 1 ) )
2023-07-19 17:34:56 +03:00
self . mask = torch . asarray ( 1.0 - latmask ) . to ( shared . device ) . type ( shared . sd_diffusers_model . vae . dtype )
self . nmask = torch . asarray ( latmask ) . to ( shared . device ) . type ( shared . sd_diffusers_model . vae . dtype )
2023-07-01 00:13:46 +03:00
# this needs to be fixed to be done in sample() using actual seeds for batches
if self . inpainting_fill == 2 :
self . init_latent = self . init_latent * self . mask + create_random_tensors ( self . init_latent . shape [ 1 : ] , all_seeds [ 0 : self . init_latent . shape [ 0 ] ] ) * self . nmask
elif self . inpainting_fill == 3 :
self . init_latent = self . init_latent * self . mask
2023-07-25 08:08:28 +03:00
def process_images_openvino ( p : StableDiffusionProcessing , sampler_name , enable_caching , openvino_device ) - > Processed :
2023-07-01 00:13:46 +03:00
""" this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch """
if type ( p . prompt ) == list :
assert ( len ( p . prompt ) > 0 )
else :
assert p . prompt is not None
devices . torch_gc ( )
seed = get_fixed_seed ( p . seed )
subseed = get_fixed_seed ( p . subseed )
comments = { }
p . setup_prompts ( )
if type ( seed ) == list :
p . all_seeds = seed
else :
p . all_seeds = [ int ( seed ) + ( x if p . subseed_strength == 0 else 0 ) for x in range ( len ( p . all_prompts ) ) ]
if type ( subseed ) == list :
p . all_subseeds = subseed
else :
p . all_subseeds = [ int ( subseed ) + x for x in range ( len ( p . all_prompts ) ) ]
def infotext ( iteration = 0 , position_in_batch = 0 ) :
return create_infotext ( p , p . all_prompts , p . all_seeds , p . all_subseeds , comments , iteration , position_in_batch )
if p . scripts is not None :
p . scripts . process ( p )
infotexts = [ ]
output_images = [ ]
2023-07-27 07:11:09 +03:00
with torch . no_grad ( ) :
2023-07-01 00:13:46 +03:00
with devices . autocast ( ) :
p . init ( p . all_prompts , p . all_seeds , p . all_subseeds )
if state . job_count == - 1 :
state . job_count = p . n_iter
extra_network_data = None
for n in range ( p . n_iter ) :
p . iteration = n
if state . skipped :
state . skipped = False
if state . interrupted :
break
p . prompts = p . all_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
p . negative_prompts = p . all_negative_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
p . seeds = p . all_seeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
p . subseeds = p . all_subseeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
if p . scripts is not None :
p . scripts . before_process_batch ( p , batch_number = n , prompts = p . prompts , seeds = p . seeds , subseeds = p . subseeds )
if len ( p . prompts ) == 0 :
break
2023-07-27 07:11:09 +03:00
if ( model_state . height != p . height or model_state . width != p . width or model_state . batch_size != p . batch_size ) :
model_state . recompile = 1
model_state . height = p . height
model_state . width = p . width
model_state . batch_size = p . batch_size
2023-07-25 08:08:28 +03:00
shared . sd_diffusers_model = get_diffusers_sd_model ( sampler_name , enable_caching , openvino_device )
2023-07-26 06:57:29 +03:00
shared . sd_diffusers_model . scheduler = set_scheduler ( shared . sd_diffusers_model , sampler_name )
2023-07-01 00:13:46 +03:00
extra_network_data = p . parse_extra_network_prompts ( )
if not p . disable_extra_networks :
with devices . autocast ( ) :
extra_networks . activate ( p , p . extra_network_data )
2023-07-25 08:08:28 +03:00
# TODO: support multiplier
if ( ' lora ' in modules . extra_networks . extra_network_registry ) :
import lora
for lora_model in lora . loaded_loras :
shared . sd_diffusers_model . load_lora_weights ( os . getcwd ( ) + " /models/Lora/ " , weight_name = lora_model . name + " .safetensors " )
2023-07-01 00:13:46 +03:00
if p . scripts is not None :
p . scripts . process_batch ( p , batch_number = n , prompts = p . prompts , seeds = p . seeds , subseeds = p . subseeds )
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
# strength, which is saved as "Model Strength: 1.0" in the infotext
if n == 0 :
with open ( os . path . join ( paths . data_path , " params.txt " ) , " w " , encoding = " utf8 " ) as file :
processed = Processed ( p , [ ] , p . seed , " " )
file . write ( create_infotext ( p , p . all_prompts , p . all_seeds , p . all_subseeds , comments = [ ] , position_in_batch = 0 % p . batch_size , iteration = 0 / / p . batch_size ) )
if p . n_iter > 1 :
shared . state . job = f " Batch { n + 1 } out of { p . n_iter } "
generator = [ torch . Generator ( device = " cpu " ) . manual_seed ( s ) for s in p . seeds ]
2023-07-27 07:11:09 +03:00
2023-07-26 06:57:29 +03:00
time_stamps = [ ]
2023-07-25 08:08:28 +03:00
2023-07-26 06:57:29 +03:00
def callback ( iter , t , latents ) :
time_stamps . append ( time . time ( ) )
time_stamps . append ( time . time ( ) )
2023-07-19 17:34:56 +03:00
output = shared . sd_diffusers_model (
2023-07-01 00:13:46 +03:00
prompt = p . prompts ,
negative_prompt = p . negative_prompts ,
num_inference_steps = p . steps ,
guidance_scale = p . cfg_scale ,
height = p . height ,
width = p . width ,
generator = generator ,
output_type = " np " ,
2023-07-26 06:57:29 +03:00
callback = callback ,
callback_steps = 1
2023-07-01 00:13:46 +03:00
)
2023-07-26 06:57:29 +03:00
model_state . recompile = 0
warmup_duration = time_stamps [ 1 ] - time_stamps [ 0 ]
generation_rate = ( p . steps - 1 ) / ( time_stamps [ - 1 ] - time_stamps [ 1 ] )
2023-07-27 07:11:09 +03:00
x_samples_ddim = output . images
2023-07-01 00:13:46 +03:00
for i , x_sample in enumerate ( x_samples_ddim ) :
p . batch_index = i
x_sample = ( 255. * x_sample ) . astype ( np . uint8 )
if p . restore_faces :
if opts . save and not p . do_not_save_samples and opts . save_images_before_face_restoration :
images . save_image ( Image . fromarray ( x_sample ) , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p , suffix = " -before-face-restoration " )
devices . torch_gc ( )
x_sample = modules . face_restoration . restore_faces ( x_sample )
devices . torch_gc ( )
image = Image . fromarray ( x_sample )
if p . scripts is not None :
pp = scripts . PostprocessImageArgs ( image )
p . scripts . postprocess_image ( p , pp )
image = pp . image
if p . color_corrections is not None and i < len ( p . color_corrections ) :
if opts . save and not p . do_not_save_samples and opts . save_images_before_color_correction :
image_without_cc = apply_overlay ( image , p . paste_to , i , p . overlay_images )
images . save_image ( image_without_cc , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p , suffix = " -before-color-correction " )
image = apply_color_correction ( p . color_corrections [ i ] , image )
image = apply_overlay ( image , p . paste_to , i , p . overlay_images )
if opts . samples_save and not p . do_not_save_samples :
images . save_image ( image , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p )
text = infotext ( n , i )
infotexts . append ( text )
if opts . enable_pnginfo :
image . info [ " parameters " ] = text
output_images . append ( image )
if hasattr ( p , ' mask_for_overlay ' ) and p . mask_for_overlay and any ( [ opts . save_mask , opts . save_mask_composite , opts . return_mask , opts . return_mask_composite ] ) :
image_mask = p . mask_for_overlay . convert ( ' RGB ' )
image_mask_composite = Image . composite ( image . convert ( ' RGBA ' ) . convert ( ' RGBa ' ) , Image . new ( ' RGBa ' , image . size ) , images . resize_image ( 2 , p . mask_for_overlay , image . width , image . height ) . convert ( ' L ' ) ) . convert ( ' RGBA ' )
if opts . save_mask :
images . save_image ( image_mask , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p , suffix = " -mask " )
if opts . save_mask_composite :
images . save_image ( image_mask_composite , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( n , i ) , p = p , suffix = " -mask-composite " )
if opts . return_mask :
output_images . append ( image_mask )
if opts . return_mask_composite :
output_images . append ( image_mask_composite )
del x_samples_ddim
devices . torch_gc ( )
state . nextjob ( )
p . color_corrections = None
index_of_first_image = 0
unwanted_grid_because_of_img_count = len ( output_images ) < 2 and opts . grid_only_if_multiple
if ( opts . return_grid or opts . grid_save ) and not p . do_not_save_grid and not unwanted_grid_because_of_img_count :
grid = images . image_grid ( output_images , p . batch_size )
if opts . return_grid :
text = infotext ( )
infotexts . insert ( 0 , text )
if opts . enable_pnginfo :
grid . info [ " parameters " ] = text
output_images . insert ( 0 , grid )
index_of_first_image = 1
if opts . grid_save :
images . save_image ( grid , p . outpath_grids , " grid " , p . all_seeds [ 0 ] , p . all_prompts [ 0 ] , opts . grid_format , info = infotext ( ) , short_filename = not opts . grid_extended_filename , p = p , grid = True )
if not p . disable_extra_networks and extra_network_data :
extra_networks . deactivate ( p , p . extra_network_data )
devices . torch_gc ( )
2023-07-27 07:11:09 +03:00
2023-07-01 00:13:46 +03:00
res = Processed (
p ,
images_list = output_images ,
seed = p . all_seeds [ 0 ] ,
info = infotext ( ) ,
comments = " " . join ( f " { comment } \n " for comment in comments ) ,
subseed = p . all_subseeds [ 0 ] ,
index_of_first_image = index_of_first_image ,
infotexts = infotexts ,
)
2023-07-26 06:57:29 +03:00
res . info = res . info + " , Warm up time: " + str ( round ( warmup_duration , 2 ) ) + " secs "
2023-07-27 07:11:09 +03:00
2023-07-26 06:57:29 +03:00
if ( generation_rate > = 1.0 ) :
res . info = res . info + " , Performance: " + str ( round ( generation_rate , 2 ) ) + " it/s "
else :
res . info = res . info + " , Performance: " + str ( round ( 1 / generation_rate , 2 ) ) + " s/it "
2023-07-27 07:11:09 +03:00
2023-07-26 06:57:29 +03:00
2023-07-01 00:13:46 +03:00
if p . scripts is not None :
p . scripts . postprocess ( p , res )
return res
class Script ( scripts . Script ) :
def title ( self ) :
return " Accelerate with OpenVINO "
def show ( self , is_img2img ) :
return True
2023-07-27 07:11:09 +03:00
def ui ( self , is_img2img ) :
2023-07-19 17:34:56 +03:00
core = Core ( )
2023-07-27 07:11:09 +03:00
openvino_device = gr . Dropdown ( label = " Select a device " , choices = [ device for device in core . available_devices ] , value = model_state . device )
2023-07-21 02:05:24 +03:00
override_sampler = gr . Checkbox ( label = " Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO) " , value = True )
2023-07-22 07:08:01 +03:00
sampler_name = gr . Radio ( label = " Select a sampling method " , choices = [ " Euler a " , " Euler " , " LMS " , " Heun " , " DPM++ 2M " , " LMS Karras " , " DPM++ 2M Karras " , " DDIM " , " PLMS " ] , value = " Euler a " )
2023-07-26 06:57:29 +03:00
enable_caching = gr . Checkbox ( label = " Cache the compiled models on disk for faster model load in subsequent launches (Recommended) " , value = True , elem_id = self . elem_id ( " enable_caching " ) )
2023-07-25 08:08:28 +03:00
warmup_status = gr . Textbox ( label = " Device " , interactive = False , visible = False )
2023-07-26 06:57:29 +03:00
warmup_note = gr . Markdown (
"""
###
### Note:
2023-07-27 07:11:09 +03:00
First inference involves compilation of the model for best performance .
Excluding the first inference ( or warm up inference ) is recommended for
performance measurements . When resolution , batchsize , or device is changed ,
or samplers like DPM + + or Karras are selected , model is recompiled . Subsequent
2023-07-26 06:57:29 +03:00
iterations use the cached compiled model for faster inference .
2023-07-27 07:11:09 +03:00
""" )
def device_change ( choice ) :
2023-07-26 06:57:29 +03:00
if ( model_state . device == choice ) :
2023-07-25 08:08:28 +03:00
return gr . update ( value = " Device selected is " + choice , visible = True )
2023-07-22 07:08:01 +03:00
else :
2023-07-27 07:11:09 +03:00
model_state . device = choice
model_state . recompile = 1
2023-07-25 08:08:28 +03:00
return gr . update ( value = " Device changed to " + choice + " . Model will be re-compiled " , visible = True )
2023-07-27 07:11:09 +03:00
openvino_device . change ( device_change , openvino_device , warmup_status )
2023-07-21 02:05:24 +03:00
2023-07-26 06:57:29 +03:00
return [ openvino_device , override_sampler , sampler_name , enable_caching ]
2023-07-19 17:34:56 +03:00
2023-07-26 06:57:29 +03:00
def run ( self , p , openvino_device , override_sampler , sampler_name , enable_caching ) :
2023-07-27 07:11:09 +03:00
os . environ [ " OPENVINO_TORCH_BACKEND_DEVICE " ] = str ( openvino_device )
2023-07-21 02:05:24 +03:00
if enable_caching :
os . environ [ " OPENVINO_TORCH_MODEL_CACHING " ] = " 1 "
2023-07-22 07:08:01 +03:00
if override_sampler :
p . sampler_name = sampler_name
else :
supported_samplers = [ " Euler a " , " Euler " , " LMS " , " Heun " , " DPM++ 2M " , " LMS Karras " , " DPM++ 2M Karras " , " DDIM " , " PLMS " ]
if ( p . sampler_name not in supported_samplers ) :
p . sampler_name = " Euler a "
2023-07-01 00:13:46 +03:00
if self . is_txt2img :
2023-07-25 08:08:28 +03:00
processed = process_images_openvino ( p , p . sampler_name , enable_caching , openvino_device )
2023-07-01 00:13:46 +03:00
else :
2023-07-19 17:34:56 +03:00
p . init = functools . partial ( init_new , p )
2023-07-25 08:08:28 +03:00
processed = process_images_openvino ( p , p . sampler_name , enable_caching , openvino_device )
2023-07-01 00:13:46 +03:00
return processed
2023-07-26 06:57:29 +03:00