2022-09-26 16:02:48 +03:00
# This file is part of stable-diffusion-webui (https://github.com/sd-webui/stable-diffusion-webui/).
# Copyright 2022 sd-webui team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
2022-10-01 21:14:32 +03:00
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2022-09-18 19:15:05 +03:00
# base webui import and utils.
from sd_utils import *
# streamlit imports
2022-09-21 06:50:29 +03:00
import streamlit . components . v1 as components
2022-09-18 19:15:05 +03:00
2022-09-21 06:50:29 +03:00
2022-09-18 19:15:05 +03:00
class plugin_info ( ) :
plugname = " concept_library "
description = " Concept Library "
displayPriority = 4
2022-09-24 02:20:51 +03:00
# Init Vuejs component
_component_func = components . declare_component (
2022-10-01 21:14:32 +03:00
" sd-concepts-browser " , " ./frontend/dists/concept-browser/dist " )
2022-09-24 02:20:51 +03:00
2022-09-21 06:50:29 +03:00
def sdConceptsBrowser ( concepts , key = None ) :
component_value = _component_func ( concepts = concepts , key = key , default = " " )
return component_value
2022-09-24 02:20:51 +03:00
2022-09-25 10:03:05 +03:00
@st.experimental_memo ( persist = " disk " , show_spinner = False , suppress_st_warning = True )
2022-09-24 02:20:51 +03:00
def getConceptsFromPath ( page , conceptPerPage , searchText = " " ) :
2022-09-21 06:50:29 +03:00
#print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText)
# get the path where the concepts are stored
2022-09-24 02:20:51 +03:00
path = os . path . join (
os . getcwd ( ) , st . session_state [ ' defaults ' ] . general . sd_concepts_library_folder )
2022-09-21 06:50:29 +03:00
acceptedExtensions = ( ' jpeg ' , ' jpg ' , " png " )
concepts = [ ]
2022-09-22 23:09:42 +03:00
2022-09-21 08:15:49 +03:00
if os . path . exists ( path ) :
# List all folders (concepts) in the path
2022-09-24 02:20:51 +03:00
folders = [ f for f in os . listdir (
path ) if os . path . isdir ( os . path . join ( path , f ) ) ]
2022-09-21 08:15:49 +03:00
filteredFolders = folders
2022-09-22 23:09:42 +03:00
2022-09-21 08:15:49 +03:00
# Filter the folders by the search text
if searchText != " " :
2022-09-24 02:20:51 +03:00
filteredFolders = [
f for f in folders if searchText . lower ( ) in f . lower ( ) ]
2022-09-21 08:15:49 +03:00
else :
filteredFolders = [ ]
2022-09-21 06:50:29 +03:00
conceptIndex = 1
for folder in filteredFolders :
# handle pagination
if conceptIndex > ( page * conceptPerPage ) :
continue
if conceptIndex < = ( ( page - 1 ) * conceptPerPage ) :
conceptIndex + = 1
continue
concept = {
" name " : folder ,
" token " : " < " + folder + " > " ,
" images " : [ ] ,
" type " : " "
}
# type of concept is inside type_of_concept.txt
typePath = os . path . join ( path , folder , " type_of_concept.txt " )
binFile = os . path . join ( path , folder , " learned_embeds.bin " )
# Continue if the concept is not valid or the download has failed (no type_of_concept.txt or no binFile)
if not os . path . exists ( typePath ) or not os . path . exists ( binFile ) :
continue
with open ( typePath , " r " ) as f :
concept [ " type " ] = f . read ( )
# List all files in the concept/concept_images folder
2022-09-24 02:20:51 +03:00
files = [ f for f in os . listdir ( os . path . join ( path , folder , " concept_images " ) ) if os . path . isfile (
os . path . join ( path , folder , " concept_images " , f ) ) ]
2022-09-21 06:50:29 +03:00
# Retrieve only the 4 first images
2022-09-25 22:23:19 +03:00
for file in files :
# Skip if we already have 4 images
if len ( concept [ " images " ] ) > = 4 :
break
2022-09-21 06:50:29 +03:00
if file . endswith ( acceptedExtensions ) :
2022-09-25 22:23:19 +03:00
try :
# Add a copy of the image to avoid file locking
originalImage = Image . open ( os . path . join (
path , folder , " concept_images " , file ) )
2022-09-21 06:50:29 +03:00
2022-09-25 22:23:19 +03:00
# Maintain the aspect ratio (max 200x200)
resizedImage = originalImage . copy ( )
resizedImage . thumbnail ( ( 200 , 200 ) , Image . ANTIALIAS )
2022-09-21 06:50:29 +03:00
2022-09-25 22:23:19 +03:00
# concept["images"].append(resizedImage)
2022-09-21 06:50:29 +03:00
2022-09-25 22:23:19 +03:00
concept [ " images " ] . append ( imageToBase64 ( resizedImage ) )
# Close original image
originalImage . close ( )
except :
print ( " Error while loading image " , file , " in concept " , folder , " (The file may be corrupted). Skipping it. " )
2022-09-21 06:50:29 +03:00
concepts . append ( concept )
conceptIndex + = 1
# print all concepts name
#print("Results:", [c["name"] for c in concepts])
return concepts
@st.cache ( persist = True , allow_output_mutation = True , show_spinner = False , suppress_st_warning = True )
def imageToBase64 ( image ) :
import io
import base64
buffered = io . BytesIO ( )
image . save ( buffered , format = " PNG " )
img_str = base64 . b64encode ( buffered . getvalue ( ) ) . decode ( " utf-8 " )
return img_str
2022-09-24 02:20:51 +03:00
2022-09-25 10:03:05 +03:00
@st.experimental_memo ( persist = " disk " , show_spinner = False , suppress_st_warning = True )
2022-09-24 02:20:51 +03:00
def getTotalNumberOfConcepts ( searchText = " " ) :
2022-09-21 06:50:29 +03:00
# get the path where the concepts are stored
2022-09-24 02:20:51 +03:00
path = os . path . join (
os . getcwd ( ) , st . session_state [ ' defaults ' ] . general . sd_concepts_library_folder )
2022-09-21 06:50:29 +03:00
concepts = [ ]
2022-09-22 23:09:42 +03:00
2022-09-21 08:15:49 +03:00
if os . path . exists ( path ) :
# List all folders (concepts) in the path
2022-09-24 02:20:51 +03:00
folders = [ f for f in os . listdir (
path ) if os . path . isdir ( os . path . join ( path , f ) ) ]
2022-09-21 08:15:49 +03:00
filteredFolders = folders
2022-09-22 23:09:42 +03:00
2022-09-21 08:15:49 +03:00
# Filter the folders by the search text
if searchText != " " :
2022-09-24 02:20:51 +03:00
filteredFolders = [
f for f in folders if searchText . lower ( ) in f . lower ( ) ]
2022-09-21 08:15:49 +03:00
else :
filteredFolders = [ ]
2022-09-21 06:50:29 +03:00
return len ( filteredFolders )
2022-09-18 19:15:05 +03:00
2022-09-21 06:50:29 +03:00
2022-09-24 02:20:51 +03:00
def layout ( ) :
# 2 tabs, one for Concept Library and one for the Download Manager
tab_library , tab_downloader = st . tabs ( [ " Library " , " Download Manager " ] )
# Concept Library
with tab_library :
downloaded_concepts_count = getTotalNumberOfConcepts ( )
2022-09-25 11:17:14 +03:00
concepts_per_page = st . session_state [ " defaults " ] . concepts_library . concepts_per_page
2022-09-24 02:20:51 +03:00
if not " results " in st . session_state :
st . session_state [ " results " ] = getConceptsFromPath ( 1 , concepts_per_page , " " )
# Pagination controls
if not " cl_current_page " in st . session_state :
st . session_state [ " cl_current_page " ] = 1
# Search
if not ' cl_search_text ' in st . session_state :
st . session_state [ " cl_search_text " ] = " "
if not ' cl_search_results_count ' in st . session_state :
st . session_state [ " cl_search_results_count " ] = downloaded_concepts_count
# Search bar
2022-09-25 22:23:19 +03:00
_search_col , _refresh_col = st . columns ( [ 10 , 2 ] )
with _search_col :
search_text_input = st . text_input ( " Search " , " " , placeholder = f ' Search for a concept ( { downloaded_concepts_count } available) ' , label_visibility = " hidden " )
if search_text_input != st . session_state [ " cl_search_text " ] :
# Search text has changed
st . session_state [ " cl_search_text " ] = search_text_input
st . session_state [ " cl_current_page " ] = 1
st . session_state [ " cl_search_results_count " ] = getTotalNumberOfConcepts ( st . session_state [ " cl_search_text " ] )
st . session_state [ " results " ] = getConceptsFromPath ( 1 , concepts_per_page , st . session_state [ " cl_search_text " ] )
with _refresh_col :
# Super weird fix to align the refresh button with the search bar ( Please streamlit, add css support.. )
_refresh_col . write ( " " )
_refresh_col . write ( " " )
if st . button ( " Refresh concepts " , key = " refresh_concepts " , help = " Refresh the concepts folders. Use this if you have added new concepts manually or deleted some. " ) :
getTotalNumberOfConcepts . clear ( )
getConceptsFromPath . clear ( )
st . experimental_rerun ( )
2022-09-24 02:20:51 +03:00
# Show results
results_empty = st . empty ( )
# Pagination
pagination_empty = st . empty ( )
# Layouts
with pagination_empty :
with st . container ( ) :
if len ( st . session_state [ " results " ] ) > 0 :
last_page = math . ceil ( st . session_state [ " cl_search_results_count " ] / concepts_per_page )
_1 , _2 , _3 , _4 , _previous_page , _current_page , _next_page , _9 , _10 , _11 , _12 = st . columns ( [ 1 , 1 , 1 , 1 , 1 , 2 , 1 , 1 , 1 , 1 , 1 ] )
# Previous page
with _previous_page :
2022-09-25 11:17:14 +03:00
if st . button ( " Previous " , key = " cl_previous_page " ) :
2022-09-24 02:20:51 +03:00
st . session_state [ " cl_current_page " ] - = 1
if st . session_state [ " cl_current_page " ] < = 0 :
st . session_state [ " cl_current_page " ] = last_page
st . session_state [ " results " ] = getConceptsFromPath ( st . session_state [ " cl_current_page " ] , concepts_per_page , st . session_state [ " cl_search_text " ] )
# Current page
with _current_page :
_current_page_container = st . empty ( )
# Next page
with _next_page :
2022-09-25 11:17:14 +03:00
if st . button ( " Next " , key = " cl_next_page " ) :
2022-09-24 02:20:51 +03:00
st . session_state [ " cl_current_page " ] + = 1
if st . session_state [ " cl_current_page " ] > last_page :
st . session_state [ " cl_current_page " ] = 1
st . session_state [ " results " ] = getConceptsFromPath ( st . session_state [ " cl_current_page " ] , concepts_per_page , st . session_state [ " cl_search_text " ] )
# Current page
with _current_page_container :
st . markdown ( f ' <p style= " text-align: center " >Page { st . session_state [ " cl_current_page " ] } of { last_page } </p> ' , unsafe_allow_html = True )
# st.write(f"Page {st.session_state['cl_current_page']} of {last_page}", key="cl_current_page")
with results_empty :
with st . container ( ) :
if downloaded_concepts_count == 0 :
st . write ( " You don ' t have any concepts in your library " )
st . markdown ( " To add concepts to your library, download some from the [sd-concepts-library](https://github.com/sd-webui/sd-concepts-library) \
repository and save the content of ` sd - concepts - library ` into ` ` ` . / models / custom / sd - concepts - library ` ` ` or just create your own concepts : wink : . " , unsafe_allow_html=False)
else :
if len ( st . session_state [ " results " ] ) == 0 :
st . write ( " No concept found in the library matching your search: " + st . session_state [ " cl_search_text " ] )
else :
# display number of results
if st . session_state [ " cl_search_text " ] :
st . write ( f " Found { st . session_state [ ' cl_search_results_count ' ] } { ' concepts ' if st . session_state [ ' cl_search_results_count ' ] > 1 else ' concept ' } matching your search " )
sdConceptsBrowser ( st . session_state [ ' results ' ] , key = " results " )
with tab_downloader :
st . write ( " Not implemented yet " )
2022-09-21 06:50:29 +03:00
return False