mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-09-11 21:27:37 +03:00
python: improve handling of incomplete downloads (#2152)
* make sure encoding is identity for Range requests * use a .part file for partial downloads * verify using file size and MD5 from models3.json Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
b4bcc5b37c
commit
71d7f34d1a
@ -299,7 +299,7 @@ class LLModel:
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, text: str, prefix: str, dimensionality: int, do_mean: bool, count_tokens: bool, atlas: bool,
|
self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool,
|
||||||
) -> EmbedResult[list[float]]: ...
|
) -> EmbedResult[list[float]]: ...
|
||||||
@overload
|
@overload
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
|
@ -3,6 +3,7 @@ Python only API for running all GPT4All models.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
@ -10,7 +11,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, overload
|
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from requests.exceptions import ChunkedEncodingError
|
from requests.exceptions import ChunkedEncodingError
|
||||||
@ -21,14 +22,17 @@ from . import _pyllmodel
|
|||||||
from ._pyllmodel import EmbedResult as EmbedResult
|
from ._pyllmodel import EmbedResult as EmbedResult
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
if sys.platform == 'darwin':
|
||||||
|
import fcntl
|
||||||
|
|
||||||
# TODO: move to config
|
# TODO: move to config
|
||||||
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all"
|
||||||
|
|
||||||
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
|
DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n"
|
||||||
|
|
||||||
ConfigType: TypeAlias = 'dict[str, str]'
|
ConfigType: TypeAlias = 'dict[str, Any]'
|
||||||
MessageType: TypeAlias = 'dict[str, str]'
|
MessageType: TypeAlias = 'dict[str, str]'
|
||||||
|
|
||||||
|
|
||||||
@ -260,7 +264,11 @@ class GPT4All:
|
|||||||
print(f"Found model file at {str(model_dest)!r}", file=sys.stderr)
|
print(f"Found model file at {str(model_dest)!r}", file=sys.stderr)
|
||||||
elif allow_download:
|
elif allow_download:
|
||||||
# If model file does not exist, download
|
# If model file does not exist, download
|
||||||
config["path"] = str(cls.download_model(model_filename, model_path, verbose=verbose, url=config.get("url")))
|
filesize = config.get("filesize")
|
||||||
|
config["path"] = str(cls.download_model(
|
||||||
|
model_filename, model_path, verbose=verbose, url=config.get("url"),
|
||||||
|
expected_size=None if filesize is None else int(filesize), expected_md5=config.get("md5sum"),
|
||||||
|
))
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Model file does not exist: {model_dest!r}")
|
raise FileNotFoundError(f"Model file does not exist: {model_dest!r}")
|
||||||
|
|
||||||
@ -272,6 +280,8 @@ class GPT4All:
|
|||||||
model_path: str | os.PathLike[str],
|
model_path: str | os.PathLike[str],
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
|
expected_size: int | None = None,
|
||||||
|
expected_md5: str | None = None,
|
||||||
) -> str | os.PathLike[str]:
|
) -> str | os.PathLike[str]:
|
||||||
"""
|
"""
|
||||||
Download model from https://gpt4all.io.
|
Download model from https://gpt4all.io.
|
||||||
@ -281,13 +291,14 @@ class GPT4All:
|
|||||||
model_path: Path to download model to.
|
model_path: Path to download model to.
|
||||||
verbose: If True (default), print debug messages.
|
verbose: If True (default), print debug messages.
|
||||||
url: the models remote url (e.g. may be hosted on HF)
|
url: the models remote url (e.g. may be hosted on HF)
|
||||||
|
expected_size: The expected size of the download.
|
||||||
|
expected_md5: The expected MD5 hash of the download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model file destination.
|
Model file destination.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Download model
|
# Download model
|
||||||
download_path = Path(model_path) / model_filename
|
|
||||||
if url is None:
|
if url is None:
|
||||||
url = f"https://gpt4all.io/models/gguf/{model_filename}"
|
url = f"https://gpt4all.io/models/gguf/{model_filename}"
|
||||||
|
|
||||||
@ -296,11 +307,14 @@ class GPT4All:
|
|||||||
if offset:
|
if offset:
|
||||||
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
|
||||||
headers['Range'] = f'bytes={offset}-' # resume incomplete response
|
headers['Range'] = f'bytes={offset}-' # resume incomplete response
|
||||||
|
headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges
|
||||||
response = requests.get(url, stream=True, headers=headers)
|
response = requests.get(url, stream=True, headers=headers)
|
||||||
if response.status_code not in (200, 206):
|
if response.status_code not in (200, 206):
|
||||||
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
|
||||||
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
|
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
|
||||||
raise ValueError('Connection was interrupted and server does not support range requests')
|
raise ValueError('Connection was interrupted and server does not support range requests')
|
||||||
|
if (enc := response.headers.get("Content-Encoding")) is not None:
|
||||||
|
raise ValueError(f"Expected identity Content-Encoding, got {enc}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
response = make_request()
|
response = make_request()
|
||||||
@ -308,41 +322,69 @@ class GPT4All:
|
|||||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||||
block_size = 2**20 # 1 MB
|
block_size = 2**20 # 1 MB
|
||||||
|
|
||||||
with open(download_path, "wb") as file, \
|
partial_path = Path(model_path) / (model_filename + ".part")
|
||||||
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
|
|
||||||
|
with open(partial_path, "w+b") as partf:
|
||||||
try:
|
try:
|
||||||
while True:
|
with tqdm(desc="Downloading", total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
|
||||||
last_progress = progress_bar.n
|
while True:
|
||||||
try:
|
last_progress = progress_bar.n
|
||||||
for data in response.iter_content(block_size):
|
try:
|
||||||
file.write(data)
|
for data in response.iter_content(block_size):
|
||||||
progress_bar.update(len(data))
|
partf.write(data)
|
||||||
except ChunkedEncodingError as cee:
|
progress_bar.update(len(data))
|
||||||
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
|
except ChunkedEncodingError as cee:
|
||||||
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
|
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
|
||||||
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
|
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
|
||||||
# the socket was closed during a read - retry
|
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
|
||||||
response = make_request(progress_bar.n)
|
# the socket was closed during a read - retry
|
||||||
continue
|
response = make_request(progress_bar.n)
|
||||||
raise
|
continue
|
||||||
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
|
raise
|
||||||
if progress_bar.n == last_progress:
|
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
|
||||||
raise RuntimeError('Download not making progress, aborting.')
|
if progress_bar.n == last_progress:
|
||||||
# server closed connection prematurely - retry
|
raise RuntimeError("Download not making progress, aborting.")
|
||||||
response = make_request(progress_bar.n)
|
# server closed connection prematurely - retry
|
||||||
continue
|
response = make_request(progress_bar.n)
|
||||||
break
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
# verify file integrity
|
||||||
|
file_size = partf.tell()
|
||||||
|
if expected_size is not None and file_size != expected_size:
|
||||||
|
raise ValueError(f"Expected file size of {expected_size} bytes, got {file_size}")
|
||||||
|
if expected_md5 is not None:
|
||||||
|
partf.seek(0)
|
||||||
|
hsh = hashlib.md5()
|
||||||
|
with tqdm(desc="Verifying", total=file_size, unit="iB", unit_scale=True) as bar:
|
||||||
|
while chunk := partf.read(block_size):
|
||||||
|
hsh.update(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
|
if hsh.hexdigest() != expected_md5.lower():
|
||||||
|
raise ValueError(f"Expected MD5 hash of {expected_md5!r}, got {hsh.hexdigest()!r}")
|
||||||
except Exception:
|
except Exception:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Cleaning up the interrupted download...", file=sys.stderr)
|
print("Cleaning up the interrupted download...", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
os.remove(download_path)
|
os.remove(partial_path)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if os.name == 'nt':
|
# flush buffers and sync the inode
|
||||||
time.sleep(2) # Sleep for a little bit so Windows can remove file lock
|
partf.flush()
|
||||||
|
_fsync(partf)
|
||||||
|
|
||||||
|
# move to final destination
|
||||||
|
download_path = Path(model_path) / model_filename
|
||||||
|
try:
|
||||||
|
os.rename(partial_path, download_path)
|
||||||
|
except FileExistsError:
|
||||||
|
try:
|
||||||
|
os.remove(partial_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
|
print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr)
|
||||||
@ -561,3 +603,19 @@ def append_extension_if_missing(model_name):
|
|||||||
if not model_name.endswith((".bin", ".gguf")):
|
if not model_name.endswith((".bin", ".gguf")):
|
||||||
model_name += ".gguf"
|
model_name += ".gguf"
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
class _HasFileno(Protocol):
|
||||||
|
def fileno(self) -> int: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _fsync(fd: int | _HasFileno) -> None:
|
||||||
|
if sys.platform == 'darwin':
|
||||||
|
# Apple's fsync does not flush the drive write cache
|
||||||
|
try:
|
||||||
|
fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
|
||||||
|
except OSError:
|
||||||
|
pass # fall back to fsync
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
os.fsync(fd)
|
||||||
|
@ -102,7 +102,8 @@ setup(
|
|||||||
'mkdocstrings[python]',
|
'mkdocstrings[python]',
|
||||||
'mkdocs-jupyter',
|
'mkdocs-jupyter',
|
||||||
'black',
|
'black',
|
||||||
'isort'
|
'isort',
|
||||||
|
'typing-extensions>=3.10',
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},
|
package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},
|
||||||
|
Loading…
Reference in New Issue
Block a user