mirror of
https://github.com/hcengineering/platform.git
synced 2024-12-23 11:31:57 +03:00
ba8ab45dc2
Signed-off-by: Andrey Sobolev <haiodo@gmail.com>
154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
from functools import partial
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
import json
|
|
|
|
import model as embeddings
|
|
import argparse
|
|
|
|
import traceback
|
|
|
|
def toArray(emb):
|
|
return [v.item() for v in emb]
|
|
|
|
|
|
class EmbeddingsServer(BaseHTTPRequestHandler):
|
|
embService: embeddings.EmbeddingService
|
|
|
|
def __init__(self, embService, *args, **kwargs):
|
|
self.embService = embService
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def do_POST(self):
|
|
try:
|
|
if self.path == '/embeddings':
|
|
self.sendEmbeddings()
|
|
return
|
|
|
|
if self.path == '/completion':
|
|
self.sendCompletion()
|
|
return
|
|
|
|
if self.path == '/compare':
|
|
self.sendCompare()
|
|
return
|
|
except BaseException as e:
|
|
print('Failed to process', e)
|
|
pass
|
|
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "text/json")
|
|
self.end_headers()
|
|
obj = {
|
|
"result": False,
|
|
"error": "Unknown service"
|
|
}
|
|
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
|
|
|
|
def sendEmbeddings(self):
|
|
data = self.rfile.read(int(self.headers['Content-Length']))
|
|
jsbody = json.loads(data)
|
|
model = jsbody["model"]
|
|
try:
|
|
embeddings = self.embService.embeddings(jsbody["input"])
|
|
emb = toArray(embeddings[0])
|
|
obj = {
|
|
"data": [
|
|
{
|
|
"embedding": emb,
|
|
"size": len(emb)
|
|
}
|
|
],
|
|
"model": model,
|
|
"usage": {
|
|
"prompt_tokens": embeddings[1],
|
|
"total_tokens": 1
|
|
}
|
|
}
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "text/json")
|
|
self.end_headers()
|
|
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
|
|
except BaseException as e:
|
|
# self.send_response(400, str(e))
|
|
self.send_error(400, str(e))
|
|
self.end_headers()
|
|
print('error', e)
|
|
traceback.print_exc()
|
|
pass
|
|
|
|
def sendCompletion(self):
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "text/json")
|
|
self.end_headers()
|
|
data = self.rfile.read(int(self.headers['Content-Length']))
|
|
jsbody = json.loads(data)
|
|
completion = self.embService.completion(jsbody["input"], max_length=jsbody["max_length"], temperature=jsbody["temperature"] )
|
|
model = jsbody["model"]
|
|
obj = {
|
|
"data": [
|
|
{
|
|
"completion": completion
|
|
}
|
|
],
|
|
"model": model
|
|
}
|
|
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
|
|
|
|
def sendCompare(self):
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "text/json")
|
|
self.end_headers()
|
|
data = self.rfile.read(int(self.headers['Content-Length']))
|
|
jsbody = json.loads(data)
|
|
emb1 = self.embService.embeddings(jsbody["input"])
|
|
emb2 = self.embService.embeddings(jsbody["compare"])
|
|
model = jsbody["model"]
|
|
e1 = toArray(emb1[0])
|
|
e2 = toArray(emb2[0])
|
|
obj = {
|
|
"similarity": self.embService.compare(emb1[0], emb2[0]).item(),
|
|
"input": e1,
|
|
"input_len": len(e1),
|
|
"compare": e2,
|
|
"compare_len": len(e2),
|
|
"model": model
|
|
}
|
|
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
prog = 'Embedding\'s service')
|
|
|
|
# 1024, sentence-transformers/all-roberta-large-v1
|
|
# 386, sentence-transformers/all-MiniLM-L6-v2
|
|
parser.add_argument('--model', default="sentence-transformers/all-MiniLM-L6-v2")
|
|
parser.add_argument('--host', default="0.0.0.0")
|
|
parser.add_argument('--device', default='cpu')
|
|
parser.add_argument('--port', default=4070) # option that takes a value
|
|
|
|
args = parser.parse_args()
|
|
|
|
hostName = args.host
|
|
serverPort = args.port
|
|
device = args.device
|
|
model = args.model
|
|
|
|
print('loading model:', model, ' on device:', device)
|
|
|
|
emb = embeddings.EmbeddingService(model, device)
|
|
|
|
webServer = HTTPServer((hostName, serverPort), partial(EmbeddingsServer, emb), bind_and_activate=False)
|
|
webServer.allow_reuse_address = True
|
|
webServer.daemon_threads = True
|
|
|
|
webServer.server_bind()
|
|
webServer.server_activate()
|
|
print("Embedding started http://%s:%s" % (hostName, serverPort))
|
|
|
|
try:
|
|
webServer.serve_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
webServer.server_close()
|
|
print("Server stopped.") |