python: improve handling of incomplete downloads (#2152)

* make sure encoding is identity for Range requests
* use a .part file for partial downloads
* verify using file size and MD5 from models3.json

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-03-21 11:33:41 -04:00 committed by GitHub
parent b4bcc5b37c
commit 71d7f34d1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 33 deletions

View File

@ -299,7 +299,7 @@ class LLModel:
@overload
def generate_embeddings(
self, text: str, prefix: str, dimensionality: int, do_mean: bool, count_tokens: bool, atlas: bool,
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
) -> EmbedResult[list[float]]: ...
@overload
def generate_embeddings(

View File

@ -3,6 +3,7 @@ Python only API for running all GPT4All models.
"""
from __future__ import annotations
import hashlib
import os
import re
import sys
@ -10,7 +11,7 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Literal, overload
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
import requests
from requests.exceptions import ChunkedEncodingError
@ -21,14 +22,17 @@ from . import _pyllmodel
from ._pyllmodel import EmbedResult as EmbedResult
if TYPE_CHECKING:
from typing import TypeAlias
from typing_extensions import TypeAlias
if sys.platform == 'darwin':
import fcntl
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
ConfigType: TypeAlias = 'dict[str, str]'
ConfigType: TypeAlias = 'dict[str, Any]'
MessageType: TypeAlias = 'dict[str, str]'
@ -260,7 +264,11 @@ class GPT4All:
print(f"Found model file at {str(model_dest)!r}", file=sys.stderr)
elif allow_download:
# If model file does not exist, download
config["path"] = str(cls.download_model(model_filename, model_path, verbose=verbose, url=config.get("url")))
filesize = config.get("filesize")
config["path"] = str(cls.download_model(
model_filename, model_path, verbose=verbose, url=config.get("url"),
expected_size=None if filesize is None else int(filesize), expected_md5=config.get("md5sum"),
))
else:
raise FileNotFoundError(f"Model file does not exist: {model_dest!r}")
@ -272,6 +280,8 @@ class GPT4All:
model_path: str | os.PathLike[str],
verbose: bool = True,
url: str | None = None,
expected_size: int | None = None,
expected_md5: str | None = None,
) -> str | os.PathLike[str]:
"""
Download model from https://gpt4all.io.
@ -281,13 +291,14 @@ class GPT4All:
model_path: Path to download model to.
verbose: If True (default), print debug messages.
url: the models remote url (e.g. may be hosted on HF)
expected_size: The expected size of the download.
expected_md5: The expected MD5 hash of the download.
Returns:
Model file destination.
"""
# Download model
download_path = Path(model_path) / model_filename
if url is None:
url = f"https://gpt4all.io/models/gguf/{model_filename}"
@ -296,11 +307,14 @@ class GPT4All:
if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
response = requests.get(url, stream=True, headers=headers)
if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
raise ValueError('Connection was interrupted and server does not support range requests')
if (enc := response.headers.get("Content-Encoding")) is not None:
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
return response
response = make_request()
@ -308,41 +322,69 @@ class GPT4All:
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 2**20 # 1 MB
with open(download_path, "wb") as file, \
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
partial_path = Path(model_path) / (model_filename + ".part")
with open(partial_path, "w+b") as partf:
try:
while True:
last_progress = progress_bar.n
try:
for data in response.iter_content(block_size):
file.write(data)
progress_bar.update(len(data))
except ChunkedEncodingError as cee:
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
# the socket was closed during a read - retry
response = make_request(progress_bar.n)
continue
raise
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
if progress_bar.n == last_progress:
raise RuntimeError('Download not making progress, aborting.')
# server closed connection prematurely - retry
response = make_request(progress_bar.n)
continue
break
with tqdm(desc="Downloading", total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
while True:
last_progress = progress_bar.n
try:
for data in response.iter_content(block_size):
partf.write(data)
progress_bar.update(len(data))
except ChunkedEncodingError as cee:
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
# the socket was closed during a read - retry
response = make_request(progress_bar.n)
continue
raise
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
if progress_bar.n == last_progress:
raise RuntimeError("Download not making progress, aborting.")
# server closed connection prematurely - retry
response = make_request(progress_bar.n)
continue
break
# verify file integrity
file_size = partf.tell()
if expected_size is not None and file_size != expected_size:
raise ValueError(f"Expected file size of {expected_size} bytes, got {file_size}")
if expected_md5 is not None:
partf.seek(0)
hsh = hashlib.md5()
with tqdm(desc="Verifying", total=file_size, unit="iB", unit_scale=True) as bar:
while chunk := partf.read(block_size):
hsh.update(chunk)
bar.update(len(chunk))
if hsh.hexdigest() != expected_md5.lower():
raise ValueError(f"Expected MD5 hash of {expected_md5!r}, got {hsh.hexdigest()!r}")
except Exception:
if verbose:
print("Cleaning up the interrupted download...", file=sys.stderr)
try:
os.remove(download_path)
os.remove(partial_path)
except OSError:
pass
raise
if os.name == 'nt':
time.sleep(2) # Sleep for a little bit so Windows can remove file lock
# flush buffers and sync the inode
partf.flush()
_fsync(partf)
# move to final destination
download_path = Path(model_path) / model_filename
try:
os.rename(partial_path, download_path)
except FileExistsError:
try:
os.remove(partial_path)
except OSError:
pass
raise
if verbose:
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
@ -561,3 +603,19 @@ def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")):
model_name += ".gguf"
return model_name
class _HasFileno(Protocol):
def fileno(self) -> int: ...
def _fsync(fd: int | _HasFileno) -> None:
if sys.platform == 'darwin':
# Apple's fsync does not flush the drive write cache
try:
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
except OSError:
pass # fall back to fsync
else:
return
os.fsync(fd)

View File

@ -102,7 +102,8 @@ setup(
'mkdocstrings[python]',
'mkdocs-jupyter',
'black',
'isort'
'isort',
'typing-extensions>=3.10',
]
},
package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},