Add Ollama provider, Add vision support to Openai

This commit is contained in:
Heiner Lohaus 2024-05-05 23:38:31 +02:00
parent ff8c1fc140
commit 8fcf618bbe
8 changed files with 72 additions and 32 deletions

View File

@ -457,10 +457,13 @@ async def stream_generate(
returned_text = ''
message_id = None
while do_read:
try:
msg = await wss.receive_str()
except TypeError:
continue
objects = msg.split(Defaults.delimiter)
for obj in objects:
if obj is None or not obj:
if not obj:
continue
try:
response = json.loads(obj)

View File

@ -1,8 +1,7 @@
from __future__ import annotations
import requests
from ..typing import AsyncResult, Messages, ImageType
from ..image import to_data_uri
from ..typing import AsyncResult, Messages
from .needs_auth.Openai import Openai
class DeepInfra(Openai):
@ -33,7 +32,6 @@ class DeepInfra(Openai):
model: str,
messages: Messages,
stream: bool,
image: ImageType = None,
api_base: str = "https://api.deepinfra.com/v1/openai",
temperature: float = 0.7,
max_tokens: int = 1028,
@ -54,19 +52,6 @@ class DeepInfra(Openai):
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
}
if image is not None:
if not model:
model = cls.default_vision_model
messages[-1]["content"] = [
{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
},
{
"type": "text",
"text": messages[-1]["content"]
}
]
return super().create_async_generator(
model, messages,
stream=stream,

33
g4f/Provider/Ollama.py Normal file
View File

@ -0,0 +1,33 @@
from __future__ import annotations
import requests
from .needs_auth.Openai import Openai
from ..typing import AsyncResult, Messages
class Ollama(Openai):
label = "Ollama"
url = "https://ollama.com"
needs_auth = False
working = True
@classmethod
def get_models(cls):
if not cls.models:
url = 'http://127.0.0.1:11434/api/tags'
models = requests.get(url).json()["models"]
cls.models = [model['name'] for model in models]
cls.default_model = cls.models[0]
return cls.models
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "http://localhost:11434/v1",
**kwargs
) -> AsyncResult:
return super().create_async_generator(
model, messages, api_base=api_base, **kwargs
)

View File

@ -43,6 +43,7 @@ from .Llama import Llama
from .Local import Local
from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount
from .Ollama import Ollama
from .PerplexityLabs import PerplexityLabs
from .Pi import Pi
from .Replicate import Replicate

View File

@ -4,9 +4,10 @@ import json
from ..helper import filter_none
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
from ...typing import Union, Optional, AsyncResult, Messages
from ...typing import Union, Optional, AsyncResult, Messages, ImageType
from ...requests import StreamSession, raise_for_status
from ...errors import MissingAuthError, ResponseError
from ...image import to_data_uri
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
label = "OpenAI API"
@ -23,6 +24,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
proxy: str = None,
timeout: int = 120,
image: ImageType = None,
api_key: str = None,
api_base: str = "https://api.openai.com/v1",
temperature: float = None,
@ -36,6 +38,19 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
if cls.needs_auth and api_key is None:
raise MissingAuthError('Add a "api_key"')
if image is not None:
if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model
messages[-1]["content"] = [
{
"type": "image_url",
"image_url": {"url": to_data_uri(image)}
},
{
"type": "text",
"text": messages[-1]["content"]
}
]
async with StreamSession(
proxies={"all": proxy},
headers=cls.get_headers(stream, api_key, headers),
@ -51,7 +66,6 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
stream=stream,
**extra_data
)
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response)
if not stream:
@ -103,8 +117,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {api_key}"}
if cls.needs_auth and api_key is not None
else {}
if api_key is not None else {}
),
**({} if headers is None else headers)
}

View File

@ -201,7 +201,7 @@ def run_api(
if bind is not None:
host, port = bind.split(":")
uvicorn.run(
f"g4f.api:{'create_app_debug' if debug else 'create_app'}",
f"g4f.api:create_app{'_debug' if debug else ''}",
host=host, port=int(port),
workers=workers,
use_colors=use_colors,

View File

@ -11,6 +11,10 @@ def main():
api_parser = subparsers.add_parser("api")
api_parser.add_argument("--bind", default="0.0.0.0:1337", help="The bind string.")
api_parser.add_argument("--debug", action="store_true", help="Enable verbose logging.")
api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --debug and --workers)")
api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
default=None, help="Default provider for chat completion. (incompatible with --debug and --workers)")
api_parser.add_argument("--proxy", default=None, help="Default used proxy.")
api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
api_parser.add_argument("--disable-colors", action="store_true", help="Don't use colors.")
api_parser.add_argument("--ignore-cookie-files", action="store_true", help="Don't read .har and cookie files.")
@ -31,14 +35,15 @@ def main():
def run_api_args(args):
from g4f.api import AppConfig, run_api
AppConfig.set_ignore_cookie_files(
args.ignore_cookie_files
)
AppConfig.set_list_ignored_providers(
args.ignored_providers
)
AppConfig.set_g4f_api_key(
args.g4f_api_key
AppConfig.set_config(
ignore_cookie_files=args.ignore_cookie_files,
ignored_providers=args.ignored_providers,
g4f_api_key=args.g4f_api_key,
defaults={
"model": args.model,
"provider": args.provider,
"proxy": args.proxy
}
)
run_api(
bind=args.bind,

View File

@ -40,7 +40,7 @@ async def get_args_from_webview(url: str) -> dict:
"Referer": window.real_url
}
cookies = [list(*cookie.items()) for cookie in window.get_cookies()]
cookies = dict([(name, cookie.value) for name, cookie in cookies])
cookies = {name: cookie.value for name, cookie in cookies}
window.destroy()
return {"headers": headers, "cookies": cookies}