Merge pull request #1150 from ZeroCool940711/dev

Added the Home tab and gallery tab on txt2img made by @devilismyfriend
This commit is contained in:
ZeroCool 2022-09-15 06:31:38 -07:00 committed by GitHub
commit b7d6329dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 247 additions and 105 deletions

View File

@ -93,3 +93,19 @@ button[kind="header"] {
/***********************************************************
* Additional CSS for other elements
************************************************************/
button[data-baseweb="tab"] {
font-size: 20px;
}
@media (min-width: 1200px){
h1 {
font-size: 1.75rem;
}
}
#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 {
width: 50rem;
align-self: center;
}
div.gallery:hover {
border: 1px solid #777;
}

View File

@ -1,5 +1,18 @@
from webui_streamlit import st, defaults
# base webui import and utils.
from webui_streamlit import st
from sd_utils import *
# streamlit imports
#other imports
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
import os
from PIL import Image
@ -30,15 +43,17 @@ def getLatestGeneratedImagesFromPath():
files.append(os.path.join(r, file))
#sort the files by date
files.sort(key=os.path.getmtime)
#reverse the list so the latest images are first
for f in files:
img = Image.open(f)
files[files.index(f)] = img
#get the latest 10 files
#get all the files with the .png or .jpg extension
#sort files by date
#get the latest 10 files
latestFiles = files[-10:]
latestFiles = files
#reverse the list
latestFiles.reverse()
return latestFiles
@ -98,41 +113,109 @@ def layout():
# create a tab for the gallery
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
history_tab, discover_tabs, settings_tab = st.tabs(["History","Discover","Settings"])
with discover_tabs:
st.markdown("<h1 style='text-align: center; color: white;'>Soon :)</h1>", unsafe_allow_html=True)
with settings_tab:
st.markdown("<h1 style='text-align: center; color: white;'>Soon :)</h1>", unsafe_allow_html=True)
with history_tab:
placeholder = st.empty()
history_tab, discover_tabs = st.tabs(["History","Discover"])
latestImages = getLatestGeneratedImagesFromPath()
st.session_state['latestImages'] = latestImages
with history_tab:
##---------------------------------------------------------
## image slideshow test
## Number of entries per screen
#slideshow_N = 9
#slideshow_page_number = 0
#slideshow_last_page = len(latestImages) // slideshow_N
## Add a next button and a previous button
#slideshow_prev, slideshow_image_col , slideshow_next = st.columns([1, 10, 1])
#with slideshow_image_col:
#slideshow_image = st.empty()
#slideshow_image.image(st.session_state['latestImages'][0])
#current_image = 0
#if slideshow_next.button("Next", key=1):
##print (current_image+1)
#current_image = current_image+1
#slideshow_image.image(st.session_state['latestImages'][current_image+1])
#if slideshow_prev.button("Previous", key=0):
##print ([current_image-1])
#current_image = current_image-1
#slideshow_image.image(st.session_state['latestImages'][current_image - 1])
#---------------------------------------------------------
placeholder = st.empty()
# image gallery
# Number of entries per screen
gallery_N = 9
gallery_page_number = 0
#gallery_last_page = len(latestImages) // gallery_N
# Add a next button and a previous button
#gallery_prev, gallery_pagination , gallery_next = st.columns([1, 10, 1])
# the pagination doesnt work for now so its better to disable the buttons.
#if gallery_next.button("Next", key=3):
#if gallery_page_number + 1 > gallery_last_page:
#gallery_page_number = 0
#else:
#gallery_page_number += 1
#if gallery_prev.button("Previous", key=2):
#if gallery_page_number - 1 < 0:
#gallery_page_number = gallery_last_page
#else:
#gallery_page_number -= 1
# Get start and end indices of the next page of the dataframe
gallery_start_idx = gallery_page_number * gallery_N
gallery_end_idx = (1 + gallery_page_number) * gallery_N
#---------------------------------------------------------
#populate the 3 images per column
with placeholder.container():
col1, col2, col3 = st.columns(3)
col1_cont = st.container()
col2_cont = st.container()
col3_cont = st.container()
#print (len(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx]))
with col1_cont:
with col1:
st.image(st.session_state['latestImages'][0])
st.image(st.session_state['latestImages'][3])
st.image(st.session_state['latestImages'][6])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][0])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][3])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][6])
with col2_cont:
with col2:
st.image(st.session_state['latestImages'][1])
st.image(st.session_state['latestImages'][4])
st.image(st.session_state['latestImages'][7])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][1])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][4])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][7])
with col3_cont:
with col3:
st.image(st.session_state['latestImages'][2])
st.image(st.session_state['latestImages'][5])
st.image(st.session_state['latestImages'][8])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][2])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][5])
st.image(st.session_state['latestImages'][gallery_start_idx:gallery_end_idx][8])
st.session_state['historyTab'] = [history_tab,col1,col2,col3,placeholder,col1_cont,col2_cont,col3_cont]
with discover_tabs:
st.markdown("<h1 style='text-align: center; color: white;'>Soon :)</h1>", unsafe_allow_html=True)
#display the images
#add a button to the gallery
#st.markdown("<h2 style='text-align: center; color: white;'>Try it out</h2>", unsafe_allow_html=True)
@ -140,3 +223,4 @@ def layout():
#if st.button("Try it out"):
#if the button is clicked, go to the gallery
#st.experimental_rerun()

View File

@ -210,8 +210,8 @@ def layout():
with col3:
# If we have custom models available on the "models/custom"
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
if st.session_state["CustomModel_available"]:
st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"],
if st.session_state.CustomModel_available:
st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models,
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. This option is only available if you have custom models \
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
@ -313,6 +313,7 @@ def layout():
st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
except (StopException, KeyError):
print(f"Received Streamlit StopException")

View File

@ -481,7 +481,7 @@ def txt2vid(
Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
return im, seeds, info, stats
return video_path, seeds, info, stats
#on import run init
def createHTMLGallery(images,info):
@ -679,7 +679,7 @@ def layout():
#load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model)
# run video generation
image, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
@ -691,11 +691,48 @@ def layout():
#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
#except (StopException, KeyError):
#print(f"Received Streamlit StopException")
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
# use the current col2 first tab to show the preview_img and update it as its generated.
#preview_image.image(output_images)

View File

@ -105,8 +105,12 @@ def layout():
iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0)
if tabs =='Stable Diffusion':
txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified",
home_tab, txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Home","Text-to-Image Unified", "Image-to-Image Unified",
"Text-to-Video","Post-Processing"])
with home_tab:
from home import layout
layout()
with txt2img_tab:
from txt2img import layout
layout()