diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 036fdbd3..ded4ea5e 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -299,7 +299,7 @@ class LLModel: @overload def generate_embeddings( - self, text: str, prefix: str, dimensionality: int, do_mean: bool, count_tokens: bool, atlas: bool, + self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool, ) -> EmbedResult[list[float]]: ... @overload def generate_embeddings( diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index ae8dbfa1..ae4feeb1 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -3,6 +3,7 @@ Python only API for running all GPT4All models. """ from __future__ import annotations +import hashlib import os import re import sys @@ -10,7 +11,7 @@ import time import warnings from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Literal, overload +from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload import requests from requests.exceptions import ChunkedEncodingError @@ -21,14 +22,17 @@ from . import _pyllmodel from ._pyllmodel import EmbedResult as EmbedResult if TYPE_CHECKING: - from typing import TypeAlias + from typing_extensions import TypeAlias + +if sys.platform == 'darwin': + import fcntl # TODO: move to config DEFAULT_MODEL_DIRECTORY = Path.home() / ".cache" / "gpt4all" DEFAULT_PROMPT_TEMPLATE = "### Human:\n{0}\n\n### Assistant:\n" -ConfigType: TypeAlias = 'dict[str, str]' +ConfigType: TypeAlias = 'dict[str, Any]' MessageType: TypeAlias = 'dict[str, str]' @@ -260,7 +264,11 @@ class GPT4All: print(f"Found model file at {str(model_dest)!r}", file=sys.stderr) elif allow_download: # If model file does not exist, download - config["path"] = str(cls.download_model(model_filename, model_path, verbose=verbose, url=config.get("url"))) + filesize = config.get("filesize") + config["path"] = str(cls.download_model( + model_filename, model_path, verbose=verbose, url=config.get("url"), + expected_size=None if filesize is None else int(filesize), expected_md5=config.get("md5sum"), + )) else: raise FileNotFoundError(f"Model file does not exist: {model_dest!r}") @@ -272,6 +280,8 @@ class GPT4All: model_path: str | os.PathLike[str], verbose: bool = True, url: str | None = None, + expected_size: int | None = None, + expected_md5: str | None = None, ) -> str | os.PathLike[str]: """ Download model from https://gpt4all.io. @@ -281,13 +291,14 @@ class GPT4All: model_path: Path to download model to. verbose: If True (default), print debug messages. url: the models remote url (e.g. may be hosted on HF) + expected_size: The expected size of the download. + expected_md5: The expected MD5 hash of the download. Returns: Model file destination. """ # Download model - download_path = Path(model_path) / model_filename if url is None: url = f"https://gpt4all.io/models/gguf/{model_filename}" @@ -296,11 +307,14 @@ class GPT4All: if offset: print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr) headers['Range'] = f'bytes={offset}-' # resume incomplete response + headers["Accept-Encoding"] = "identity" # Content-Encoding changes meaning of ranges response = requests.get(url, stream=True, headers=headers) if response.status_code not in (200, 206): raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')): raise ValueError('Connection was interrupted and server does not support range requests') + if (enc := response.headers.get("Content-Encoding")) is not None: + raise ValueError(f"Expected identity Content-Encoding, got {enc}") return response response = make_request() @@ -308,41 +322,69 @@ class GPT4All: total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 2**20 # 1 MB - with open(download_path, "wb") as file, \ - tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar: + partial_path = Path(model_path) / (model_filename + ".part") + + with open(partial_path, "w+b") as partf: try: - while True: - last_progress = progress_bar.n - try: - for data in response.iter_content(block_size): - file.write(data) - progress_bar.update(len(data)) - except ChunkedEncodingError as cee: - if cee.args and isinstance(pe := cee.args[0], ProtocolError): - if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead): - assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind - # the socket was closed during a read - retry - response = make_request(progress_bar.n) - continue - raise - if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes: - if progress_bar.n == last_progress: - raise RuntimeError('Download not making progress, aborting.') - # server closed connection prematurely - retry - response = make_request(progress_bar.n) - continue - break + with tqdm(desc="Downloading", total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar: + while True: + last_progress = progress_bar.n + try: + for data in response.iter_content(block_size): + partf.write(data) + progress_bar.update(len(data)) + except ChunkedEncodingError as cee: + if cee.args and isinstance(pe := cee.args[0], ProtocolError): + if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead): + assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind + # the socket was closed during a read - retry + response = make_request(progress_bar.n) + continue + raise + if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes: + if progress_bar.n == last_progress: + raise RuntimeError("Download not making progress, aborting.") + # server closed connection prematurely - retry + response = make_request(progress_bar.n) + continue + break + + # verify file integrity + file_size = partf.tell() + if expected_size is not None and file_size != expected_size: + raise ValueError(f"Expected file size of {expected_size} bytes, got {file_size}") + if expected_md5 is not None: + partf.seek(0) + hsh = hashlib.md5() + with tqdm(desc="Verifying", total=file_size, unit="iB", unit_scale=True) as bar: + while chunk := partf.read(block_size): + hsh.update(chunk) + bar.update(len(chunk)) + if hsh.hexdigest() != expected_md5.lower(): + raise ValueError(f"Expected MD5 hash of {expected_md5!r}, got {hsh.hexdigest()!r}") except Exception: if verbose: print("Cleaning up the interrupted download...", file=sys.stderr) try: - os.remove(download_path) + os.remove(partial_path) except OSError: pass raise - if os.name == 'nt': - time.sleep(2) # Sleep for a little bit so Windows can remove file lock + # flush buffers and sync the inode + partf.flush() + _fsync(partf) + + # move to final destination + download_path = Path(model_path) / model_filename + try: + os.rename(partial_path, download_path) + except FileExistsError: + try: + os.remove(partial_path) + except OSError: + pass + raise if verbose: print(f"Model downloaded to {str(download_path)!r}", file=sys.stderr) @@ -561,3 +603,19 @@ def append_extension_if_missing(model_name): if not model_name.endswith((".bin", ".gguf")): model_name += ".gguf" return model_name + + +class _HasFileno(Protocol): + def fileno(self) -> int: ... + + +def _fsync(fd: int | _HasFileno) -> None: + if sys.platform == 'darwin': + # Apple's fsync does not flush the drive write cache + try: + fcntl.fcntl(fd, fcntl.F_FULLFSYNC) + except OSError: + pass # fall back to fsync + else: + return + os.fsync(fd) diff --git a/gpt4all-bindings/python/setup.py b/gpt4all-bindings/python/setup.py index f9952c4b..5d138934 100644 --- a/gpt4all-bindings/python/setup.py +++ b/gpt4all-bindings/python/setup.py @@ -102,7 +102,8 @@ setup( 'mkdocstrings[python]', 'mkdocs-jupyter', 'black', - 'isort' + 'isort', + 'typing-extensions>=3.10', ] }, package_data={'llmodel': [os.path.join(DEST_CLIB_DIRECTORY, "*")]},