Automatic torch install for amd on linux

This commit allows the launch script to automatically download rocm's torch version for AMD GPUs using an external GPU detection script. It also prints the operative system and GPU in use.
This commit is contained in:
DaniAndTheWeb 2023-01-13 19:22:23 +01:00 committed by GitHub
parent eaebcf6383
commit a407c9f014
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,6 +7,7 @@ import shlex
import platform import platform
import argparse import argparse
import json import json
import detection
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions" dir_extensions = "extensions"
@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None stored_commit_hash = None
# Get the GPU vendor and the operating system
gpu = detection.check_gpu()
if os.name == "posix":
os_name = platform.uname().system
else:
os_name = os.name
def commit_hash(): def commit_hash():
global stored_commit_hash global stored_commit_hash
@ -173,7 +180,11 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") if gpu == "AMD" and os_name !="nt":
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2")
else:
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@ -295,6 +306,8 @@ def tests(test_dir):
def start(): def start():
print(f"Operating System: {os_name}")
print(f"GPU: {gpu}")
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui import webui
if '--nowebui' in sys.argv: if '--nowebui' in sys.argv: