refactor(g4f/api/__init__.py): use asynchronous methods in Client

This commit is contained in:
kqlio67 2024-10-15 15:07:06 +03:00
parent 069b6cebdd
commit c6d11e5cef

View File

@ -17,7 +17,7 @@ from typing import Union, Optional
import g4f import g4f
import g4f.debug import g4f.debug
from g4f.client import AsyncClient from g4f.client import Client
from g4f.typing import Messages from g4f.typing import Messages
from g4f.cookies import read_cookie_files from g4f.cookies import read_cookie_files
@ -69,7 +69,7 @@ class AppConfig():
class Api: class Api:
def __init__(self, app: FastAPI) -> None: def __init__(self, app: FastAPI) -> None:
self.app = app self.app = app
self.client = AsyncClient() self.client = Client()
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key") self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
def register_authorization(self): def register_authorization(self):
@ -156,7 +156,8 @@ class Api:
auth_header = auth_header.split(None, 1)[-1] auth_header = auth_header.split(None, 1)[-1]
if auth_header and auth_header != "Bearer": if auth_header and auth_header != "Bearer":
config.api_key = auth_header config.api_key = auth_header
response = self.client.chat.completions.create( # Use the asynchronous create method and await it
response = await self.client.chat.completions.async_create(
**{ **{
**AppConfig.defaults, **AppConfig.defaults,
**config.dict(exclude_none=True), **config.dict(exclude_none=True),
@ -164,7 +165,7 @@ class Api:
ignored=AppConfig.ignored_providers ignored=AppConfig.ignored_providers
) )
if not config.stream: if not config.stream:
return JSONResponse((await response).to_json()) return JSONResponse(response.to_json())
async def streaming(): async def streaming():
try: try:
@ -196,10 +197,11 @@ class Api:
auth_header = auth_header.split(None, 1)[-1] auth_header = auth_header.split(None, 1)[-1]
if auth_header and auth_header != "Bearer": if auth_header and auth_header != "Bearer":
config.api_key = auth_header config.api_key = auth_header
response = self.client.images.generate( # Use the asynchronous generate method and await it
response = await self.client.images.async_generate(
**config.dict(exclude_none=True), **config.dict(exclude_none=True),
) )
return JSONResponse((await response).to_json()) return JSONResponse(response.to_json())
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
return Response(content=format_exception(e, config), status_code=500, media_type="application/json") return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
@ -232,4 +234,4 @@ def run_api(
use_colors=use_colors, use_colors=use_colors,
factory=True, factory=True,
reload=debug reload=debug
) )