2024-04-07 12:25:14 +03:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-04-07 11:36:13 +03:00
|
|
|
import os
|
|
|
|
|
|
|
|
from gpt4all import GPT4All
|
|
|
|
from .models import get_models
|
|
|
|
from ..typing import Messages
|
|
|
|
|
|
|
|
MODEL_LIST: dict[str, dict] = None
|
|
|
|
|
|
|
|
def find_model_dir(model_file: str) -> str:
|
|
|
|
local_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
project_dir = os.path.dirname(os.path.dirname(local_dir))
|
|
|
|
|
|
|
|
new_model_dir = os.path.join(project_dir, "models")
|
|
|
|
new_model_file = os.path.join(new_model_dir, model_file)
|
|
|
|
if os.path.isfile(new_model_file):
|
|
|
|
return new_model_dir
|
|
|
|
|
|
|
|
old_model_dir = os.path.join(local_dir, "models")
|
|
|
|
old_model_file = os.path.join(old_model_dir, model_file)
|
|
|
|
if os.path.isfile(old_model_file):
|
|
|
|
return old_model_dir
|
|
|
|
|
|
|
|
working_dir = "./"
|
|
|
|
for root, dirs, files in os.walk(working_dir):
|
|
|
|
if model_file in files:
|
|
|
|
return root
|
|
|
|
|
|
|
|
return new_model_dir
|
|
|
|
|
|
|
|
class LocalProvider:
|
|
|
|
@staticmethod
|
|
|
|
def create_completion(model: str, messages: Messages, stream: bool = False, **kwargs):
|
|
|
|
global MODEL_LIST
|
|
|
|
if MODEL_LIST is None:
|
|
|
|
MODEL_LIST = get_models()
|
|
|
|
if model not in MODEL_LIST:
|
|
|
|
raise ValueError(f'Model "{model}" not found / not yet implemented')
|
|
|
|
|
|
|
|
model = MODEL_LIST[model]
|
|
|
|
model_file = model["path"]
|
|
|
|
model_dir = find_model_dir(model_file)
|
|
|
|
if not os.path.isfile(os.path.join(model_dir, model_file)):
|
|
|
|
print(f'Model file "models/{model_file}" not found.')
|
|
|
|
download = input(f"Do you want to download {model_file}? [y/n]: ")
|
|
|
|
if download in ["y", "Y"]:
|
|
|
|
GPT4All.download_model(model_file, model_dir)
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Model "{model_file}" not found.')
|
|
|
|
|
|
|
|
model = GPT4All(model_name=model_file,
|
|
|
|
#n_threads=8,
|
|
|
|
verbose=False,
|
|
|
|
allow_download=False,
|
|
|
|
model_path=model_dir)
|
|
|
|
|
|
|
|
system_message = "\n".join(message["content"] for message in messages if message["role"] == "system")
|
|
|
|
if system_message:
|
|
|
|
system_message = "A chat between a curious user and an artificial intelligence assistant."
|
|
|
|
|
|
|
|
prompt_template = "USER: {0}\nASSISTANT: "
|
|
|
|
conversation = "\n" . join(
|
|
|
|
f"{message['role'].upper()}: {message['content']}"
|
|
|
|
for message in messages
|
|
|
|
if message["role"] != "system"
|
|
|
|
) + "\nASSISTANT: "
|
|
|
|
|
2024-04-13 19:02:47 +03:00
|
|
|
def should_not_stop(token_id: int, token: str):
|
|
|
|
return "USER" not in token
|
|
|
|
|
2024-04-07 11:36:13 +03:00
|
|
|
with model.chat_session(system_message, prompt_template):
|
|
|
|
if stream:
|
2024-04-13 19:02:47 +03:00
|
|
|
for token in model.generate(conversation, streaming=True, callback=should_not_stop):
|
2024-04-07 11:36:13 +03:00
|
|
|
yield token
|
|
|
|
else:
|
2024-04-13 19:02:47 +03:00
|
|
|
yield model.generate(conversation, callback=should_not_stop)
|