Update HuggingChat.py

This commit is contained in:
H Lohaus 2024-04-07 00:01:04 +02:00 committed by GitHub
parent 5dcbf147b6
commit e5e811fd7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,10 +5,10 @@ import requests
from aiohttp import ClientSession, BaseConnector from aiohttp import ClientSession, BaseConnector
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt, get_connector from .helper import format_prompt, get_connector
class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co/chat" url = "https://huggingface.co/chat"
working = True working = True
@ -60,11 +60,11 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
headers=headers, headers=headers,
connector=get_connector(connector, proxy) connector=get_connector(connector, proxy)
) as session: ) as session:
async with session.post(f"{cls.url}/conversation", json=options, proxy=proxy) as response: async with session.post(f"{cls.url}/conversation", json=options) as response:
response.raise_for_status() await raise_for_status(response)
conversation_id = (await response.json())["conversationId"] conversation_id = (await response.json())["conversationId"]
async with session.get(f"{cls.url}/conversation/{conversation_id}/__data.json") as response: async with session.get(f"{cls.url}/conversation/{conversation_id}/__data.json") as response:
response.raise_for_status() await raise_for_status(response)
data: list = (await response.json())["nodes"][1]["data"] data: list = (await response.json())["nodes"][1]["data"]
keys: list[int] = data[data[0]["messages"]] keys: list[int] = data[data[0]["messages"]]
message_keys: dict = data[keys[0]] message_keys: dict = data[keys[0]]
@ -79,7 +79,7 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(f"{cls.url}/conversation/{conversation_id}", json=options) as response: async with session.post(f"{cls.url}/conversation/{conversation_id}", json=options) as response:
first_token = True first_token = True
async for line in response.content: async for line in response.content:
response.raise_for_status() await raise_for_status(response)
line = json.loads(line) line = json.loads(line)
if "type" not in line: if "type" not in line:
raise RuntimeError(f"Response: {line}") raise RuntimeError(f"Response: {line}")
@ -91,5 +91,5 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
yield token yield token
elif line["type"] == "finalAnswer": elif line["type"] == "finalAnswer":
break break
async with session.delete(f"{cls.url}/conversation/{conversation_id}", proxy=proxy) as response: async with session.delete(f"{cls.url}/conversation/{conversation_id}") as response:
response.raise_for_status() await raise_for_status(response)