Added message to tell the user to accept the license on the huggingface site if they havent done so. (#1573)

This commit is contained in:
Alejandro Gil 2022-10-22 10:58:06 -07:00 committed by GitHub
commit d8bb0c4121
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,6 +20,7 @@ from sd_utils import *
#other imports
from requests.auth import HTTPBasicAuth
from requests import HTTPError
from stqdm import stqdm
# Temp imports
@ -43,11 +44,20 @@ def download_file(file_name, file_path, file_url):
)
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token), stream=True) as r:
r.raise_for_status()
with open(os.path.join(file_path, file_name), 'wb') as f:
for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
f.write(chunk)
try:
with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token), stream=True) as r:
r.raise_for_status()
with open(os.path.join(file_path, file_name), 'wb') as f:
for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
f.write(chunk)
except HTTPError:
if "huggingface.co" in file_url:
if "resolve"in file_url:
repo_url = file_url.split("resolve")[0]
st.session_state["progress_bar_text"].error(
f"You need to accept the license for the model in order to be able to download it. "
f"Please visit {repo_url} and accept the lincense there, then try again to download the model.")
else:
print(file_name + ' already exists.')