python: improve handling of incomplete downloads (#2152)

* make sure encoding is identity for Range requests
* use a .part file for partial downloads
* verify using file size and MD5 from models3.json

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-03-21 11:33:41 -04:00 committed by GitHub
parent b4bcc5b37c
commit 71d7f34d1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 33 deletions

View File

@ -299,7 +299,7 @@ class LLModel:
@overload @overload
def generate_embeddings( def generate_embeddings(
self, text: str, prefix: str, dimensionality: int, do_mean: bool, count_tokens: bool, atlas: bool, self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
) -> EmbedResult[list[float]]: ... ) -> EmbedResult[list[float]]: ...
@overload @overload
def generate_embeddings( def generate_embeddings(

View File

@ -3,6 +3,7 @@ Python only API for running all GPT4All models.
""" """
from __future__ import annotations from __future__ import annotations
import hashlib
import os import os
import re import re
import sys import sys
@ -10,7 +11,7 @@ import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Literal, overload from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
import requests import requests
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
@ -21,14 +22,17 @@ from . import _pyllmodel
from ._pyllmodel import EmbedResult as EmbedResult from ._pyllmodel import EmbedResult as EmbedResult
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeAlias from typing_extensions import TypeAlias
if sys.platform == 'darwin':
import fcntl
# TODO: move to config # TODO: move to config
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all" DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n" DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
ConfigType: TypeAlias = 'dict[str, str]' ConfigType: TypeAlias = 'dict[str, Any]'
MessageType: TypeAlias = 'dict[str, str]' MessageType: TypeAlias = 'dict[str, str]'
@ -260,7 +264,11 @@ class GPT4All:
print(f"Found model file at {str(model_dest)!r}", file=sys.stderr) print(f"Found model file at {str(model_dest)!r}", file=sys.stderr)
elif allow_download: elif allow_download:
# If model file does not exist, download # If model file does not exist, download
config["path"] = str(cls.download_model(model_filename, model_path, verbose=verbose, url=config.get("url"))) filesize = config.get("filesize")
config["path"] = str(cls.download_model(
model_filename, model_path, verbose=verbose, url=config.get("url"),
expected_size=None if filesize is None else int(filesize), expected_md5=config.get("md5sum"),
))
else: else:
raise FileNotFoundError(f"Model file does not exist: {model_dest!r}") raise FileNotFoundError(f"Model file does not exist: {model_dest!r}")
@ -272,6 +280,8 @@ class GPT4All:
model_path: str | os.PathLike[str], model_path: str | os.PathLike[str],
verbose: bool = True, verbose: bool = True,
url: str | None = None, url: str | None = None,
expected_size: int | None = None,
expected_md5: str | None = None,
) -> str | os.PathLike[str]: ) -> str | os.PathLike[str]:
""" """
Download model from https://gpt4all.io. Download model from https://gpt4all.io.
@ -281,13 +291,14 @@ class GPT4All:
model_path: Path to download model to. model_path: Path to download model to.
verbose: If True (default), print debug messages. verbose: If True (default), print debug messages.
url: the models remote url (e.g. may be hosted on HF) url: the models remote url (e.g. may be hosted on HF)
expected_size: The expected size of the download.
expected_md5: The expected MD5 hash of the download.
Returns: Returns:
Model file destination. Model file destination.
""" """
# Download model # Download model
download_path = Path(model_path) / model_filename
if url is None: if url is None:
url = f"https://gpt4all.io/models/gguf/{model_filename}" url = f"https://gpt4all.io/models/gguf/{model_filename}"
@ -296,11 +307,14 @@ class GPT4All:
if offset: if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr) print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response headers['Range'] = f'bytes={offset}-' # resume incomplete response
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
response = requests.get(url, stream=True, headers=headers) response = requests.get(url, stream=True, headers=headers)
if response.status_code not in (200, 206): if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')): if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
raise ValueError('Connection was interrupted and server does not support range requests') raise ValueError('Connection was interrupted and server does not support range requests')
if (enc := response.headers.get("Content-Encoding")) is not None:
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
return response return response
response = make_request() response = make_request()
@ -308,41 +322,69 @@ class GPT4All:
total_size_in_bytes = int(response.headers.get("content-length", 0)) total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 2**20 # 1 MB block_size = 2**20 # 1 MB
with open(download_path, "wb") as file, \ partial_path = Path(model_path) / (model_filename + ".part")
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
with open(partial_path, "w+b") as partf:
try: try:
while True: with tqdm(desc="Downloading", total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
last_progress = progress_bar.n while True:
try: last_progress = progress_bar.n
for data in response.iter_content(block_size): try:
file.write(data) for data in response.iter_content(block_size):
progress_bar.update(len(data)) partf.write(data)
except ChunkedEncodingError as cee: progress_bar.update(len(data))
if cee.args and isinstance(pe := cee.args[0], ProtocolError): except ChunkedEncodingError as cee:
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead): if cee.args and isinstance(pe := cee.args[0], ProtocolError):
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
# the socket was closed during a read - retry assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
response = make_request(progress_bar.n) # the socket was closed during a read - retry
continue response = make_request(progress_bar.n)
raise continue
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes: raise
if progress_bar.n == last_progress: if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
raise RuntimeError('Download not making progress, aborting.') if progress_bar.n == last_progress:
# server closed connection prematurely - retry raise RuntimeError("Download not making progress, aborting.")
response = make_request(progress_bar.n) # server closed connection prematurely - retry
continue response = make_request(progress_bar.n)
break continue
break
# verify file integrity
file_size = partf.tell()
if expected_size is not None and file_size != expected_size:
raise ValueError(f"Expected file size of {expected_size} bytes, got {file_size}")
if expected_md5 is not None:
partf.seek(0)
hsh = hashlib.md5()
with tqdm(desc="Verifying", total=file_size, unit="iB", unit_scale=True) as bar:
while chunk := partf.read(block_size):
hsh.update(chunk)
bar.update(len(chunk))
if hsh.hexdigest() != expected_md5.lower():
raise ValueError(f"Expected MD5 hash of {expected_md5!r}, got {hsh.hexdigest()!r}")
except Exception: except Exception:
if verbose: if verbose:
print("Cleaning up the interrupted download...", file=sys.stderr) print("Cleaning up the interrupted download...", file=sys.stderr)
try: try:
os.remove(download_path) os.remove(partial_path)
except OSError: except OSError:
pass pass
raise raise
if os.name == 'nt': # flush buffers and sync the inode
time.sleep(2) # Sleep for a little bit so Windows can remove file lock partf.flush()
_fsync(partf)
# move to final destination
download_path = Path(model_path) / model_filename
try:
os.rename(partial_path, download_path)
except FileExistsError:
try:
os.remove(partial_path)
except OSError:
pass
raise
if verbose: if verbose:
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr) print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
@ -561,3 +603,19 @@ def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")): if not model_name.endswith((".bin", ".gguf")):
model_name += ".gguf" model_name += ".gguf"
return model_name return model_name
class _HasFileno(Protocol):
def fileno(self) -> int: ...
def _fsync(fd: int | _HasFileno) -> None:
if sys.platform == 'darwin':
# Apple's fsync does not flush the drive write cache
try:
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
except OSError:
pass # fall back to fsync
else:
return
os.fsync(fd)

View File

@ -102,7 +102,8 @@ setup(
'mkdocstrings[python]', 'mkdocstrings[python]',
'mkdocs-jupyter', 'mkdocs-jupyter',
'black', 'black',
'isort' 'isort',
'typing-extensions>=3.10',
] ]
}, },
package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]}, package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},