feat(api): support async streaming in chat completions

This commit is contained in:
kqlio67 2024-10-31 00:34:49 +02:00
parent a0087269b3
commit 0d05825a71

View File

@ -165,19 +165,6 @@ class Api:
@self.app.post("/v1/chat/completions") @self.app.post("/v1/chat/completions")
async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None): async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
try: try:
# Find the last delimiter with ':' or '-'
if ':' in config.model:
model_parts = config.model.rsplit(":", 1)
elif '-' in config.model:
model_parts = config.model.rsplit("-", 1)
else:
model_parts = [config.model] # There is no prefix.
base_model = model_parts[0] # We use the base model name
model_prefix = model_parts[1] if len(model_parts) > 1 else None
config.model = base_model # Update the configuration to the basic model
config.provider = provider if config.provider is None else config.provider config.provider = provider if config.provider is None else config.provider
if config.api_key is None and request is not None: if config.api_key is None and request is not None:
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
@ -206,9 +193,13 @@ class Api:
return JSONResponse(response_list[0].to_json()) return JSONResponse(response_list[0].to_json())
# Streaming response # Streaming response
async def async_generator(sync_gen):
for item in sync_gen:
yield item
async def streaming(): async def streaming():
try: try:
async for chunk in response: async for chunk in async_generator(response):
yield f"data: {json.dumps(chunk.to_json())}\n\n" yield f"data: {json.dumps(chunk.to_json())}\n\n"
except GeneratorExit: except GeneratorExit:
pass pass
@ -242,7 +233,6 @@ class Api:
async def completions(): async def completions():
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig]) -> str: def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig]) -> str:
last_provider = g4f.get_last_provider(True) last_provider = g4f.get_last_provider(True)
return json.dumps({ return json.dumps({