diff --git a/scripts/ModelManager.py b/scripts/ModelManager.py index 6a49ea7..7391b90 100644 --- a/scripts/ModelManager.py +++ b/scripts/ModelManager.py @@ -20,6 +20,7 @@ from sd_utils import * #other imports from requests.auth import HTTPBasicAuth +from requests import HTTPError from stqdm import stqdm # Temp imports @@ -43,11 +44,20 @@ def download_file(file_name, file_path, file_url): ) raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.") - with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token), stream=True) as r: - r.raise_for_status() - with open(os.path.join(file_path, file_name), 'wb') as f: - for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"): - f.write(chunk) + try: + with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token), stream=True) as r: + r.raise_for_status() + with open(os.path.join(file_path, file_name), 'wb') as f: + for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"): + f.write(chunk) + except HTTPError: + if "huggingface.co" in file_url: + if "resolve"in file_url: + repo_url = file_url.split("resolve")[0] + + st.session_state["progress_bar_text"].error( + f"You need to accept the license for the model in order to be able to download it. " + f"Please visit {repo_url} and accept the lincense there, then try again to download the model.") else: print(file_name + ' already exists.')