2022-09-16 22:18:30 +03:00
import math
import numpy as np
import skimage
import modules . scripts as scripts
import gradio as gr
from PIL import Image , ImageDraw
from modules import images , processing , devices
from modules . processing import Processed , process_images
from modules . shared import opts , cmd_opts , state
2022-10-01 00:38:48 +03:00
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
2022-09-16 22:18:30 +03:00
def get_matched_noise ( _np_src_image , np_mask_rgb , noise_q = 1 , color_variation = 0.05 ) :
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2 ( data ) :
if data . ndim > 2 : # has channels
out_fft = np . zeros ( ( data . shape [ 0 ] , data . shape [ 1 ] , data . shape [ 2 ] ) , dtype = np . complex128 )
for c in range ( data . shape [ 2 ] ) :
c_data = data [ : , : , c ]
out_fft [ : , : , c ] = np . fft . fft2 ( np . fft . fftshift ( c_data ) , norm = " ortho " )
out_fft [ : , : , c ] = np . fft . ifftshift ( out_fft [ : , : , c ] )
else : # one channel
out_fft = np . zeros ( ( data . shape [ 0 ] , data . shape [ 1 ] ) , dtype = np . complex128 )
out_fft [ : , : ] = np . fft . fft2 ( np . fft . fftshift ( data ) , norm = " ortho " )
out_fft [ : , : ] = np . fft . ifftshift ( out_fft [ : , : ] )
return out_fft
def _ifft2 ( data ) :
if data . ndim > 2 : # has channels
out_ifft = np . zeros ( ( data . shape [ 0 ] , data . shape [ 1 ] , data . shape [ 2 ] ) , dtype = np . complex128 )
for c in range ( data . shape [ 2 ] ) :
c_data = data [ : , : , c ]
out_ifft [ : , : , c ] = np . fft . ifft2 ( np . fft . fftshift ( c_data ) , norm = " ortho " )
out_ifft [ : , : , c ] = np . fft . ifftshift ( out_ifft [ : , : , c ] )
else : # one channel
out_ifft = np . zeros ( ( data . shape [ 0 ] , data . shape [ 1 ] ) , dtype = np . complex128 )
out_ifft [ : , : ] = np . fft . ifft2 ( np . fft . fftshift ( data ) , norm = " ortho " )
out_ifft [ : , : ] = np . fft . ifftshift ( out_ifft [ : , : ] )
return out_ifft
def _get_gaussian_window ( width , height , std = 3.14 , mode = 0 ) :
window_scale_x = float ( width / min ( width , height ) )
window_scale_y = float ( height / min ( width , height ) )
window = np . zeros ( ( width , height ) )
x = ( np . arange ( width ) / width * 2. - 1. ) * window_scale_x
for y in range ( height ) :
fy = ( y / height * 2. - 1. ) * window_scale_y
if mode == 0 :
window [ : , y ] = np . exp ( - ( x * * 2 + fy * * 2 ) * std )
else :
window [ : , y ] = ( 1 / ( ( x * * 2 + 1. ) * ( fy * * 2 + 1. ) ) ) * * ( std / 3.14 ) # hey wait a minute that's not gaussian
return window
def _get_masked_window_rgb ( np_mask_grey , hardness = 1. ) :
np_mask_rgb = np . zeros ( ( np_mask_grey . shape [ 0 ] , np_mask_grey . shape [ 1 ] , 3 ) )
if hardness != 1. :
hardened = np_mask_grey [ : ] * * hardness
else :
hardened = np_mask_grey [ : ]
for c in range ( 3 ) :
np_mask_rgb [ : , : , c ] = hardened [ : ]
return np_mask_rgb
width = _np_src_image . shape [ 0 ]
height = _np_src_image . shape [ 1 ]
num_channels = _np_src_image . shape [ 2 ]
np_src_image = _np_src_image [ : ] * ( 1. - np_mask_rgb )
np_mask_grey = ( np . sum ( np_mask_rgb , axis = 2 ) / 3. )
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
windowed_image = _np_src_image * ( 1. - _get_masked_window_rgb ( np_mask_grey ) )
windowed_image / = np . max ( windowed_image )
windowed_image + = np . average ( _np_src_image ) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
src_fft = _fft2 ( windowed_image ) # get feature statistics from masked src img
src_dist = np . absolute ( src_fft )
src_phase = src_fft / src_dist
2022-10-05 17:31:48 +03:00
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np . random . default_rng ( 0 )
2022-09-16 22:18:30 +03:00
noise_window = _get_gaussian_window ( width , height , mode = 1 ) # start with simple gaussian noise
2022-10-05 17:31:48 +03:00
noise_rgb = rng . random ( ( width , height , num_channels ) )
2022-09-16 22:18:30 +03:00
noise_grey = ( np . sum ( noise_rgb , axis = 2 ) / 3. )
noise_rgb * = color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
for c in range ( num_channels ) :
noise_rgb [ : , : , c ] + = ( 1. - color_variation ) * noise_grey
noise_fft = _fft2 ( noise_rgb )
for c in range ( num_channels ) :
noise_fft [ : , : , c ] * = noise_window
noise_rgb = np . real ( _ifft2 ( noise_fft ) )
shaped_noise_fft = _fft2 ( noise_rgb )
shaped_noise_fft [ : , : , : ] = np . absolute ( shaped_noise_fft [ : , : , : ] ) * * 2 * ( src_dist * * noise_q ) * src_phase # perform the actual shaping
brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
contrast_adjusted_np_src = _np_src_image [ : ] * ( brightness_variation + 1. ) - brightness_variation * 2.
# scikit-image is used for histogram matching, very convenient!
shaped_noise = np . real ( _ifft2 ( shaped_noise_fft ) )
shaped_noise - = np . min ( shaped_noise )
shaped_noise / = np . max ( shaped_noise )
shaped_noise [ img_mask , : ] = skimage . exposure . match_histograms ( shaped_noise [ img_mask , : ] * * 1. , contrast_adjusted_np_src [ ref_mask , : ] , channel_axis = 1 )
shaped_noise = _np_src_image [ : ] * ( 1. - np_mask_rgb ) + shaped_noise * np_mask_rgb
matched_noise = shaped_noise [ : ]
return np . clip ( matched_noise , 0. , 1. )
class Script ( scripts . Script ) :
def title ( self ) :
return " Outpainting mk2 "
def show ( self , is_img2img ) :
return is_img2img
def ui ( self , is_img2img ) :
if not is_img2img :
return None
info = gr . HTML ( " <p style= \" margin-bottom:0.75em \" >Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p> " )
2023-01-05 11:29:07 +03:00
pixels = gr . Slider ( label = " Pixels to expand " , minimum = 8 , maximum = 256 , step = 8 , value = 128 , elem_id = self . elem_id ( " pixels " ) )
mask_blur = gr . Slider ( label = ' Mask blur ' , minimum = 0 , maximum = 64 , step = 1 , value = 8 , elem_id = self . elem_id ( " mask_blur " ) )
direction = gr . CheckboxGroup ( label = " Outpainting direction " , choices = [ ' left ' , ' right ' , ' up ' , ' down ' ] , value = [ ' left ' , ' right ' , ' up ' , ' down ' ] , elem_id = self . elem_id ( " direction " ) )
noise_q = gr . Slider ( label = " Fall-off exponent (lower=higher detail) " , minimum = 0.0 , maximum = 4.0 , step = 0.01 , value = 1.0 , elem_id = self . elem_id ( " noise_q " ) )
color_variation = gr . Slider ( label = " Color variation " , minimum = 0.0 , maximum = 1.0 , step = 0.01 , value = 0.05 , elem_id = self . elem_id ( " color_variation " ) )
2022-09-16 22:18:30 +03:00
return [ info , pixels , mask_blur , direction , noise_q , color_variation ]
def run ( self , p , _ , pixels , mask_blur , direction , noise_q , color_variation ) :
initial_seed_and_info = [ None , None ]
process_width = p . width
process_height = p . height
p . mask_blur = mask_blur * 4
p . inpaint_full_res = False
p . inpainting_fill = 1
p . do_not_save_samples = True
p . do_not_save_grid = True
left = pixels if " left " in direction else 0
right = pixels if " right " in direction else 0
up = pixels if " up " in direction else 0
down = pixels if " down " in direction else 0
init_img = p . init_images [ 0 ]
target_w = math . ceil ( ( init_img . width + left + right ) / 64 ) * 64
target_h = math . ceil ( ( init_img . height + up + down ) / 64 ) * 64
if left > 0 :
left = left * ( target_w - init_img . width ) / / ( left + right )
2022-09-21 15:40:31 +03:00
2022-09-16 22:18:30 +03:00
if right > 0 :
right = target_w - init_img . width - left
if up > 0 :
up = up * ( target_h - init_img . height ) / / ( up + down )
if down > 0 :
down = target_h - init_img . height - up
2022-10-20 18:02:32 +03:00
def expand ( init , count , expand_pixels , is_left = False , is_right = False , is_top = False , is_bottom = False ) :
2022-09-16 22:18:30 +03:00
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
2022-10-20 18:02:32 +03:00
images_to_process = [ ]
2022-10-20 18:16:09 +03:00
output_images = [ ]
2022-10-20 18:02:32 +03:00
for n in range ( count ) :
res_w = init [ n ] . width + pixels_horiz
res_h = init [ n ] . height + pixels_vert
process_res_w = math . ceil ( res_w / 64 ) * 64
process_res_h = math . ceil ( res_h / 64 ) * 64
img = Image . new ( " RGB " , ( process_res_w , process_res_h ) )
img . paste ( init [ n ] , ( pixels_horiz if is_left else 0 , pixels_vert if is_top else 0 ) )
mask = Image . new ( " RGB " , ( process_res_w , process_res_h ) , " white " )
draw = ImageDraw . Draw ( mask )
draw . rectangle ( (
expand_pixels + mask_blur if is_left else 0 ,
expand_pixels + mask_blur if is_top else 0 ,
mask . width - expand_pixels - mask_blur if is_right else res_w ,
mask . height - expand_pixels - mask_blur if is_bottom else res_h ,
) , fill = " black " )
np_image = ( np . asarray ( img ) / 255.0 ) . astype ( np . float64 )
np_mask = ( np . asarray ( mask ) / 255.0 ) . astype ( np . float64 )
noised = get_matched_noise ( np_image , np_mask , noise_q , color_variation )
2022-10-20 18:16:09 +03:00
output_images . append ( Image . fromarray ( np . clip ( noised * 255. , 0. , 255. ) . astype ( np . uint8 ) , mode = " RGB " ) )
2022-10-20 18:02:32 +03:00
target_width = min ( process_width , init [ n ] . width + pixels_horiz ) if is_horiz else img . width
target_height = min ( process_height , init [ n ] . height + pixels_vert ) if is_vert else img . height
p . width = target_width if is_horiz else img . width
p . height = target_height if is_vert else img . height
crop_region = (
2022-10-20 18:16:09 +03:00
0 if is_left else output_images [ n ] . width - target_width ,
0 if is_top else output_images [ n ] . height - target_height ,
target_width if is_left else output_images [ n ] . width ,
target_height if is_top else output_images [ n ] . height ,
2022-10-20 18:02:32 +03:00
)
mask = mask . crop ( crop_region )
p . image_mask = mask
2022-10-20 18:16:09 +03:00
image_to_process = output_images [ n ] . crop ( crop_region )
2022-10-20 18:02:32 +03:00
images_to_process . append ( image_to_process )
p . init_images = images_to_process
2022-09-16 22:18:30 +03:00
latent_mask = Image . new ( " RGB " , ( p . width , p . height ) , " white " )
draw = ImageDraw . Draw ( latent_mask )
draw . rectangle ( (
expand_pixels + mask_blur * 2 if is_left else 0 ,
expand_pixels + mask_blur * 2 if is_top else 0 ,
2022-09-21 15:40:31 +03:00
mask . width - expand_pixels - mask_blur * 2 if is_right else res_w ,
mask . height - expand_pixels - mask_blur * 2 if is_bottom else res_h ,
2022-09-16 22:18:30 +03:00
) , fill = " black " )
p . latent_mask = latent_mask
proc = process_images ( p )
if initial_seed_and_info [ 0 ] is None :
initial_seed_and_info [ 0 ] = proc . seed
initial_seed_and_info [ 1 ] = proc . info
2022-10-20 18:16:09 +03:00
for n in range ( count ) :
output_images [ n ] . paste ( proc . images [ n ] , ( 0 if is_left else output_images [ n ] . width - proc . images [ n ] . width , 0 if is_top else output_images [ n ] . height - proc . images [ n ] . height ) )
output_images [ n ] = output_images [ n ] . crop ( ( 0 , 0 , res_w , res_h ) )
2022-10-20 18:02:32 +03:00
2022-10-20 18:16:09 +03:00
return output_images
2022-09-16 22:18:30 +03:00
2022-10-20 17:31:09 +03:00
batch_count = p . n_iter
2022-10-20 18:02:32 +03:00
batch_size = p . batch_size
2022-10-20 17:31:09 +03:00
p . n_iter = 1
2022-10-21 17:26:30 +03:00
state . job_count = batch_count * ( ( 1 if left > 0 else 0 ) + ( 1 if right > 0 else 0 ) + ( 1 if up > 0 else 0 ) + ( 1 if down > 0 else 0 ) )
2022-10-20 18:02:32 +03:00
all_processed_images = [ ]
2022-10-20 17:31:09 +03:00
for i in range ( batch_count ) :
2022-10-20 18:02:32 +03:00
imgs = [ init_img ] * batch_size
state . job = f " Batch { i + 1 } out of { batch_count } "
2022-10-20 17:31:09 +03:00
if left > 0 :
2022-10-20 18:02:32 +03:00
imgs = expand ( imgs , batch_size , left , is_left = True )
2022-10-20 17:31:09 +03:00
if right > 0 :
2022-10-20 18:02:32 +03:00
imgs = expand ( imgs , batch_size , right , is_right = True )
2022-10-20 17:31:09 +03:00
if up > 0 :
2022-10-20 18:02:32 +03:00
imgs = expand ( imgs , batch_size , up , is_top = True )
2022-10-20 17:31:09 +03:00
if down > 0 :
2022-10-20 18:02:32 +03:00
imgs = expand ( imgs , batch_size , down , is_bottom = True )
2022-10-20 17:31:09 +03:00
2022-10-20 18:02:32 +03:00
all_processed_images + = imgs
all_images = all_processed_images
2022-10-20 17:31:09 +03:00
2022-10-21 04:38:24 +03:00
combined_grid_image = images . image_grid ( all_processed_images )
unwanted_grid_because_of_img_count = len ( all_processed_images ) < 2 and opts . grid_only_if_multiple
if opts . return_grid and not unwanted_grid_because_of_img_count :
2022-10-20 18:02:32 +03:00
all_images = [ combined_grid_image ] + all_processed_images
2022-10-20 17:31:09 +03:00
res = Processed ( p , all_images , initial_seed_and_info [ 0 ] , initial_seed_and_info [ 1 ] )
2022-09-16 22:18:30 +03:00
if opts . samples_save :
2022-10-20 18:02:32 +03:00
for img in all_processed_images :
images . save_image ( img , p . outpath_samples , " " , res . seed , p . prompt , opts . grid_format , info = res . info , p = p )
2022-09-16 22:18:30 +03:00
2022-10-21 04:38:24 +03:00
if opts . grid_save and not unwanted_grid_because_of_img_count :
2022-10-20 17:31:09 +03:00
images . save_image ( combined_grid_image , p . outpath_grids , " grid " , res . seed , p . prompt , opts . grid_format , info = res . info , short_filename = not opts . grid_extended_filename , grid = True , p = p )
2022-09-16 22:18:30 +03:00
return res