2022-10-24 03:31:41 +03:00
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
2022-09-26 16:02:48 +03:00
2022-10-24 03:17:50 +03:00
# Copyright 2022 Sygil-Dev team.
2022-09-26 16:02:48 +03:00
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
2022-10-20 23:17:06 +03:00
# along with this program. If not, see <http://www.gnu.org/licenses/>.
2022-09-14 14:19:24 +03:00
# base webui import and utils.
2022-11-03 10:04:32 +03:00
from sd_utils import st , logger
2022-09-14 14:19:24 +03:00
# streamlit imports
#other imports
2022-11-03 10:04:32 +03:00
import os , requests
2022-10-21 05:50:40 +03:00
from requests . auth import HTTPBasicAuth
2022-10-22 20:57:22 +03:00
from requests import HTTPError
2022-10-21 05:50:40 +03:00
from stqdm import stqdm
2022-09-14 14:19:24 +03:00
2022-10-20 23:17:06 +03:00
# Temp imports
2022-09-14 14:19:24 +03:00
# end of imports
#---------------------------------------------------------------------------------------------------------------
2022-10-04 17:25:47 +03:00
def download_file ( file_name , file_path , file_url ) :
if not os . path . exists ( file_path ) :
os . makedirs ( file_path )
2022-10-20 23:17:06 +03:00
2022-10-06 09:48:18 +03:00
if not os . path . exists ( os . path . join ( file_path , file_name ) ) :
2022-10-04 17:25:47 +03:00
print ( ' Downloading ' + file_name + ' ... ' )
# TODO - add progress bar in streamlit
2022-10-05 01:08:24 +03:00
# download file with `requests``
2022-10-21 05:50:40 +03:00
if file_name == " Stable Diffusion v1.5 " :
if " huggingface_token " not in st . session_state or st . session_state [ " defaults " ] . general . huggingface_token == " None " :
if " progress_bar_text " in st . session_state :
st . session_state [ " progress_bar_text " ] . error (
" You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token. "
)
raise OSError ( " You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token. " )
2022-10-22 20:57:22 +03:00
try :
2022-10-27 03:21:21 +03:00
with requests . get ( file_url , auth = HTTPBasicAuth ( ' token ' , st . session_state . defaults . general . huggingface_token ) if " huggingface.co " in file_url else None , stream = True ) as r :
2022-10-22 20:57:22 +03:00
r . raise_for_status ( )
with open ( os . path . join ( file_path , file_name ) , ' wb ' ) as f :
for chunk in stqdm ( r . iter_content ( chunk_size = 8192 ) , backend = True , unit = " kb " ) :
f . write ( chunk )
2022-10-27 03:21:21 +03:00
except HTTPError as e :
2022-10-22 20:57:22 +03:00
if " huggingface.co " in file_url :
if " resolve " in file_url :
repo_url = file_url . split ( " resolve " ) [ 0 ]
st . session_state [ " progress_bar_text " ] . error (
f " You need to accept the license for the model in order to be able to download it. "
f " Please visit { repo_url } and accept the lincense there, then try again to download the model. " )
2022-10-05 01:08:24 +03:00
2022-10-27 03:21:21 +03:00
logger . error ( e )
2022-10-04 17:25:47 +03:00
else :
print ( file_name + ' already exists. ' )
2022-10-27 03:21:21 +03:00
2022-10-04 17:25:47 +03:00
def download_model ( models , model_name ) :
""" Download all files from model_list[model_name] """
for file in models [ model_name ] :
download_file ( file [ ' file_name ' ] , file [ ' file_path ' ] , file [ ' file_url ' ] )
return
2022-09-14 14:19:24 +03:00
2022-09-14 00:08:40 +03:00
def layout ( ) :
#search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
2022-10-20 23:17:06 +03:00
2022-10-21 05:50:40 +03:00
colms = st . columns ( ( 1 , 3 , 3 , 5 , 5 ) )
columns = [ " № " , ' Model Name ' , ' Save Location ' , " Download " , ' Download Link ' ]
2022-10-20 23:17:06 +03:00
2022-10-03 23:14:36 +03:00
models = st . session_state [ " defaults " ] . model_manager . models
2022-09-14 00:08:40 +03:00
for col , field_name in zip ( colms , columns ) :
# table header
col . write ( field_name )
2022-10-20 23:17:06 +03:00
2022-10-03 23:14:36 +03:00
for x , model_name in enumerate ( models ) :
2022-10-21 05:50:40 +03:00
col1 , col2 , col3 , col4 , col5 = st . columns ( ( 1 , 3 , 3 , 3 , 6 ) )
2022-09-14 00:08:40 +03:00
col1 . write ( x ) # index
2022-10-03 23:14:36 +03:00
col2 . write ( models [ model_name ] [ ' model_name ' ] )
col3 . write ( models [ model_name ] [ ' save_location ' ] )
2022-10-04 17:25:47 +03:00
with col4 :
files_exist = 0
for file in models [ model_name ] [ ' files ' ] :
if " save_location " in models [ model_name ] [ ' files ' ] [ file ] :
2022-10-06 09:48:18 +03:00
os . path . exists ( os . path . join ( models [ model_name ] [ ' files ' ] [ file ] [ ' save_location ' ] , models [ model_name ] [ ' files ' ] [ file ] [ ' file_name ' ] ) )
2022-10-04 17:25:47 +03:00
files_exist + = 1
2022-10-06 09:48:18 +03:00
elif os . path . exists ( os . path . join ( models [ model_name ] [ ' save_location ' ] , models [ model_name ] [ ' files ' ] [ file ] [ ' file_name ' ] ) ) :
2022-10-04 17:25:47 +03:00
files_exist + = 1
files_needed = [ ]
for file in models [ model_name ] [ ' files ' ] :
if " save_location " in models [ model_name ] [ ' files ' ] [ file ] :
2022-10-06 09:48:18 +03:00
if not os . path . exists ( os . path . join ( models [ model_name ] [ ' files ' ] [ file ] [ ' save_location ' ] , models [ model_name ] [ ' files ' ] [ file ] [ ' file_name ' ] ) ) :
2022-10-04 17:25:47 +03:00
files_needed . append ( file )
2022-10-06 09:48:18 +03:00
elif not os . path . exists ( os . path . join ( models [ model_name ] [ ' save_location ' ] , models [ model_name ] [ ' files ' ] [ file ] [ ' file_name ' ] ) ) :
2022-10-04 17:25:47 +03:00
files_needed . append ( file )
if len ( files_needed ) > 0 :
if st . button ( ' Download ' , key = models [ model_name ] [ ' model_name ' ] , help = ' Download ' + models [ model_name ] [ ' model_name ' ] ) :
for file in files_needed :
if " save_location " in models [ model_name ] [ ' files ' ] [ file ] :
download_file ( models [ model_name ] [ ' files ' ] [ file ] [ ' file_name ' ] , models [ model_name ] [ ' files ' ] [ file ] [ ' save_location ' ] , models [ model_name ] [ ' files ' ] [ file ] [ ' download_link ' ] )
else :
download_file ( models [ model_name ] [ ' files ' ] [ file ] [ ' file_name ' ] , models [ model_name ] [ ' save_location ' ] , models [ model_name ] [ ' files ' ] [ file ] [ ' download_link ' ] )
2022-10-20 23:17:06 +03:00
st . experimental_rerun ( )
2022-10-04 17:25:47 +03:00
else :
st . empty ( )
else :
2022-10-21 05:50:40 +03:00
st . write ( ' ✅ ' )
#