diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 9c7c236a..277f44e7 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -476,7 +476,9 @@ const std::vector &LLamaModel::endTokens() const bool LLamaModel::shouldAddBOS() const { int add_bos = llama_add_bos_token(d_ptr->model); - return add_bos != -1 ? bool(add_bos) : llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_SPM; + if (add_bos != -1) { return add_bos; } + auto vocab_type = llama_vocab_type(d_ptr->model); + return vocab_type == LLAMA_VOCAB_TYPE_SPM || vocab_type == LLAMA_VOCAB_TYPE_WPM; } int32_t LLamaModel::maxContextLength(std::string const &modelPath) const @@ -638,6 +640,7 @@ static const EmbModelGroup EMBEDDING_MODEL_SPECS[] { {LLM_EMBEDDER_SPEC, {"llm-embedder"}}, {BGE_SPEC, {"bge-small-en", "bge-base-en", "bge-large-en", "bge-small-en-v1.5", "bge-base-en-v1.5", "bge-large-en-v1.5"}}, + // NOTE: E5 Mistral is not yet implemented in llama.cpp, so it's not in EMBEDDING_ARCHES {E5_SPEC, {"e5-small", "e5-base", "e5-large", "e5-small-unsupervised", "e5-base-unsupervised", "e5-large-unsupervised", "e5-small-v2", "e5-base-v2", "e5-large-v2"}}, @@ -658,20 +661,20 @@ static const EmbModelSpec *getEmbedSpec(const std::string &modelName) { } void LLamaModel::embed( - const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality, bool doMean, - bool atlas + const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality, size_t *tokenCount, + bool doMean, bool atlas ) { const EmbModelSpec *spec; std::optional prefix; if (d_ptr->model && (spec = getEmbedSpec(llama_model_name(d_ptr->model)))) prefix = isRetrieval ? spec->queryPrefix : spec->docPrefix; - embed(texts, embeddings, prefix, dimensionality, doMean, atlas); + embed(texts, embeddings, prefix, dimensionality, tokenCount, doMean, atlas); } void LLamaModel::embed( const std::vector &texts, float *embeddings, std::optional prefix, int dimensionality, - bool doMean, bool atlas + size_t *tokenCount, bool doMean, bool atlas ) { if (!d_ptr->model) throw std::logic_error("no model is loaded"); @@ -698,12 +701,9 @@ void LLamaModel::embed( } if (!prefix) { - if (spec) { - prefix = spec->docPrefix; - } else { - std::cerr << __func__ << ": warning: assuming no prefix\n"; - prefix = ""; - } + if (!spec) + throw std::invalid_argument("unknown model "s + modelName + ", specify a prefix if applicable or an empty string"); + prefix = spec->docPrefix; } else if (spec && prefix != spec->docPrefix && prefix != spec->queryPrefix && std::find(spec->otherPrefixes.begin(), spec->otherPrefixes.end(), *prefix) == spec->otherPrefixes.end()) { @@ -712,7 +712,7 @@ void LLamaModel::embed( throw std::invalid_argument(ss.str()); } - embedInternal(texts, embeddings, *prefix, dimensionality, doMean, atlas, spec); + embedInternal(texts, embeddings, *prefix, dimensionality, tokenCount, doMean, atlas, spec); } // MD5 hash of "nomic empty" @@ -730,7 +730,7 @@ double getL2NormScale(T *start, T *end) { void LLamaModel::embedInternal( const std::vector &texts, float *embeddings, std::string prefix, int dimensionality, - bool doMean, bool atlas, const EmbModelSpec *spec + size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec ) { typedef std::vector TokenString; static constexpr int32_t atlasMaxLength = 8192; @@ -796,6 +796,7 @@ void LLamaModel::embedInternal( // split into max_len-sized chunks struct split_batch { unsigned idx; TokenString batch; }; std::vector batches; + size_t totalTokens = 0; for (unsigned i = 0; i < inputs.size(); i++) { auto &input = inputs[i]; for (auto it = input.begin(); it < input.end(); it += max_len) { @@ -805,6 +806,7 @@ void LLamaModel::embedInternal( auto &batch = batches.back().batch; batch = prefixTokens; batch.insert(batch.end(), it, end); + totalTokens += end - it; batch.push_back(eos_token); if (!doMean) { break; /* limit text to one chunk */ } } @@ -889,6 +891,8 @@ void LLamaModel::embedInternal( std::transform(embd, embd_end, embeddings, product(scale)); embeddings += dimensionality; } + + if (tokenCount) { *tokenCount = totalTokens; } } #if defined(_WIN32) diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index cd9dbd57..5cd6394f 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -39,10 +39,10 @@ public: size_t embeddingSize() const override; // user-specified prefix void embed(const std::vector &texts, float *embeddings, std::optional prefix, - int dimensionality = -1, bool doMean = true, bool atlas = false) override; + int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; // automatic prefix void embed(const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality = -1, - bool doMean = true, bool atlas = false) override; + size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; private: std::unique_ptr d_ptr; @@ -61,7 +61,7 @@ protected: int32_t layerCount(std::string const &modelPath) const override; void embedInternal(const std::vector &texts, float *embeddings, std::string prefix, int dimensionality, - bool doMean, bool atlas, const EmbModelSpec *spec); + size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec); }; #endif // LLAMAMODEL_H diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index ac7f2055..17abb80f 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -110,10 +110,10 @@ public: } // user-specified prefix virtual void embed(const std::vector &texts, float *embeddings, std::optional prefix, - int dimensionality = -1, bool doMean = true, bool atlas = false); + int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false); // automatic prefix virtual void embed(const std::vector &texts, float *embeddings, bool isRetrieval, - int dimensionality = -1, bool doMean = true, bool atlas = false); + int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false); virtual void setThreadCount(int32_t n_threads) { (void)n_threads; } virtual int32_t threadCount() const { return 1; } diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index e0809f1d..950a7320 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -158,7 +158,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, float *llmodel_embed( llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality, - bool do_mean, bool atlas, const char **error + size_t *token_count, bool do_mean, bool atlas, const char **error ) { auto *wrapper = static_cast(model); @@ -184,7 +184,7 @@ float *llmodel_embed( if (prefix) { prefixStr = prefix; } embedding = new float[embd_size]; - wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, do_mean, atlas); + wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas); } catch (std::exception const &e) { llmodel_set_error(error, e.what()); return nullptr; diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 913834ec..e26722ca 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -193,6 +193,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, * @param prefix The model-specific prefix representing the embedding task, without the trailing colon. NULL for no * prefix. * @param dimensionality The embedding dimension, for use with Matryoshka-capable models. Set to -1 to for full-size. + * @param token_count Return location for the number of prompt tokens processed, or NULL. * @param do_mean True to average multiple embeddings if the text is longer than the model can accept, False to * truncate. * @param atlas Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with @@ -202,7 +203,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, * be responsible for lifetime of this memory. NULL if an error occurred. */ float *llmodel_embed(llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, - int dimensionality, bool do_mean, bool atlas, const char **error); + int dimensionality, size_t *token_count, bool do_mean, bool atlas, const char **error); /** * Frees the memory allocated by the llmodel_embedding function. diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index 6cc7e905..3f2b23ea 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -270,25 +270,27 @@ void LLModel::generateResponse(std::function void LLModel::embed( const std::vector &texts, float *embeddings, std::optional prefix, int dimensionality, - bool doMean, bool atlas + size_t *tokenCount, bool doMean, bool atlas ) { (void)texts; (void)embeddings; (void)prefix; (void)dimensionality; + (void)tokenCount; (void)doMean; (void)atlas; throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); } void LLModel::embed( - const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality, bool doMean, - bool atlas + const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality, size_t *tokenCount, + bool doMean, bool atlas ) { (void)texts; (void)embeddings; (void)isRetrieval; (void)dimensionality; + (void)tokenCount; (void)doMean; (void)atlas; throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); diff --git a/gpt4all-bindings/python/docs/gpt4all_chat.md b/gpt4all-bindings/python/docs/gpt4all_chat.md index 96da44d7..e2bc9f6a 100644 --- a/gpt4all-bindings/python/docs/gpt4all_chat.md +++ b/gpt4all-bindings/python/docs/gpt4all_chat.md @@ -7,7 +7,7 @@ It is optimized to run 7-13B parameter LLMs on the CPU's of any computer running ## Running LLMs on CPU The GPT4All Chat UI supports models from all newer versions of `llama.cpp` with `GGUF` models including the `Mistral`, `LLaMA2`, `LLaMA`, `OpenLLaMa`, `Falcon`, `MPT`, `Replit`, `Starcoder`, and `Bert` architectures -GPT4All maintains an official list of recommended models located in [models2.json](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models2.json). You can pull request new models to it and if accepted they will show up in the official download dialog. +GPT4All maintains an official list of recommended models located in [models3.json](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models3.json). You can pull request new models to it and if accepted they will show up in the official download dialog. #### Sideloading any GGUF model If a model is compatible with the gpt4all-backend, you can sideload it into GPT4All Chat by: diff --git a/gpt4all-bindings/python/docs/gpt4all_faq.md b/gpt4all-bindings/python/docs/gpt4all_faq.md index 981494de..74a95772 100644 --- a/gpt4all-bindings/python/docs/gpt4all_faq.md +++ b/gpt4all-bindings/python/docs/gpt4all_faq.md @@ -61,12 +61,12 @@ or `allowDownload=true` (default), a model is automatically downloaded into `.ca unless it already exists. In case of connection issues or errors during the download, you might want to manually verify the model file's MD5 -checksum by comparing it with the one listed in [models2.json]. +checksum by comparing it with the one listed in [models3.json]. As an alternative to the basic downloader built into the bindings, you can choose to download from the website instead. Scroll down to 'Model Explorer' and pick your preferred model. -[models2.json]: https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models2.json +[models3.json]: https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models3.json #### I need the chat GUI and bindings to behave the same @@ -93,7 +93,7 @@ The chat GUI and bindings are based on the same backend. You can make them behav - Next you'll have to compare the templates, adjusting them as necessary, based on how you're using the bindings. - Specifically, in Python: - With simple `generate()` calls, the input has to be surrounded with system and prompt templates. - - When using a chat session, it depends on whether the bindings are allowed to download [models2.json]. If yes, + - When using a chat session, it depends on whether the bindings are allowed to download [models3.json]. If yes, and in the chat GUI the default templates are used, it'll be handled automatically. If no, use `chat_session()` template parameters to customize them. diff --git a/gpt4all-bindings/python/docs/index.md b/gpt4all-bindings/python/docs/index.md index 9fabf321..b87ee14e 100644 --- a/gpt4all-bindings/python/docs/index.md +++ b/gpt4all-bindings/python/docs/index.md @@ -38,7 +38,7 @@ The GPT4All software ecosystem is compatible with the following Transformer arch - `MPT` (including `Replit`) - `GPT-J` -You can find an exhaustive list of supported models on the [website](https://gpt4all.io) or in the [models directory](https://raw.githubusercontent.com/nomic-ai/gpt4all/main/gpt4all-chat/metadata/models2.json) +You can find an exhaustive list of supported models on the [website](https://gpt4all.io) or in the [models directory](https://raw.githubusercontent.com/nomic-ai/gpt4all/main/gpt4all-chat/metadata/models3.json) GPT4All models are artifacts produced through a process known as neural network quantization. diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index e65c6fe1..036fdbd3 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -9,13 +9,15 @@ import sys import threading from enum import Enum from queue import Queue -from typing import Any, Callable, Iterable, overload +from typing import Any, Callable, Generic, Iterable, TypedDict, TypeVar, overload if sys.version_info >= (3, 9): import importlib.resources as importlib_resources else: import importlib_resources +EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]') + # TODO: provide a config file to make this more robust MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build" @@ -25,7 +27,7 @@ def load_llmodel_library(): ext = {"Darwin": "dylib", "Linux": "so", "Windows": "dll"}[platform.system()] try: - # Linux, Windows, MinGW + # macOS, Linux, MinGW lib = ctypes.CDLL(str(MODEL_LIB_PATH / f"libllmodel.{ext}")) except FileNotFoundError: if ext != 'dll': @@ -108,6 +110,7 @@ llmodel.llmodel_embed.argtypes = [ ctypes.POINTER(ctypes.c_size_t), ctypes.c_char_p, ctypes.c_int, + ctypes.POINTER(ctypes.c_size_t), ctypes.c_bool, ctypes.c_bool, ctypes.POINTER(ctypes.c_char_p), @@ -157,6 +160,11 @@ class Sentinel(Enum): TERMINATING_SYMBOL = 0 +class EmbedResult(Generic[EmbeddingsType], TypedDict): + embeddings: EmbeddingsType + n_prompt_tokens: int + + class LLModel: """ Base class and universal wrapper for GPT4All language models @@ -188,7 +196,7 @@ class LLModel: raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}") self.model = model - def __del__(self): + def __del__(self, llmodel=llmodel): if hasattr(self, 'model'): llmodel.llmodel_model_destroy(self.model) @@ -291,20 +299,20 @@ class LLModel: @overload def generate_embeddings( - self, text: str, prefix: str, dimensionality: int, do_mean: bool, atlas: bool, - ) -> list[float]: ... + self, text: str, prefix: str, dimensionality: int, do_mean: bool, count_tokens: bool, atlas: bool, + ) -> EmbedResult[list[float]]: ... @overload def generate_embeddings( self, text: list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, - ) -> list[list[float]]: ... + ) -> EmbedResult[list[list[float]]]: ... @overload def generate_embeddings( self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, - ) -> Any: ... + ) -> EmbedResult[list[Any]]: ... def generate_embeddings( self, text: str | list[str], prefix: str | None, dimensionality: int, do_mean: bool, atlas: bool, - ) -> Any: + ) -> EmbedResult[list[Any]]: if not text: raise ValueError("text must not be None or empty") @@ -313,6 +321,7 @@ class LLModel: # prepare input embedding_size = ctypes.c_size_t() + token_count = ctypes.c_size_t() error = ctypes.c_char_p() c_prefix = ctypes.c_char_p() if prefix is None else prefix.encode() c_texts = (ctypes.c_char_p * (len(text) + 1))() @@ -321,8 +330,8 @@ class LLModel: # generate the embeddings embedding_ptr = llmodel.llmodel_embed( - self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, do_mean, atlas, - ctypes.byref(error), + self.model, c_texts, ctypes.byref(embedding_size), c_prefix, dimensionality, ctypes.byref(token_count), + do_mean, atlas, ctypes.byref(error), ) if not embedding_ptr: @@ -337,7 +346,8 @@ class LLModel: ] llmodel.llmodel_free_embedding(embedding_ptr) - return embedding_array[0] if single_text else embedding_array + embeddings = embedding_array[0] if single_text else embedding_array + return {'embeddings': embeddings, 'n_prompt_tokens': token_count.value} def prompt_model( self, diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 38a43be4..ae8dbfa1 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -18,6 +18,7 @@ from tqdm import tqdm from urllib3.exceptions import IncompleteRead, ProtocolError from . import _pyllmodel +from ._pyllmodel import EmbedResult as EmbedResult if TYPE_CHECKING: from typing import TypeAlias @@ -49,35 +50,69 @@ class Embed4All: model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf' self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs) + # return_dict=False @overload def embed( - self, text: str, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., - atlas: bool = ..., + self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., + return_dict: Literal[False] = ..., atlas: bool = ..., ) -> list[float]: ... @overload def embed( - self, text: list[str], prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., - atlas: bool = ..., + self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., + return_dict: Literal[False] = ..., atlas: bool = ..., ) -> list[list[float]]: ... + @overload + def embed( + self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., + long_text_mode: str = ..., return_dict: Literal[False] = ..., atlas: bool = ..., + ) -> list[Any]: ... + + # return_dict=True + @overload + def embed( + self, text: str, *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., + return_dict: Literal[True], atlas: bool = ..., + ) -> EmbedResult[list[float]]: ... + @overload + def embed( + self, text: list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., long_text_mode: str = ..., + return_dict: Literal[True], atlas: bool = ..., + ) -> EmbedResult[list[list[float]]]: ... + @overload + def embed( + self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., + long_text_mode: str = ..., return_dict: Literal[True], atlas: bool = ..., + ) -> EmbedResult[list[Any]]: ... + + # return type unknown + @overload + def embed( + self, text: str | list[str], *, prefix: str | None = ..., dimensionality: int | None = ..., + long_text_mode: str = ..., return_dict: bool = ..., atlas: bool = ..., + ) -> Any: ... def embed( - self, text: str | list[str], prefix: str | None = None, dimensionality: int | None = None, - long_text_mode: str = "mean", atlas: bool = False, - ) -> list[Any]: + self, text: str | list[str], *, prefix: str | None = None, dimensionality: int | None = None, + long_text_mode: str = "mean", return_dict: bool = False, atlas: bool = False, + ) -> Any: """ Generate one or more embeddings. Args: text: A text or list of texts to generate embeddings for. prefix: The model-specific prefix representing the embedding task, without the trailing colon. For Nomic - Embed this can be `search_query`, `search_document`, `classification`, or `clustering`. + Embed, this can be `search_query`, `search_document`, `classification`, or `clustering`. Defaults to + `search_document` or equivalent if known; otherwise, you must explicitly pass a prefix or an empty + string if none applies. dimensionality: The embedding dimension, for use with Matryoshka-capable models. Defaults to full-size. long_text_mode: How to handle texts longer than the model can accept. One of `mean` or `truncate`. + return_dict: Return the result as a dict that includes the number of prompt tokens processed. atlas: Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with long_text_mode="mean" will raise an error. Disabled by default. Returns: - An embedding or list of embeddings of your text(s). + With return_dict=False, an embedding or list of embeddings of your text(s). + With return_dict=True, a dict with keys 'embeddings' and 'n_prompt_tokens'. """ if dimensionality is None: dimensionality = -1 @@ -93,7 +128,8 @@ class Embed4All: do_mean = {"mean": True, "truncate": False}[long_text_mode] except KeyError: raise ValueError(f"Long text mode must be one of 'mean' or 'truncate', got {long_text_mode!r}") - return self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas) + result = self.gpt4all.model.generate_embeddings(text, prefix, dimensionality, do_mean, atlas) + return result if return_dict else result['embeddings'] class GPT4All: @@ -157,12 +193,12 @@ class GPT4All: @staticmethod def list_models() -> list[ConfigType]: """ - Fetch model list from https://gpt4all.io/models/models2.json. + Fetch model list from https://gpt4all.io/models/models3.json. Returns: Model list in JSON format. """ - resp = requests.get("https://gpt4all.io/models/models2.json") + resp = requests.get("https://gpt4all.io/models/models3.json") if resp.status_code != 200: raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}') return resp.json()