platform/pods/embeddings/server.py

154 lines
4.8 KiB
Python
Raw Normal View History

from functools import partial
from http.server import BaseHTTPRequestHandler, HTTPServer
import json
import model as embeddings
import argparse
import traceback
def toArray(emb):
return [v.item() for v in emb]
class EmbeddingsServer(BaseHTTPRequestHandler):
embService: embeddings.EmbeddingService
def __init__(self, embService, *args, **kwargs):
self.embService = embService
super().__init__(*args, **kwargs)
def do_POST(self):
try:
if self.path == '/embeddings':
self.sendEmbeddings()
return
if self.path == '/completion':
self.sendCompletion()
return
if self.path == '/compare':
self.sendCompare()
return
except BaseException as e:
print('Failed to process', e)
pass
self.send_response(200)
self.send_header("Content-type", "text/json")
self.end_headers()
obj = {
"result": False,
"error": "Unknown service"
}
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
def sendEmbeddings(self):
data = self.rfile.read(int(self.headers['Content-Length']))
jsbody = json.loads(data)
model = jsbody["model"]
try:
embeddings = self.embService.embeddings(jsbody["input"])
emb = toArray(embeddings[0])
obj = {
"data": [
{
"embedding": emb,
"size": len(emb)
}
],
"model": model,
"usage": {
"prompt_tokens": embeddings[1],
"total_tokens": 1
}
}
self.send_response(200)
self.send_header("Content-type", "text/json")
self.end_headers()
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
except BaseException as e:
# self.send_response(400, str(e))
self.send_error(400, str(e))
self.end_headers()
print('error', e)
traceback.print_exc()
pass
def sendCompletion(self):
self.send_response(200)
self.send_header("Content-type", "text/json")
self.end_headers()
data = self.rfile.read(int(self.headers['Content-Length']))
jsbody = json.loads(data)
completion = self.embService.completion(jsbody["input"], max_length=jsbody["max_length"], temperature=jsbody["temperature"] )
model = jsbody["model"]
obj = {
"data": [
{
"completion": completion
}
],
"model": model
}
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
def sendCompare(self):
self.send_response(200)
self.send_header("Content-type", "text/json")
self.end_headers()
data = self.rfile.read(int(self.headers['Content-Length']))
jsbody = json.loads(data)
emb1 = self.embService.embeddings(jsbody["input"])
emb2 = self.embService.embeddings(jsbody["compare"])
model = jsbody["model"]
e1 = toArray(emb1[0])
e2 = toArray(emb2[0])
obj = {
"similarity": self.embService.compare(emb1[0], emb2[0]).item(),
"input": e1,
"input_len": len(e1),
"compare": e2,
"compare_len": len(e2),
"model": model
}
self.wfile.write(bytes(json.dumps(obj), "utf-8"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog = 'Embedding\'s service')
# 1024, sentence-transformers/all-roberta-large-v1
# 386, sentence-transformers/all-MiniLM-L6-v2
parser.add_argument('--model', default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument('--host', default="0.0.0.0")
parser.add_argument('--device', default='cpu')
parser.add_argument('--port', default=4070) # option that takes a value
args = parser.parse_args()
hostName = args.host
serverPort = args.port
device = args.device
model = args.model
print('loading model:', model, ' on device:', device)
emb = embeddings.EmbeddingService(model, device)
webServer = HTTPServer((hostName, serverPort), partial(EmbeddingsServer, emb), bind_and_activate=False)
webServer.allow_reuse_address = True
webServer.daemon_threads = True
webServer.server_bind()
webServer.server_activate()
print("Embedding started http://%s:%s" % (hostName, serverPort))
try:
webServer.serve_forever()
except KeyboardInterrupt:
pass
webServer.server_close()
print("Server stopped.")