diff --git a/distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso b/distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso index c7853133750..3e380f2c595 100644 --- a/distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso +++ b/distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso @@ -1,5 +1,6 @@ from Standard.Base import all from Standard.Table import Table +import project.Helpers ## PRIVATE goal_placeholder = "__$$GOAL$$__" @@ -34,3 +35,6 @@ Table.build_ai_prompt self = ## PRIVATE build_ai_prompt subject = subject.build_ai_prompt + +## PRIVATE +print subject = subject.to_default_visualization_data diff --git a/docs/language-server/protocol-language-server.md b/docs/language-server/protocol-language-server.md index 842fcc15bd3..ae0199f8174 100644 --- a/docs/language-server/protocol-language-server.md +++ b/docs/language-server/protocol-language-server.md @@ -187,6 +187,9 @@ transport formats, please look [here](./protocol-architecture). - [Profiling Operations](#profiling-operations) - [`profiling/start`](#profilingstart) - [`profiling/stop`](#profilingstop) +- [AI Operations](#ai-operations) + - [`ai/completion_v2`](#aicompletionv2) + - [`ai/completionProgress`](#aicompletionprogres) - [Errors](#errors-75) - [`Error`](#error) - [`AccessDeniedError`](#accessdeniederror) @@ -4296,7 +4299,7 @@ interface SearchGetSuggestionsDatabaseVersionResult { ### `search/suggestionsDatabaseUpdate` -Sent from server to the client to inform abouth the change in the suggestions +Sent from server to the client to inform about the change in the suggestions database. - **Type:** Notification @@ -4315,7 +4318,7 @@ interface SearchSuggestionsDatabaseUpdateNotification { ### `search/suggestionsOrderDatabaseUpdate` -Sent from server to the client to inform abouth the change in the suggestions +Sent from server to the client to inform about the change in the suggestions order database. - **Type:** Notification @@ -5178,6 +5181,92 @@ interface ProfilingStopResult {} None +## AI Operations + +### `ai/completion_v2` + +Sent from the client to the server to ask the AI model the code suggestion. + +- **Type:** Request +- **Direction:** Client -> Server +- **Connection:** Protocol +- **Visibility:** Public + +#### Parameters + +```typescript +interface AiCompletionParameters { + /** The execution context id to use for executing expressions. */ + contextId: UUID; + /** + * The expression providing the execution scope. The same as `expressionId` + * parameter of `executionContext/executeExpression` method. + */ + expressionId: UUID; + /** The user prompt. */ + prompt: string; + /** The system prompt describing the AI role. */ + systemPrompt?: string; + /** The AI model to use. */ + model?: string; +} +``` + +#### Result + +```typescript +type AiCompletionResult = AiCompletionResultSuccess | AiCompletionResultFailure; + +interface AiCompletionResultSuccess { + /** The code of the function producing the desired result. */ + fn: string; + + /** The code of how to call the suggested function. */ + fnCall: string; +} + +interface AiCompletionResultFailure { + /** + * The explanation given by the AI model for why it was unable to provide the + * answer. + */ + reason: string; +} +``` + +#### Errors + +- [`AiHttpError`](#aihttperror) Signals about an error during the processing of + AI http respnse. +- [`AiEvaluationError`](#aievaluationerror) Signals about an error during the + evaluation of expression requested by AI. + +### `ai/completionProgress` + +Sent from server to the client to inform about the progress of the +[`ai/completion`](#aicompletion) request. + +- **Type:** Notification +- **Direction:** Server -> Client +- **Connection:** Protocol +- **Visibility:** Public + +#### Notification + +```typescript +interface AiCompletionProgressNotification { + /** Code snippte that AI model requested to evaluate. */ + code: string; + /** Explanation given by the AI model why it needs an extra information. */ + reason: string; + /** + * The id of the visualization being executed. When evaluated, the + * visualization update will contain the result of the executed expression. + */ + visualizationId: UUID; +} +``` + ## Errors The language server component also has its own set of errors. This section is @@ -5763,3 +5852,34 @@ Signals that the refactoring of the given expression is not supported. "message" : "Refactoring not supported for expression []" } ``` + +### `AiHttpError` + +Signals about an error during the processing of AI http respnse. + +```typescript +"error" : { + "code" : 10001, + "message" : "Failed to process HTTP response", + "payload" : { + "reason" : "", + "request" : "", + "response" : "" + } +} +``` + +### `AiEvaluationError` + +Signals about an error during the evaluation of expression requested by AI. + +```typescript +"error" : { + "code" : 10002, + "message" : "Failed to execute expression", + "payload" : { + "expression" : "", + "error" : "" + } +} +``` diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/ai/AICompletion.scala b/engine/language-server/src/main/scala/org/enso/languageserver/ai/AICompletion.scala deleted file mode 100644 index 953f3d058f0..00000000000 --- a/engine/language-server/src/main/scala/org/enso/languageserver/ai/AICompletion.scala +++ /dev/null @@ -1,18 +0,0 @@ -package org.enso.languageserver.ai - -import org.enso.jsonrpc.{HasParams, HasResult, Method} - -case object AICompletion extends Method("ai/completion") { - case class Params(prompt: String, stopSequence: String) - case class Result(code: String) - - implicit val hasParams: HasParams.Aux[this.type, AICompletion.Params] = - new HasParams[this.type] { - type Params = AICompletion.Params - } - - implicit val hasResult: HasResult.Aux[this.type, AICompletion.Result] = - new HasResult[this.type] { - type Result = AICompletion.Result - } -} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/ai/AiApi.scala b/engine/language-server/src/main/scala/org/enso/languageserver/ai/AiApi.scala new file mode 100644 index 00000000000..88b8ba78bfd --- /dev/null +++ b/engine/language-server/src/main/scala/org/enso/languageserver/ai/AiApi.scala @@ -0,0 +1,82 @@ +package org.enso.languageserver.ai + +import io.circe.Json +import io.circe.syntax._ +import org.enso.jsonrpc.{Error, HasParams, HasResult, Method} + +import java.util.UUID + +case object AiApi { + + case object AiCompletion extends Method("ai/completion") { + + case class Params(prompt: String, stopSequence: String) + case class Result(code: String) + + implicit val hasParams: HasParams.Aux[this.type, AiCompletion.Params] = + new HasParams[this.type] { + type Params = AiCompletion.Params + } + + implicit val hasResult: HasResult.Aux[this.type, AiCompletion.Result] = + new HasResult[this.type] { + type Result = AiCompletion.Result + } + } + + case object AiCompletion2 extends Method("ai/completion_v2") { + + case class Params( + contextId: UUID, + expressionId: UUID, + prompt: String, + systemPrompt: Option[String], + model: Option[String] + ) + type Result = AiProtocol.AiCompletionResult + + implicit val hasParams: HasParams.Aux[this.type, AiCompletion2.Params] = + new HasParams[this.type] { + type Params = AiCompletion2.Params + } + + implicit val hasResult: HasResult.Aux[this.type, AiCompletion2.Result] = + new HasResult[this.type] { + type Result = AiCompletion2.Result + } + } + + case object AiCompletionProgress extends Method("ai/completionProgress") { + + case class Params(code: String, reason: String, visualizationId: UUID) + + implicit + val hasParams: HasParams.Aux[this.type, AiCompletionProgress.Params] = + new HasParams[this.type] { + type Params = AiCompletionProgress.Params + } + } + + case class AiHttpError(reason: String, request: Json, response: String) + extends Error(10001, "Failed to process HTTP response") { + + override val payload: Option[Json] = Some( + Json.obj( + ("reason", reason.asJson), + ("request", request), + ("response", response.asJson) + ) + ) + } + + case class AiEvaluationError(expression: String, error: String) + extends Error(10002, "Failed to execute expression") { + + override val payload: Option[Json] = Some( + Json.obj( + ("expression", expression.asJson), + ("error", error.asJson) + ) + ) + } +} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/ai/AiProtocol.scala b/engine/language-server/src/main/scala/org/enso/languageserver/ai/AiProtocol.scala new file mode 100644 index 00000000000..a148729351d --- /dev/null +++ b/engine/language-server/src/main/scala/org/enso/languageserver/ai/AiProtocol.scala @@ -0,0 +1,53 @@ +package org.enso.languageserver.ai + +import java.util.UUID + +object AiProtocol { + + /** Base trait for the AI completion results. */ + sealed trait AiCompletionResult + + case object AiCompletionResult { + + /** Successful completion result. + * + * @param fn the code for the function returning the answer + * @param fnCall the code how to call the function + */ + sealed case class Success(fn: String, fnCall: String) + extends AiCompletionResult + + /** Failed completion result + * + * @param reason the explanation why the AI was unable to provide the result. + */ + sealed case class Failure(reason: String) extends AiCompletionResult + } + + /** The request from AI to evaluate an expression. + * + * @param reason the explanation why the AI requires this information + * @param code the expression code + */ + case class AiEvalRequest(reason: String, code: String) + + /** The message sent to AI. + * + * @param role the AI role + * @param content the message content + */ + case class CompletionsMessage(role: String, content: String) + + /** The progress notification sent when AI requests to evaluate an expression. + * + * @param code the code that AI requested to evaluate + * @param reason the explanation why AI requires this information + * @param visualizationId the id of the visualization being executed. When evaluated, + * the visualization update will contain the result of the executed expression. + */ + case class AiCompletionProgressNotification( + code: String, + reason: String, + visualizationId: UUID + ) +} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala index 1dfa9113994..65a45311b58 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala @@ -7,7 +7,12 @@ import com.typesafe.scalalogging.LazyLogging import org.enso.cli.task.ProgressUnit import org.enso.cli.task.notifications.TaskNotificationApi import org.enso.jsonrpc._ -import org.enso.languageserver.ai.AICompletion +import org.enso.languageserver.ai.AiApi.{ + AiCompletion, + AiCompletion2, + AiCompletionProgress +} +import org.enso.languageserver.ai.AiProtocol import org.enso.languageserver.boot.resource.{ InitializationComponent, InitializationComponentInitialized @@ -472,6 +477,16 @@ class JsonConnectionController( translateProgressNotification(payload) webActor ! translated + case AiProtocol.AiCompletionProgressNotification( + code, + reason, + visualizationId + ) => + webActor ! Notification( + AiCompletionProgress, + AiCompletionProgress.Params(code, reason, visualizationId) + ) + case req @ Request(method, _, _) if requestHandlers.contains(method) => refreshIdleTime(method) val handler = context.actorOf( @@ -566,9 +581,14 @@ class JsonConnectionController( .props(requestTimeout, suggestionsHandler), InvalidateSuggestionsDatabase -> search.InvalidateSuggestionsDatabaseHandler .props(requestTimeout, suggestionsHandler), - AICompletion -> ai.AICompletionHandler.props( + AiCompletion -> ai.AICompletionHandler.props( languageServerConfig.aiCompletionConfig ), + AiCompletion2 -> ai.AICompletion2Handler.props( + languageServerConfig.aiCompletionConfig, + rpcSession, + runtimeConnector + ), ExecuteExpression -> ExecuteExpressionHandler .props(rpcSession.clientId, requestTimeout, contextRegistry), AttachVisualization -> AttachVisualizationHandler diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala index 3040aadef43..13ea37c3fba 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala @@ -7,7 +7,11 @@ import org.enso.cli.task.notifications.TaskNotificationApi.{ TaskStarted } import org.enso.jsonrpc.Protocol -import org.enso.languageserver.ai.AICompletion +import org.enso.languageserver.ai.AiApi.{ + AiCompletion, + AiCompletion2, + AiCompletionProgress +} import org.enso.languageserver.capability.CapabilityApi.{ AcquireCapability, ForceReleaseCapability, @@ -88,7 +92,8 @@ object JsonRpc { .registerRequest(GetSuggestionsDatabaseVersion) .registerRequest(InvalidateSuggestionsDatabase) .registerRequest(Completion) - .registerRequest(AICompletion) + .registerRequest(AiCompletion) + .registerRequest(AiCompletion2) .registerRequest(RenameProject) .registerRequest(RenameSymbol) .registerRequest(ProjectInfo) @@ -131,5 +136,6 @@ object JsonRpc { .registerNotification(SuggestionsDatabaseUpdates) .registerNotification(VisualizationEvaluationFailed) .registerNotification(ProjectRenamed) + .registerNotification(AiCompletionProgress) .finalized() } diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/UnsupportedHandler.scala b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/UnsupportedHandler.scala new file mode 100644 index 00000000000..3714e710942 --- /dev/null +++ b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/UnsupportedHandler.scala @@ -0,0 +1,19 @@ +package org.enso.languageserver.requesthandler + +import akka.actor.Actor +import com.typesafe.scalalogging.LazyLogging +import org.enso.jsonrpc.{Errors, Method, Request, ResponseError} +import org.enso.languageserver.util.UnhandledLogging + +final class UnsupportedHandler(method: Method) + extends Actor + with LazyLogging + with UnhandledLogging { + + override def receive: Receive = { case Request(`method`, id, _) => + sender() ! ResponseError( + Some(id), + Errors.MethodNotFound + ) + } +} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletion2Handler.scala b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletion2Handler.scala new file mode 100644 index 00000000000..db9e13e7c89 --- /dev/null +++ b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletion2Handler.scala @@ -0,0 +1,409 @@ +package org.enso.languageserver.requesthandler.ai + +import akka.actor.{Actor, ActorRef, PoisonPill, Props} +import akka.http.scaladsl.Http +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.OAuth2BearerToken +import akka.pattern.PipeToSupport +import akka.stream.Materializer +import akka.util.ByteString +import com.typesafe.scalalogging.LazyLogging +import io.circe.Json +import io.circe.syntax._ +import io.circe.generic.auto._ +import org.enso.jsonrpc._ +import org.enso.languageserver.ai.AiApi.{ + AiCompletion2, + AiEvaluationError, + AiHttpError +} +import org.enso.languageserver.ai.AiProtocol +import org.enso.languageserver.data.AICompletionConfig +import org.enso.languageserver.requesthandler.UnsupportedHandler +import org.enso.languageserver.runtime.{ + ContextRegistryProtocol, + RuntimeFailureMapper +} +import org.enso.languageserver.session.JsonSession +import org.enso.languageserver.util.UnhandledLogging +import org.enso.logger.akka.ActorMessageLogging +import org.enso.polyglot.runtime.Runtime.Api + +import java.nio.charset.StandardCharsets +import java.util.UUID + +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.FiniteDuration + +class AICompletion2Handler( + cfg: AICompletionConfig, + session: JsonSession, + runtime: ActorRef +) extends Actor + with LazyLogging + with ActorMessageLogging + with UnhandledLogging + with PipeToSupport { + + import AICompletion2Handler._ + + override def preStart(): Unit = { + super.preStart() + + context.system.eventStream.subscribe(self, classOf[Api.VisualizationUpdate]) + context.system.eventStream + .subscribe(self, classOf[Api.VisualizationEvaluationFailed]) + } + + override def receive: Receive = requestStage + + private val http = Http(context.system) + implicit val ec: ExecutionContext = context.dispatcher + implicit val materializer: Materializer = Materializer(context) + + private def requestStage: Receive = LoggingReceive.withLabel("requestStage") { + case Request( + AiCompletion2, + id, + AiCompletion2.Params( + contextId, + expressionId, + prompt, + systemPrompt, + model + ) + ) => + val messages = Vector( + AiProtocol.CompletionsMessage( + "system", + systemPrompt.getOrElse(SYSTEM_PROMPT) + ), + AiProtocol.CompletionsMessage("user", prompt) + ) + val httpReq = sendHttpRequest(messages, model) + val debugInfo = DebugInfo(httpReq) + + context.become( + awaitingCompletionResponse( + id, + sender(), + contextId, + expressionId, + messages, + model, + debugInfo + ) + ) + } + + private def evalRequestStage( + id: Id, + replyTo: ActorRef, + contextId: Api.ContextId, + expressionId: Api.ExpressionId, + messages: Vector[AiProtocol.CompletionsMessage], + model: Option[String] + ): Receive = LoggingReceive.withLabel("evalRequestStage") { + case req @ AiProtocol.AiEvalRequest(reason, code) => + val requestId = UUID.randomUUID() + val visualizationId = UUID.randomUUID() + + val executeExpression = Api.ExecuteExpression( + contextId, + visualizationId, + expressionId, + code + ) + runtime ! Api.Request(requestId, executeExpression) + + session.rpcController ! AiProtocol.AiCompletionProgressNotification( + code, + reason, + visualizationId + ) + + context.become( + evalResponseStage( + id, + replyTo, + contextId, + expressionId, + visualizationId, + req, + messages, + model + ) + ) + } + + private def evalResponseStage( + id: Id, + replyTo: ActorRef, + contextId: Api.ContextId, + expressionId: Api.ExpressionId, + visualizationId: Api.VisualizationId, + request: AiProtocol.AiEvalRequest, + messages: Vector[AiProtocol.CompletionsMessage], + model: Option[String] + ): Receive = LoggingReceive.withLabel("evalResponseStage") { + case Api.VisualizationUpdate(ctx, data) + if ctx.visualizationId == visualizationId => + val visualizationResult = new String(data, StandardCharsets.UTF_8) + val message = AiProtocol.CompletionsMessage( + "user", + s"EVALUATED:\n${request.code}\n\nOUTPUT:\n$visualizationResult" + ) + val newMessages = messages :+ message + + val httpReq = sendHttpRequest(newMessages, model) + val debugInfo = DebugInfo(httpReq) + context.become( + awaitingCompletionResponse( + id, + replyTo, + contextId, + expressionId, + newMessages, + model, + debugInfo + ) + ) + + case Api.VisualizationEvaluationFailed(ctx, message, _) + if ctx.visualizationId == visualizationId => + val aiError = AiEvaluationError(request.code, message) + replyTo ! ResponseError(Some(id), aiError) + stop() + + case error: ContextRegistryProtocol.Failure => + replyTo ! ResponseError(Some(id), RuntimeFailureMapper.mapFailure(error)) + } + + private def awaitingCompletionResponse( + id: Id, + replyTo: ActorRef, + contextId: Api.ContextId, + expressionId: Api.ExpressionId, + messages: Vector[AiProtocol.CompletionsMessage], + model: Option[String], + debugInfo: DebugInfo + ): Receive = LoggingReceive.withLabel("awaitingCompletionStage") { + case HttpResponse(StatusCodes.OK, data) => + val responseUtf8String = data.utf8String + logger.trace("AI response:\n{}", responseUtf8String) + + parse(responseUtf8String) match { + case Some(response) => + getResponseKind(response) match { + case Some("final") => + getFinalResult(response).fold { + val aiError = AiHttpError( + "Failed to parse final kind of AI response", + debugInfo.httpReq, + responseUtf8String + ) + replyTo ! ResponseError(Some(id), aiError) + }(success => replyTo ! ResponseResult(AiCompletion2, id, success)) + stop() + + case Some("eval") => + getEvalResult(response).fold { + val aiError = AiHttpError( + "Failed to parse eval kind of AI response", + debugInfo.httpReq, + responseUtf8String + ) + replyTo ! ResponseError(Some(id), aiError) + stop() + } { evalRequest => + self ! evalRequest + context.become( + evalRequestStage( + id, + replyTo, + contextId, + expressionId, + messages, + model + ) + ) + } + + case Some("fail") => + getFailResult(response).fold { + val aiError = AiHttpError( + "Failed to parse fail kind of AI response", + debugInfo.httpReq, + responseUtf8String + ) + replyTo ! ResponseError(Some(id), aiError) + }(fail => replyTo ! ResponseResult(AiCompletion2, id, fail)) + stop() + + case _ => + val aiError = AiHttpError( + "Unknown kind of AI response", + debugInfo.httpReq, + responseUtf8String + ) + replyTo ! ResponseError(Some(id), aiError) + stop() + } + + case None => + val aiError = AiHttpError( + "Failed to parse AI response as JSON", + debugInfo.httpReq, + data.utf8String + ) + replyTo ! ResponseError(Some(id), aiError) + stop() + } + + case HttpResponse(status, data) => + val aiError = + AiHttpError( + s"Unknown AI response [${status.value}]", + debugInfo.httpReq, + data.utf8String + ) + replyTo ! ResponseError(Some(id), aiError) + stop() + } + + private def sendHttpRequest( + messages: Vector[AiProtocol.CompletionsMessage], + modelOption: Option[String] + ): Json = { + val body = Json.obj( + ("model", modelOption.getOrElse(MODEL).asJson), + ("response_format", Json.obj(("type", "json_object".asJson))), + ("messages", Json.arr(messages.map(_.asJson): _*)) + ) + + logger.trace("AI request:\n{}", body) + + val req = + HttpRequest( + uri = API_OPENAI_URI, + method = HttpMethods.POST, + headers = Seq(headers.Authorization(OAuth2BearerToken(cfg.apiKey))), + entity = HttpEntity(ContentTypes.`application/json`, body.noSpaces) + ) + + http + .singleRequest(req) + .flatMap(response => { + response.entity + .toStrict(FiniteDuration(10, "s")) + .map(e => { + HttpResponse(response.status, e.data) + }) + }) + .pipeTo(self) + + body + } + + private def stop(): Unit = { + self ! PoisonPill + } +} + +object AICompletion2Handler { + + private val MODEL = "gpt-4-turbo-preview" + private val API_OPENAI_URI = "https://api.openai.com/v1/chat/completions" + private val SYSTEM_PROMPT = + """You are a data analyst. You use Python3. Installed libraries: ['pandas']. + |Your task is to output JSON object with fields: + |- 'kind': 'final' + |- 'fn': String, Python function returning what user wants. Always write as generic code as possible that will work even if the input data (e.g. file content) changes. Do not assume any input data exists if not provided with it explicitly. Use your knowledge about the world if the provided data is missing. + |- 'fnCall': String, Python code that calls the generated function. + |- 'resultPreview': Code in Python. When evaluated, prints to stdout a preview of the result, e.g. 'Visualization.AI.print("Number 5")'. Make it as generic as possible. It should work even if the input data changes. The string written to stdout should be one-line, as short as possible, and as informative as possible, e.g. 'Table with 50 rows and columns "c1", "c2", and "c3"'. It can assume that the 'fnCall' result is in scope. + |- 'queryParts': Array of user query divided into either non-editable text, or widgets. The idea is that users can click widgets to change them to other values. Every part should be one of the JSON objects: + | * Fields: + | a. 'kind': 'text' + | b. 'text': Part of the user query that should not be widget. In particular, numbers shoul not be widgets. + | * Fields: + | a. 'kind': 'dropdown' + | b. 'values': list of possible values. For example, if the query contains name of a column in a data set, provide all other column names, like ['columnName1', 'columnName2']. If the query contains a common comparator like 'less than', provide other comparators like ['greater than', 'equal to']. The same applies to other comparators like 'most popular'. + | + |If in order to provide the answer you need to investigate what is inside the data, you can run code and be asked again the same question with provided stdout by outputing JSON object with fields: + |- 'kind': 'eval' + |- 'code': Python code required to investigate data. The code should write to stdout as little as possible. Use 'Visualization.AI.print' function for printing to stdout. You can only use data you are already provided with, no more data can be provided and you can't ask for more data. + |- 'reason': Reason why you were not able to provide final code. + |Always prefer outputting the final code. Use kind "eval" only if you can't output object with kind "final". + | + |If you can't provide the answer, because the current data and your knowledge about the world is not enough, output JSON object with fields: + |- 'kind': 'fail' + |- 'reason': Reason why you were not able to provide the answer. As short as possible. Do not mention Python nor code, this is information for non-tech users. + |""".stripMargin + + private case class HttpResponse(status: StatusCode, data: ByteString) + + private case class DebugInfo( + httpReq: Json + ) + + def props( + cfg: Option[AICompletionConfig], + session: JsonSession, + runtime: ActorRef + ): Props = + cfg + .map(conf => Props(new AICompletion2Handler(conf, session, runtime))) + .getOrElse(Props(new UnsupportedHandler(AiCompletion2))) + + private def parse(str: String): Option[Json] = + for { + response <- io.circe.parser.parse(str).toOption + responseObj <- response.asObject + choices <- responseObj("choices") + choicesArr <- choices.asArray + firstChoice <- choicesArr.headOption + firstChoiceObj <- firstChoice.asObject + message <- firstChoiceObj("message") + messageObj <- message.asObject + content <- messageObj("content") + contentString <- content.asString + contentJson <- io.circe.parser.parse(contentString).toOption + } yield contentJson + + private def getFinalResult( + response: Json + ): Option[AiProtocol.AiCompletionResult] = + for { + obj <- response.asObject + fn <- obj("fn") + fnString <- fn.asString + fnCall <- obj("fnCall") + fnCallString <- fnCall.asString + } yield AiProtocol.AiCompletionResult.Success(fnString, fnCallString) + + private def getEvalResult(response: Json): Option[AiProtocol.AiEvalRequest] = + for { + obj <- response.asObject + reason <- obj("reason") + reasonString <- reason.asString + code <- obj("code") + codeString <- code.asString + } yield AiProtocol.AiEvalRequest(reasonString, codeString) + + private def getFailResult( + response: Json + ): Option[AiProtocol.AiCompletionResult] = + for { + obj <- response.asObject + reason <- obj("reason") + reasonString <- reason.asString + } yield AiProtocol.AiCompletionResult.Failure(reasonString) + + private def getResponseKind(response: Json): Option[String] = + for { + obj <- response.asObject + key <- obj("kind") + keyString <- key.asString + } yield keyString + +} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletionHandler.scala b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletionHandler.scala index 2a0f9080e35..643fdeef056 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletionHandler.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletionHandler.scala @@ -3,7 +3,7 @@ package org.enso.languageserver.requesthandler.ai import akka.actor.{Actor, ActorRef, Props} import com.typesafe.scalalogging.LazyLogging import org.enso.jsonrpc.{Errors, Id, Request, ResponseError, ResponseResult} -import org.enso.languageserver.ai.AICompletion +import org.enso.languageserver.ai.AiApi.AiCompletion import org.enso.languageserver.util.UnhandledLogging import akka.http.scaladsl.model._ import akka.http.scaladsl.Http @@ -13,6 +13,7 @@ import akka.stream.Materializer import akka.util.ByteString import io.circe.Json import org.enso.languageserver.data.AICompletionConfig +import org.enso.languageserver.requesthandler.UnsupportedHandler import scala.concurrent.ExecutionContext import scala.concurrent.duration.FiniteDuration @@ -31,7 +32,7 @@ class AICompletionHandler(cfg: AICompletionConfig) implicit val materializer: Materializer = Materializer(context) private def requestStage: Receive = { - case Request(AICompletion, id, AICompletion.Params(prompt, stop)) => + case Request(AiCompletion, id, AiCompletion.Params(prompt, stop)) => val body = Json.fromFields( Seq( ("model", Json.fromString("gpt-3.5-turbo-instruct")), @@ -77,9 +78,9 @@ class AICompletionHandler(cfg: AICompletionConfig) firstChoiceText <- firstChoiceObj("text") firstChoiceTextStr <- firstChoiceText.asString } yield ResponseResult( - AICompletion, + AiCompletion, id, - AICompletion.Result(firstChoiceTextStr) + AiCompletion.Result(firstChoiceTextStr) ) val handledErrors = response.getOrElse(ResponseError(Some(id), Errors.ServiceError)) @@ -92,16 +93,6 @@ class AICompletionHandler(cfg: AICompletionConfig) } } -class UnsupportedHandler extends Actor with LazyLogging with UnhandledLogging { - override def receive: Receive = { case Request(AICompletion, id, _) => - sender() ! ResponseError( - Some(id), - Errors.MethodNotFound - ) - - } -} - object AICompletionHandler { def props(cfg: Option[AICompletionConfig]): Props = cfg .map(conf => @@ -109,5 +100,5 @@ object AICompletionHandler { new AICompletionHandler(conf) ) ) - .getOrElse(Props(new UnsupportedHandler())) + .getOrElse(Props(new UnsupportedHandler(AiCompletion))) }