mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-26 22:00:32 +03:00
python: improve handling of incomplete downloads (#2152)
* make sure encoding is identity for Range requests * use a .part file for partial downloads * verify using file size and MD5 from models3.json Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
b4bcc5b37c
commit
71d7f34d1a
@ -299,7 +299,7 @@ class LLModel:
|
||||
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
self, text: str, prefix: str, dimensionality: int, do_mean: bool, count_tokens: bool, atlas: bool,
|
||||
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
|
||||
) -> EmbedResult[list[float]]: ...
|
||||
@overload
|
||||
def generate_embeddings(
|
||||
|
@ -3,6 +3,7 @@ Python only API for running all GPT4All models.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@ -10,7 +11,7 @@ import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, overload
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
@ -21,14 +22,17 @@ from . import _pyllmodel
|
||||
from ._pyllmodel import EmbedResult as EmbedResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
import fcntl
|
||||
|
||||
# TODO: move to config
|
||||
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
|
||||
|
||||
ConfigType: TypeAlias = 'dict[str, str]'
|
||||
ConfigType: TypeAlias = 'dict[str, Any]'
|
||||
MessageType: TypeAlias = 'dict[str, str]'
|
||||
|
||||
|
||||
@ -260,7 +264,11 @@ class GPT4All:
|
||||
print(f"Found model file at {str(model_dest)!r}", file=sys.stderr)
|
||||
elif allow_download:
|
||||
# If model file does not exist, download
|
||||
config["path"] = str(cls.download_model(model_filename, model_path, verbose=verbose, url=config.get("url")))
|
||||
filesize = config.get("filesize")
|
||||
config["path"] = str(cls.download_model(
|
||||
model_filename, model_path, verbose=verbose, url=config.get("url"),
|
||||
expected_size=None if filesize is None else int(filesize), expected_md5=config.get("md5sum"),
|
||||
))
|
||||
else:
|
||||
raise FileNotFoundError(f"Model file does not exist: {model_dest!r}")
|
||||
|
||||
@ -272,6 +280,8 @@ class GPT4All:
|
||||
model_path: str | os.PathLike[str],
|
||||
verbose: bool = True,
|
||||
url: str | None = None,
|
||||
expected_size: int | None = None,
|
||||
expected_md5: str | None = None,
|
||||
) -> str | os.PathLike[str]:
|
||||
"""
|
||||
Download model from https://gpt4all.io.
|
||||
@ -281,13 +291,14 @@ class GPT4All:
|
||||
model_path: Path to download model to.
|
||||
verbose: If True (default), print debug messages.
|
||||
url: the models remote url (e.g. may be hosted on HF)
|
||||
expected_size: The expected size of the download.
|
||||
expected_md5: The expected MD5 hash of the download.
|
||||
|
||||
Returns:
|
||||
Model file destination.
|
||||
"""
|
||||
|
||||
# Download model
|
||||
download_path = Path(model_path) / model_filename
|
||||
if url is None:
|
||||
url = f"https://gpt4all.io/models/gguf/{model_filename}"
|
||||
|
||||
@ -296,11 +307,14 @@ class GPT4All:
|
||||
if offset:
|
||||
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
||||
headers['Range'] = f'bytes={offset}-' # resume incomplete response
|
||||
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
|
||||
response = requests.get(url, stream=True, headers=headers)
|
||||
if response.status_code not in (200, 206):
|
||||
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
||||
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
|
||||
raise ValueError('Connection was interrupted and server does not support range requests')
|
||||
if (enc := response.headers.get("Content-Encoding")) is not None:
|
||||
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
|
||||
return response
|
||||
|
||||
response = make_request()
|
||||
@ -308,41 +322,69 @@ class GPT4All:
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 2**20 # 1 MB
|
||||
|
||||
with open(download_path, "wb") as file, \
|
||||
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
|
||||
partial_path = Path(model_path) / (model_filename + ".part")
|
||||
|
||||
with open(partial_path, "w+b") as partf:
|
||||
try:
|
||||
while True:
|
||||
last_progress = progress_bar.n
|
||||
try:
|
||||
for data in response.iter_content(block_size):
|
||||
file.write(data)
|
||||
progress_bar.update(len(data))
|
||||
except ChunkedEncodingError as cee:
|
||||
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
|
||||
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
|
||||
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
|
||||
# the socket was closed during a read - retry
|
||||
response = make_request(progress_bar.n)
|
||||
continue
|
||||
raise
|
||||
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
|
||||
if progress_bar.n == last_progress:
|
||||
raise RuntimeError('Download not making progress, aborting.')
|
||||
# server closed connection prematurely - retry
|
||||
response = make_request(progress_bar.n)
|
||||
continue
|
||||
break
|
||||
with tqdm(desc="Downloading", total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
|
||||
while True:
|
||||
last_progress = progress_bar.n
|
||||
try:
|
||||
for data in response.iter_content(block_size):
|
||||
partf.write(data)
|
||||
progress_bar.update(len(data))
|
||||
except ChunkedEncodingError as cee:
|
||||
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
|
||||
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
|
||||
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
|
||||
# the socket was closed during a read - retry
|
||||
response = make_request(progress_bar.n)
|
||||
continue
|
||||
raise
|
||||
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
|
||||
if progress_bar.n == last_progress:
|
||||
raise RuntimeError("Download not making progress, aborting.")
|
||||
# server closed connection prematurely - retry
|
||||
response = make_request(progress_bar.n)
|
||||
continue
|
||||
break
|
||||
|
||||
# verify file integrity
|
||||
file_size = partf.tell()
|
||||
if expected_size is not None and file_size != expected_size:
|
||||
raise ValueError(f"Expected file size of {expected_size} bytes, got {file_size}")
|
||||
if expected_md5 is not None:
|
||||
partf.seek(0)
|
||||
hsh = hashlib.md5()
|
||||
with tqdm(desc="Verifying", total=file_size, unit="iB", unit_scale=True) as bar:
|
||||
while chunk := partf.read(block_size):
|
||||
hsh.update(chunk)
|
||||
bar.update(len(chunk))
|
||||
if hsh.hexdigest() != expected_md5.lower():
|
||||
raise ValueError(f"Expected MD5 hash of {expected_md5!r}, got {hsh.hexdigest()!r}")
|
||||
except Exception:
|
||||
if verbose:
|
||||
print("Cleaning up the interrupted download...", file=sys.stderr)
|
||||
try:
|
||||
os.remove(download_path)
|
||||
os.remove(partial_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
if os.name == 'nt':
|
||||
time.sleep(2) # Sleep for a little bit so Windows can remove file lock
|
||||
# flush buffers and sync the inode
|
||||
partf.flush()
|
||||
_fsync(partf)
|
||||
|
||||
# move to final destination
|
||||
download_path = Path(model_path) / model_filename
|
||||
try:
|
||||
os.rename(partial_path, download_path)
|
||||
except FileExistsError:
|
||||
try:
|
||||
os.remove(partial_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
if verbose:
|
||||
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
|
||||
@ -561,3 +603,19 @@ def append_extension_if_missing(model_name):
|
||||
if not model_name.endswith((".bin", ".gguf")):
|
||||
model_name += ".gguf"
|
||||
return model_name
|
||||
|
||||
|
||||
class _HasFileno(Protocol):
|
||||
def fileno(self) -> int: ...
|
||||
|
||||
|
||||
def _fsync(fd: int | _HasFileno) -> None:
|
||||
if sys.platform == 'darwin':
|
||||
# Apple's fsync does not flush the drive write cache
|
||||
try:
|
||||
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
|
||||
except OSError:
|
||||
pass # fall back to fsync
|
||||
else:
|
||||
return
|
||||
os.fsync(fd)
|
||||
|
@ -102,7 +102,8 @@ setup(
|
||||
'mkdocstrings[python]',
|
||||
'mkdocs-jupyter',
|
||||
'black',
|
||||
'isort'
|
||||
'isort',
|
||||
'typing-extensions>=3.10',
|
||||
]
|
||||
},
|
||||
package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},
|
||||
|
Loading…
Reference in New Issue
Block a user