2022-09-03 12:08:45 +03:00
import math
2022-09-08 05:35:26 +03:00
import cv2
import numpy as np
2022-09-07 23:37:54 +03:00
from PIL import Image , ImageOps , ImageChops
2022-09-03 12:08:45 +03:00
from modules . processing import Processed , StableDiffusionProcessingImg2Img , process_images
from modules . shared import opts , state
import modules . shared as shared
import modules . processing as processing
from modules . ui import plaintext_to_html
import modules . images as images
2022-09-03 17:21:15 +03:00
import modules . scripts
2022-09-03 12:08:45 +03:00
2022-09-11 12:18:06 +03:00
def img2img ( prompt : str , negative_prompt : str , prompt_style : str , init_img , init_img_with_mask , init_mask , mask_mode , steps : int , sampler_index : int , mask_blur : int , inpainting_fill : int , restore_faces : bool , tiling : bool , mode : int , n_iter : int , batch_size : int , cfg_scale : float , denoising_strength : float , denoising_strength_change_factor : float , seed : int , subseed : int , subseed_strength : float , seed_resize_from_h : int , seed_resize_from_w : int , height : int , width : int , resize_mode : int , upscaler_index : str , upscale_overlap : int , inpaint_full_res : bool , inpainting_mask_invert : int , * args ) :
2022-09-03 12:08:45 +03:00
is_inpaint = mode == 1
is_loopback = mode == 2
is_upscale = mode == 3
if is_inpaint :
2022-09-09 19:43:16 +03:00
if mask_mode == 0 :
image = init_img_with_mask [ ' image ' ]
mask = init_img_with_mask [ ' mask ' ]
alpha_mask = ImageOps . invert ( image . split ( ) [ - 1 ] ) . convert ( ' L ' ) . point ( lambda x : 255 if x > 0 else 0 , mode = ' 1 ' )
mask = ImageChops . lighter ( alpha_mask , mask . convert ( ' L ' ) ) . convert ( ' L ' )
image = image . convert ( ' RGB ' )
else :
image = init_img
mask = init_mask
2022-09-03 12:08:45 +03:00
else :
image = init_img
mask = None
assert 0. < = denoising_strength < = 1. , ' can only work with strength in [0.0, 1.0] '
p = StableDiffusionProcessingImg2Img (
sd_model = shared . sd_model ,
outpath_samples = opts . outdir_samples or opts . outdir_img2img_samples ,
outpath_grids = opts . outdir_grids or opts . outdir_img2img_grids ,
prompt = prompt ,
2022-09-09 09:15:36 +03:00
negative_prompt = negative_prompt ,
2022-09-09 23:16:02 +03:00
prompt_style = prompt_style ,
2022-09-03 12:08:45 +03:00
seed = seed ,
2022-09-09 17:54:04 +03:00
subseed = subseed ,
subseed_strength = subseed_strength ,
seed_resize_from_h = seed_resize_from_h ,
seed_resize_from_w = seed_resize_from_w ,
2022-09-03 12:08:45 +03:00
sampler_index = sampler_index ,
batch_size = batch_size ,
n_iter = n_iter ,
steps = steps ,
cfg_scale = cfg_scale ,
width = width ,
height = height ,
2022-09-07 12:32:28 +03:00
restore_faces = restore_faces ,
2022-09-05 03:25:37 +03:00
tiling = tiling ,
2022-09-03 12:08:45 +03:00
init_images = [ image ] ,
mask = mask ,
mask_blur = mask_blur ,
inpainting_fill = inpainting_fill ,
resize_mode = resize_mode ,
denoising_strength = denoising_strength ,
inpaint_full_res = inpaint_full_res ,
2022-09-03 21:02:38 +03:00
inpainting_mask_invert = inpainting_mask_invert ,
2022-09-08 15:02:06 +03:00
extra_generation_params = {
" Denoising strength " : denoising_strength ,
2022-09-10 11:37:06 +03:00
" Denoising strength change factor " : ( denoising_strength_change_factor if is_loopback else None )
2022-09-08 15:02:06 +03:00
}
2022-09-03 12:08:45 +03:00
)
2022-09-08 16:37:13 +03:00
print ( f " \n img2img: { prompt } " , file = shared . progress_print_out )
2022-09-03 12:08:45 +03:00
if is_loopback :
output_images , info = None , None
history = [ ]
initial_seed = None
initial_info = None
2022-09-06 02:09:01 +03:00
state . job_count = n_iter
2022-09-08 05:35:26 +03:00
do_color_correction = False
try :
from skimage import exposure
do_color_correction = True
except :
print ( " Install scikit-image to perform color correction on loopback " )
2022-09-03 12:08:45 +03:00
for i in range ( n_iter ) :
2022-09-08 05:35:26 +03:00
if do_color_correction and i == 0 :
correction_target = cv2 . cvtColor ( np . asarray ( init_img . copy ( ) ) , cv2 . COLOR_RGB2LAB )
2022-09-03 12:08:45 +03:00
p . n_iter = 1
p . batch_size = 1
p . do_not_save_grid = True
state . job = f " Batch { i + 1 } out of { n_iter } "
processed = process_images ( p )
if initial_seed is None :
initial_seed = processed . seed
initial_info = processed . info
2022-09-08 05:35:26 +03:00
init_img = processed . images [ 0 ]
if do_color_correction and correction_target is not None :
init_img = Image . fromarray ( cv2 . cvtColor ( exposure . match_histograms (
cv2 . cvtColor (
np . asarray ( init_img ) ,
cv2 . COLOR_RGB2LAB
) ,
correction_target ,
channel_axis = 2
) , cv2 . COLOR_LAB2RGB ) . astype ( " uint8 " ) )
p . init_images = [ init_img ]
2022-09-03 12:08:45 +03:00
p . seed = processed . seed + 1
2022-09-08 15:02:06 +03:00
p . denoising_strength = min ( max ( p . denoising_strength * denoising_strength_change_factor , 0.1 ) , 1 )
2022-09-03 12:08:45 +03:00
history . append ( processed . images [ 0 ] )
grid = images . image_grid ( history , batch_size , rows = 1 )
images . save_image ( grid , p . outpath_grids , " grid " , initial_seed , prompt , opts . grid_format , info = info , short_filename = not opts . grid_extended_filename )
processed = Processed ( p , history , initial_seed , initial_info )
elif is_upscale :
initial_info = None
2022-09-10 11:37:06 +03:00
processing . fix_seed ( p )
seed = p . seed
2022-09-04 18:54:12 +03:00
upscaler = shared . sd_upscalers [ upscaler_index ]
img = upscaler . upscale ( init_img , init_img . width * 2 , init_img . height * 2 )
2022-09-03 12:08:45 +03:00
processing . torch_gc ( )
grid = images . split_grid ( img , tile_w = width , tile_h = height , overlap = upscale_overlap )
2022-09-10 11:37:06 +03:00
upscale_count = p . n_iter
2022-09-03 12:08:45 +03:00
p . n_iter = 1
p . do_not_save_grid = True
p . do_not_save_samples = True
work = [ ]
for y , h , row in grid . tiles :
for tiledata in row :
work . append ( tiledata [ 2 ] )
batch_count = math . ceil ( len ( work ) / p . batch_size )
2022-09-10 11:37:06 +03:00
state . job_count = batch_count * upscale_count
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
print ( f " SD upscaling will process a total of { len ( work ) } images tiled as { len ( grid . tiles [ 0 ] [ 2 ] ) } x { len ( grid . tiles ) } per upscale in a total of { state . job_count } batches. " )
2022-09-06 02:09:01 +03:00
2022-09-10 11:37:06 +03:00
result_images = [ ]
for n in range ( upscale_count ) :
start_seed = seed + n
p . seed = start_seed
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
work_results = [ ]
for i in range ( batch_count ) :
p . init_images = work [ i * p . batch_size : ( i + 1 ) * p . batch_size ]
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
state . job = f " Batch { i + 1 } out of { state . job_count } "
processed = process_images ( p )
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
if initial_info is None :
initial_info = processed . info
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
p . seed = processed . seed + 1
work_results + = processed . images
image_index = 0
for y , h , row in grid . tiles :
for tiledata in row :
tiledata [ 2 ] = work_results [ image_index ] if image_index < len ( work_results ) else Image . new ( " RGB " , ( p . width , p . height ) )
image_index + = 1
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
combined_image = images . combine_grid ( grid )
result_images . append ( combined_image )
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
if opts . samples_save :
images . save_image ( combined_image , p . outpath_samples , " " , start_seed , prompt , opts . grid_format , info = initial_info )
2022-09-03 12:08:45 +03:00
2022-09-10 11:37:06 +03:00
processed = Processed ( p , result_images , seed , initial_info )
2022-09-03 12:08:45 +03:00
else :
2022-09-03 17:21:15 +03:00
2022-09-04 01:29:43 +03:00
processed = modules . scripts . scripts_img2img . run ( p , * args )
2022-09-03 17:21:15 +03:00
if processed is None :
processed = process_images ( p )
2022-09-08 16:37:13 +03:00
shared . total_tqdm . clear ( )
2022-09-03 12:08:45 +03:00
return processed . images , processed . js ( ) , plaintext_to_html ( processed . info )