update code, handle escape sequence and others

This commit is contained in:
ezerinz 2023-04-28 16:38:05 +08:00
parent 5b38bcf8e6
commit be04fcd7f3
No known key found for this signature in database
GPG Key ID: 8713BDED6DE33C4D

View File

@ -1,60 +1,62 @@
import json
import re
from fake_useragent import UserAgent
import requests
class Completion:
@staticmethod
def create(
systemprompt:str,
text:str,
assistantprompt:str
):
def create(messages):
headers = {
"authority": "openai.a2hosted.com",
"accept": "text/event-stream",
"accept-language": "en-US,en;q=0.9,id;q=0.8,ja;q=0.7",
"cache-control": "no-cache",
"sec-fetch-dest": "empty",
"sec-fetch-mode": "cors",
"sec-fetch-site": "cross-site",
"user-agent": "Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/112.0",
}
data = [
{"role": "system", "content": systemprompt},
{"role": "user", "content": "hi"},
{"role": "assistant", "content": assistantprompt},
{"role": "user", "content": text},
]
url = f'https://openai.a2hosted.com/chat?q={Completion.__get_query_param(data)}'
try:
response = requests.get(url, headers=Completion.__get_headers(), stream=True)
except:
query_param = Completion.__create_query_param(messages)
url = f"https://openai.a2hosted.com/chat?q={query_param}"
request = requests.get(url, headers=headers, stream=True)
if request.status_code != 200:
return Completion.__get_failure_response()
sentence = ""
content = request.content
response = Completion.__join_response(content)
return {"responses": response}
for message in response.iter_content(chunk_size=1024):
message = message.decode('utf-8')
msg_match, num_match = re.search(r'"msg":"([^"]+)"', message), re.search(r'\[DONE\] (\d+)', message)
if msg_match:
# Put the captured group into a sentence
sentence += msg_match.group(1)
return {
'response': sentence
}
@classmethod
def __get_headers(cls) -> dict:
return {
'authority': 'openai.a2hosted.com',
'accept': 'text/event-stream',
'accept-language': 'en-US,en;q=0.9,id;q=0.8,ja;q=0.7',
'cache-control': 'no-cache',
'sec-fetch-dest': 'empty',
'sec-fetch-mode': 'cors',
'sec-fetch-site': 'cross-site',
'user-agent': UserAgent().random
}
@classmethod
def __get_failure_response(cls) -> dict:
return dict(response='Unable to fetch the response, Please try again.', links=[], extra={})
return dict(
response="Unable to fetch the response, Please try again.",
links=[],
extra={},
)
@classmethod
def __get_query_param(cls, conversation) -> str:
def __multiple_replace(cls, string, reps) -> str:
for original, replacement in reps.items():
string = string.replace(original, replacement)
return string
@classmethod
def __create_query_param(cls, conversation) -> str:
encoded_conversation = json.dumps(conversation)
return encoded_conversation.replace(" ", "%20").replace('"', '%22').replace("'", "%27")
replacement = {" ": "%20", '"': "%22", "'": "%27"}
return Completion.__multiple_replace(encoded_conversation, replacement)
@classmethod
def __convert_escape_codes(cls, text) -> str:
replacement = {'\\\\"': '"', '\\"': '"', "\\n": "\n", "\\'": "'"}
return Completion.__multiple_replace(text, replacement)
@classmethod
def __join_response(cls, data) -> str:
data = data.decode("utf-8")
find_ans = re.findall(r'(?<={"msg":)[^}]*', str(data))
ans = [Completion.__convert_escape_codes(x[1:-1]) for x in find_ans]
response = "".join(ans)
return response