python: implement close() and context manager interface (#2177)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-03-28 16:48:07 -04:00 committed by GitHub
parent dddaf49428
commit 3313c7de0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 4 deletions

View File

@ -9,7 +9,7 @@ import sys
import threading
from enum import Enum
from queue import Queue
from typing import Any, Callable, Generic, Iterable, TypeVar, overload
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
@ -200,13 +200,22 @@ class LLModel:
if model is None:
s = err.value
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
self.model = model
self.model: ctypes.c_void_p | None = model
def __del__(self, llmodel=llmodel):
if hasattr(self, 'model'):
self.close()
def close(self) -> None:
if self.model is not None:
llmodel.llmodel_model_destroy(self.model)
self.model = None
def _raise_closed(self) -> NoReturn:
raise ValueError("Attempted operation on a closed LLModel")
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
assert self.model is not None
num_devices = ctypes.c_int32(0)
devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
if not devices_ptr:
@ -214,6 +223,9 @@ class LLModel:
return devices_ptr[:num_devices.value]
def init_gpu(self, device: str):
if self.model is None:
self._raise_closed()
mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl)
if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()):
@ -246,14 +258,21 @@ class LLModel:
-------
True if model loaded successfully, False otherwise
"""
if self.model is None:
self._raise_closed()
return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl)
def set_thread_count(self, n_threads):
if self.model is None:
self._raise_closed()
if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded")
llmodel.llmodel_setThreadCount(self.model, n_threads)
def thread_count(self):
if self.model is None:
self._raise_closed()
if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
@ -322,6 +341,9 @@ class LLModel:
if not text:
raise ValueError("text must not be None or empty")
if self.model is None:
self._raise_closed()
if (single_text := isinstance(text, str)):
text = [text]
@ -387,6 +409,9 @@ class LLModel:
None
"""
if self.model is None:
self._raise_closed()
self.buffer.clear()
self.buff_expecting_cont_bytes = 0
@ -419,6 +444,9 @@ class LLModel:
def prompt_model_streaming(
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]:
if self.model is None:
self._raise_closed()
output_queue: Queue[str | Sentinel] = Queue()
# Put response tokens into an output queue

View File

@ -11,6 +11,7 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
import requests
@ -22,7 +23,7 @@ from . import _pyllmodel
from ._pyllmodel import EmbedResult as EmbedResult
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias
if sys.platform == 'darwin':
import fcntl
@ -54,6 +55,18 @@ class Embed4All:
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs)
def __enter__(self) -> Self:
return self
def __exit__(
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""Delete the model instance and free associated system resources."""
self.gpt4all.close()
# return_dict=False
@overload
def embed(
@ -190,6 +203,18 @@ class GPT4All:
self._history: list[MessageType] | None = None
self._current_prompt_template: str = "{0}"
def __enter__(self) -> Self:
return self
def __exit__(
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""Delete the model instance and free associated system resources."""
self.model.close()
@property
def current_chat_session(self) -> list[MessageType] | None:
return None if self._history is None else list(self._history)

View File

@ -68,7 +68,7 @@ def get_long_description():
setup(
name=package_name,
version="2.3.2",
version="2.3.3",
description="Python bindings for GPT4All",
long_description=get_long_description(),
long_description_content_type="text/markdown",