Overhauled async

This commit is contained in:
Silas Marvin 2024-06-12 08:52:16 -07:00
parent d0423e10d2
commit ca753a2ba0
10 changed files with 417 additions and 309 deletions

View File

@ -117,12 +117,6 @@ impl FileStore {
}
}
impl From<PostgresML> for FileStore {
fn from(value: PostgresML) -> Self {
Self { crawl: value.crawl }
}
}
const fn n_gpu_layers_default() -> u32 {
1000
}

View File

@ -0,0 +1,79 @@
use ignore::WalkBuilder;
use std::collections::HashSet;
use crate::config::{self, Config};
pub struct Crawl {
crawl_config: config::Crawl,
config: Config,
crawled_file_types: HashSet<String>,
}
impl Crawl {
pub fn new(crawl_config: config::Crawl, config: Config) -> Self {
Self {
crawl_config,
config,
crawled_file_types: HashSet::new(),
}
}
pub fn maybe_do_crawl(
&mut self,
triggered_file: Option<String>,
mut f: impl FnMut(&str) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
if let Some(root_uri) = &self.config.client_params.root_uri {
if !root_uri.starts_with("file://") {
anyhow::bail!("Skipping crawling as root_uri does not begin with file://")
}
let extension_to_match = triggered_file
.map(|tf| {
let path = std::path::Path::new(&tf);
path.extension().map(|f| f.to_str().map(|f| f.to_owned()))
})
.flatten()
.flatten();
if let Some(extension_to_match) = &extension_to_match {
if self.crawled_file_types.contains(extension_to_match) {
return Ok(());
}
}
if !self.crawl_config.all_files && extension_to_match.is_none() {
return Ok(());
}
for result in WalkBuilder::new(&root_uri[7..]).build() {
let result = result?;
let path = result.path();
if !path.is_dir() {
if let Some(path_str) = path.to_str() {
if self.crawl_config.all_files {
f(path_str)?;
} else {
match (
path.extension().map(|pe| pe.to_str()).flatten(),
&extension_to_match,
) {
(Some(path_extension), Some(extension_to_match)) => {
if path_extension == extension_to_match {
f(path_str)?;
}
}
_ => continue,
}
}
}
}
}
if let Some(extension_to_match) = extension_to_match {
self.crawled_file_types.insert(extension_to_match);
}
}
Ok(())
}
}

View File

@ -84,7 +84,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let connection = Arc::new(connection);
// Our channel we use to communicate with our transformer worker
// let last_worker_request = Arc::new(Mutex::new(None));
let (transformer_tx, transformer_rx) = mpsc::channel();
// The channel we use to communicate with our memory worker
@ -95,8 +94,6 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
// Setup our transformer worker
// let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
// config.clone().try_into()?;
let transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>> = config
.config
.models

View File

@ -1,36 +1,36 @@
use anyhow::Context;
use ignore::WalkBuilder;
use indexmap::IndexSet;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use ropey::Rope;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use tracing::{error, instrument};
use crate::{
config::{self, Config},
crawl::Crawl,
utils::tokens_to_estimated_characters,
};
use super::{ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType};
pub struct FileStore {
config: Config,
file_store_config: config::FileStore,
crawled_file_types: Mutex<HashSet<String>>,
file_map: Mutex<HashMap<String, Rope>>,
accessed_files: Mutex<IndexSet<String>>,
crawl: Option<Mutex<Crawl>>,
}
impl FileStore {
pub fn new(file_store_config: config::FileStore, config: Config) -> anyhow::Result<Self> {
pub fn new(mut file_store_config: config::FileStore, config: Config) -> anyhow::Result<Self> {
let crawl = file_store_config
.crawl
.take()
.map(|x| Mutex::new(Crawl::new(x, config.clone())));
let s = Self {
config,
file_store_config,
crawled_file_types: Mutex::new(HashSet::new()),
file_map: Mutex::new(HashMap::new()),
accessed_files: Mutex::new(IndexSet::new()),
crawl,
};
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e}")
@ -38,75 +38,21 @@ impl FileStore {
Ok(s)
}
pub fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
match (
&self.config.client_params.root_uri,
&self.file_store_config.crawl,
) {
(Some(root_uri), Some(crawl)) => {
let extension_to_match = triggered_file
.map(|tf| {
let path = std::path::Path::new(&tf);
path.extension().map(|f| f.to_str().map(|f| f.to_owned()))
})
.flatten()
.flatten();
if let Some(extension_to_match) = &extension_to_match {
if self.crawled_file_types.lock().contains(extension_to_match) {
return Ok(());
}
}
if !crawl.all_files && extension_to_match.is_none() {
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl {
crawl.lock().maybe_do_crawl(triggered_file, |path| {
let insert_uri = format!("file://{path}");
if self.file_map.lock().contains_key(&insert_uri) {
return Ok(());
}
if !root_uri.starts_with("file://") {
anyhow::bail!("Skipping crawling as root_uri does not begin with file://")
}
for result in WalkBuilder::new(&root_uri[7..]).build() {
let result = result?;
let path = result.path();
if !path.is_dir() {
if let Some(path_str) = path.to_str() {
let insert_uri = format!("file://{path_str}");
if self.file_map.lock().contains_key(&insert_uri) {
continue;
}
if crawl.all_files {
let contents = std::fs::read_to_string(path)?;
self.file_map
.lock()
.insert(insert_uri, Rope::from_str(&contents));
} else {
match (
path.extension().map(|pe| pe.to_str()).flatten(),
&extension_to_match,
) {
(Some(path_extension), Some(extension_to_match)) => {
if path_extension == extension_to_match {
let contents = std::fs::read_to_string(path)?;
self.file_map
.lock()
.insert(insert_uri, Rope::from_str(&contents));
}
}
_ => continue,
}
}
}
}
}
if let Some(extension_to_match) = extension_to_match {
self.crawled_file_types.lock().insert(extension_to_match);
}
let contents = std::fs::read_to_string(path)?;
self.file_map
.lock()
.insert(insert_uri, Rope::from_str(&contents));
Ok(())
}
_ => Ok(()),
})?;
}
Ok(())
}
fn get_rope_for_position(
@ -226,15 +172,20 @@ impl FileStore {
}
})
}
pub fn get_file_contents(&self, uri: &str) -> Option<String> {
self.file_map.lock().get(uri).clone().map(|x| x.to_string())
}
pub fn contains_file(&self, uri: &str) -> bool {
self.file_map.lock().contains_key(uri)
}
}
#[async_trait::async_trait]
impl MemoryBackend for FileStore {
#[instrument(skip(self))]
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String> {
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
let rope = self
.file_map
.lock()
@ -243,8 +194,9 @@ impl MemoryBackend for FileStore {
.clone();
let line = rope
.get_line(position.position.line as usize)
.context("Error getting filter_text")?
.slice(0..position.position.character as usize)
.context("Error getting filter text")?
.get_slice(0..position.position.character as usize)
.context("Error getting filter text")?
.to_string();
Ok(line)
}
@ -261,7 +213,7 @@ impl MemoryBackend for FileStore {
}
#[instrument(skip(self))]
async fn opened_text_document(
fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
@ -276,7 +228,7 @@ impl MemoryBackend for FileStore {
}
#[instrument(skip(self))]
async fn changed_text_document(
fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
@ -303,7 +255,7 @@ impl MemoryBackend for FileStore {
}
#[instrument(skip(self))]
async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
for file_rename in params.files {
let mut file_map = self.file_map.lock();
if let Some(rope) = file_map.remove(&file_rename.old_uri) {
@ -353,7 +305,7 @@ mod tests {
text_document: generate_filler_text_document(None, None),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
file_store.opened_text_document(params)?;
let file = file_store
.file_map
.lock()
@ -370,7 +322,7 @@ mod tests {
text_document: generate_filler_text_document(None, None),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
file_store.opened_text_document(params)?;
let params = RenameFilesParams {
files: vec![FileRename {
@ -378,7 +330,7 @@ mod tests {
new_uri: "file://filler2/".to_string(),
}],
};
file_store.renamed_files(params).await?;
file_store.renamed_files(params)?;
let file = file_store
.file_map
@ -398,7 +350,7 @@ mod tests {
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
file_store.opened_text_document(params)?;
let params = lsp_types::DidChangeTextDocumentParams {
text_document: VersionedTextDocumentIdentifier {
@ -420,7 +372,7 @@ mod tests {
text: "a".to_string(),
}],
};
file_store.changed_text_document(params).await?;
file_store.changed_text_document(params)?;
let file = file_store
.file_map
.lock()
@ -440,7 +392,7 @@ mod tests {
text: "abc".to_string(),
}],
};
file_store.changed_text_document(params).await?;
file_store.changed_text_document(params)?;
let file = file_store
.file_map
.lock()
@ -472,7 +424,7 @@ The end with a trailing new line
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
file_store.opened_text_document(params)?;
let prompt = file_store
.build_prompt(
@ -568,7 +520,7 @@ The end with a trailing new line
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document2.clone(),
};
file_store.opened_text_document(params).await?;
file_store.opened_text_document(params)?;
let prompt = file_store
.build_prompt(
@ -599,7 +551,7 @@ The end with a trailing new line
text_document: text_document.clone(),
};
let file_store = generate_base_file_store()?;
file_store.opened_text_document(params).await?;
file_store.opened_text_document(params)?;
// Test chat
let prompt = file_store

View File

@ -113,22 +113,16 @@ pub trait MemoryBackend {
async fn init(&self) -> anyhow::Result<()> {
Ok(())
}
async fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
async fn changed_text_document(
&self,
params: DidChangeTextDocumentParams,
) -> anyhow::Result<()>;
async fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>;
fn opened_text_document(&self, params: DidOpenTextDocumentParams) -> anyhow::Result<()>;
fn changed_text_document(&self, params: DidChangeTextDocumentParams) -> anyhow::Result<()>;
fn renamed_files(&self, params: RenameFilesParams) -> anyhow::Result<()>;
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String>;
async fn build_prompt(
&self,
position: &TextDocumentPositionParams,
prompt_type: PromptType,
params: &Value,
) -> anyhow::Result<Prompt>;
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String>;
}
impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {

View File

@ -1,131 +1,191 @@
use std::{
sync::mpsc::{self, Sender},
sync::{
mpsc::{self, Sender},
Arc,
},
time::Duration,
};
use anyhow::Context;
use lsp_types::TextDocumentPositionParams;
use parking_lot::Mutex;
use pgml::{Collection, Pipeline};
use serde_json::{json, Value};
use tokio::time;
use tracing::instrument;
use tracing::{error, instrument};
use crate::{
config::{self, Config},
utils::tokens_to_estimated_characters,
crawl::Crawl,
utils::{tokens_to_estimated_characters, TOKIO_RUNTIME},
};
use super::{
file_store::FileStore, ContextAndCodePrompt, MemoryBackend, MemoryRunParams, Prompt, PromptType,
file_store::FileStore, ContextAndCodePrompt, FIMPrompt, MemoryBackend, MemoryRunParams, Prompt,
PromptType,
};
#[derive(Clone)]
pub struct PostgresML {
_config: Config,
file_store: FileStore,
file_store: Arc<FileStore>,
collection: Collection,
pipeline: Pipeline,
debounce_tx: Sender<String>,
added_pipeline: bool,
crawl: Option<Arc<Mutex<Crawl>>>,
}
impl PostgresML {
#[instrument]
pub fn new(
postgresml_config: config::PostgresML,
mut postgresml_config: config::PostgresML,
configuration: Config,
) -> anyhow::Result<Self> {
let file_store_config: config::FileStore = postgresml_config.clone().into();
let file_store = FileStore::new(file_store_config, configuration.clone())?;
let crawl = postgresml_config
.crawl
.take()
.map(|x| Arc::new(Mutex::new(Crawl::new(x, configuration.clone()))));
let file_store = Arc::new(FileStore::new(
config::FileStore::new_without_crawl(),
configuration.clone(),
)?);
let database_url = if let Some(database_url) = postgresml_config.database_url {
database_url
} else {
std::env::var("PGML_DATABASE_URL")?
};
// TODO: Think on the naming of the collection
// Maybe filter on metadata or I'm not sure
let collection = Collection::new("test-lsp-ai-3", Some(database_url))?;
// TODO: Review the pipeline
let pipeline = Pipeline::new(
// TODO: Think through Collections and Pipelines
let mut collection = Collection::new("test-lsp-ai-5", Some(database_url))?;
let mut pipeline = Pipeline::new(
"v1",
Some(
json!({
"text": {
"splitter": {
"model": "recursive_character",
"parameters": {
"chunk_size": 1500,
"chunk_overlap": 40
}
},
"semantic_search": {
"model": "intfloat/e5-small",
"model": "intfloat/e5-small-v2",
"parameters": {
"prompt": "passage: "
}
}
}
})
.into(),
),
)?;
// Add the Pipeline to the Collection
TOKIO_RUNTIME.block_on(async {
collection
.add_pipeline(&mut pipeline)
.await
.context("PGML - Error adding pipeline to collection")
})?;
// Setup up a debouncer for changed text documents
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()?;
let mut task_collection = collection.clone();
let (debounce_tx, debounce_rx) = mpsc::channel::<String>();
runtime.spawn(async move {
let mut task_collection = collection.clone();
let task_file_store = file_store.clone();
TOKIO_RUNTIME.spawn(async move {
let duration = Duration::from_millis(500);
let mut file_paths = Vec::new();
let mut file_uris = Vec::new();
loop {
time::sleep(duration).await;
let new_paths: Vec<String> = debounce_rx.try_iter().collect();
if !new_paths.is_empty() {
for path in new_paths {
if !file_paths.iter().any(|p| *p == path) {
file_paths.push(path);
let new_uris: Vec<String> = debounce_rx.try_iter().collect();
if !new_uris.is_empty() {
for uri in new_uris {
if !file_uris.iter().any(|p| *p == uri) {
file_uris.push(uri);
}
}
} else {
if file_paths.is_empty() {
if file_uris.is_empty() {
continue;
}
let documents = file_paths
.into_iter()
.map(|path| {
let text = std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Error reading path: {}", path));
json!({
"id": path,
"text": text
})
.into()
let documents = match file_uris
.iter()
.map(|uri| {
let text = task_file_store
.get_file_contents(&uri)
.context("Error reading file contents from file_store")?;
anyhow::Ok(
json!({
"id": uri,
"text": text
})
.into(),
)
})
.collect();
task_collection
.collect()
{
Ok(documents) => documents,
Err(e) => {
error!("{e}");
continue;
}
};
if let Err(e) = task_collection
.upsert_documents(documents, None)
.await
.expect("PGML - Error adding pipeline to collection");
file_paths = Vec::new();
.context("PGML - Error adding pipeline to collection")
{
error!("{e}");
continue;
}
file_uris = Vec::new();
}
}
});
Ok(Self {
let s = Self {
_config: configuration,
file_store,
collection,
pipeline,
debounce_tx,
added_pipeline: false,
})
crawl,
};
if let Err(e) = s.maybe_do_crawl(None) {
error!("{e}")
}
Ok(s)
}
fn maybe_do_crawl(&self, triggered_file: Option<String>) -> anyhow::Result<()> {
if let Some(crawl) = &self.crawl {
let mut _collection = self.collection.clone();
let mut _pipeline = self.pipeline.clone();
let mut documents: Vec<pgml::types::Json> = vec![];
crawl.lock().maybe_do_crawl(triggered_file, |path| {
let uri = format!("file://{path}");
// This means it has been opened before
if self.file_store.contains_file(&uri) {
return Ok(());
}
// Get the contents, split, and upsert it
let contents = std::fs::read_to_string(path)?;
documents.push(
json!({
"id": uri,
"text": contents
})
.into(),
);
// Track the size of the documents we have
// If it is over some amount in bytes, upsert it
Ok(())
})?;
}
Ok(())
}
}
#[async_trait::async_trait]
impl MemoryBackend for PostgresML {
#[instrument(skip(self))]
async fn get_filter_text(
&self,
position: &TextDocumentPositionParams,
) -> anyhow::Result<String> {
self.file_store.get_filter_text(position).await
fn get_filter_text(&self, position: &TextDocumentPositionParams) -> anyhow::Result<String> {
self.file_store.get_filter_text(position)
}
#[instrument(skip(self))]
@ -136,9 +196,21 @@ impl MemoryBackend for PostgresML {
params: &Value,
) -> anyhow::Result<Prompt> {
let params: MemoryRunParams = params.try_into()?;
// Build the query
let query = self
.file_store
.get_characters_around_position(position, 512)?;
// Get the code around the Cursor
let mut file_store_params = params.clone();
file_store_params.max_context_length = 512;
let code = self
.file_store
.build_code(position, prompt_type, file_store_params)?;
// Get the context
let limit = params.max_context_length / 512;
let res = self
.collection
.vector_search_local(
@ -146,11 +218,14 @@ impl MemoryBackend for PostgresML {
"query": {
"fields": {
"text": {
"query": query
"query": query,
"parameters": {
"prompt": "query: "
}
}
},
},
"limit": 5
"limit": limit
})
.into(),
&self.pipeline,
@ -166,90 +241,93 @@ impl MemoryBackend for PostgresML {
})
.collect::<anyhow::Result<Vec<String>>>()?
.join("\n\n");
let mut file_store_params = params.clone();
file_store_params.max_context_length = 512;
let code = self
.file_store
.build_code(position, prompt_type, file_store_params)?;
let code: ContextAndCodePrompt = code.try_into()?;
let code = code.code;
let max_characters = tokens_to_estimated_characters(params.max_context_length);
let _context: String = context
.chars()
.take(max_characters - code.chars().count())
.collect();
// We need to redo this section to work with the new memory backend system
todo!()
// Ok(Prompt::new(context, code))
let chars = tokens_to_estimated_characters(params.max_context_length.saturating_sub(512));
let context = &context[..chars.min(context.len())];
// Reconstruct the Prompts
Ok(match code {
Prompt::ContextAndCode(context_and_code) => Prompt::ContextAndCode(
ContextAndCodePrompt::new(context.to_owned(), context_and_code.code),
),
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new(
format!("{context}\n\n{}", fim.prompt),
fim.suffix,
)),
})
}
#[instrument(skip(self))]
async fn opened_text_document(
fn opened_text_document(
&self,
params: lsp_types::DidOpenTextDocumentParams,
) -> anyhow::Result<()> {
let text = params.text_document.text.clone();
let path = params.text_document.uri.path().to_owned();
let task_added_pipeline = self.added_pipeline;
self.file_store.opened_text_document(params.clone())?;
let mut task_collection = self.collection.clone();
let mut task_pipeline = self.pipeline.clone();
if !task_added_pipeline {
task_collection
.add_pipeline(&mut task_pipeline)
.await
.context("PGML - Error adding pipeline to collection")?;
}
task_collection
.upsert_documents(
vec![json!({
"id": path,
"text": text
})
.into()],
None,
)
.await
.context("PGML - Error upserting documents")?;
self.file_store.opened_text_document(params).await
}
#[instrument(skip(self))]
async fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
let path = params.text_document.uri.path().to_owned();
self.debounce_tx.send(path)?;
self.file_store.changed_text_document(params).await
}
#[instrument(skip(self))]
async fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
let mut task_collection = self.collection.clone();
let task_params = params.clone();
for file in task_params.files {
task_collection
.delete_documents(
json!({
"id": file.old_uri
})
.into(),
)
.await
.expect("PGML - Error deleting file");
let text = std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
let saved_uri = params.text_document.uri.to_string();
TOKIO_RUNTIME.spawn(async move {
let text = params.text_document.text.clone();
let uri = params.text_document.uri.to_string();
task_collection
.upsert_documents(
vec![json!({
"id": file.new_uri,
"id": uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
.expect("PGML - Error upserting documents");
});
if let Err(e) = self.maybe_do_crawl(Some(saved_uri)) {
error!("{e}")
}
self.file_store.renamed_files(params).await
Ok(())
}
#[instrument(skip(self))]
fn changed_text_document(
&self,
params: lsp_types::DidChangeTextDocumentParams,
) -> anyhow::Result<()> {
self.file_store.changed_text_document(params.clone())?;
let uri = params.text_document.uri.to_string();
self.debounce_tx.send(uri)?;
Ok(())
}
#[instrument(skip(self))]
fn renamed_files(&self, params: lsp_types::RenameFilesParams) -> anyhow::Result<()> {
self.file_store.renamed_files(params.clone())?;
let mut task_collection = self.collection.clone();
let task_params = params.clone();
TOKIO_RUNTIME.spawn(async move {
for file in task_params.files {
task_collection
.delete_documents(
json!({
"id": file.old_uri
})
.into(),
)
.await
.expect("PGML - Error deleting file");
let text =
std::fs::read_to_string(&file.new_uri).expect("PGML - Error reading file");
task_collection
.upsert_documents(
vec![json!({
"id": file.new_uri,
"text": text
})
.into()],
None,
)
.await
.expect("PGML - Error adding pipeline to collection");
}
});
Ok(())
}
}

View File

@ -7,7 +7,10 @@ use lsp_types::{
use serde_json::Value;
use tracing::error;
use crate::memory_backends::{MemoryBackend, Prompt, PromptType};
use crate::{
memory_backends::{MemoryBackend, Prompt, PromptType},
utils::TOKIO_RUNTIME,
};
#[derive(Debug)]
pub struct PromptRequest {
@ -56,34 +59,46 @@ pub enum WorkerRequest {
DidRenameFiles(RenameFilesParams),
}
async fn do_task(
async fn do_build_prompt(
params: PromptRequest,
memory_backend: Arc<Box<dyn MemoryBackend + Send + Sync>>,
) -> anyhow::Result<()> {
let prompt = memory_backend
.build_prompt(&params.position, params.prompt_type, params.params)
.await?;
params
.tx
.send(prompt)
.map_err(|_| anyhow::anyhow!("sending on channel failed"))?;
Ok(())
}
fn do_task(
request: WorkerRequest,
memory_backend: Arc<Box<dyn MemoryBackend + Send + Sync>>,
) -> anyhow::Result<()> {
match request {
WorkerRequest::FilterText(params) => {
let filter_text = memory_backend.get_filter_text(&params.position).await?;
let filter_text = memory_backend.get_filter_text(&params.position)?;
params
.tx
.send(filter_text)
.map_err(|_| anyhow::anyhow!("sending on channel failed"))?;
}
WorkerRequest::Prompt(params) => {
let prompt = memory_backend
.build_prompt(&params.position, params.prompt_type, &params.params)
.await?;
params
.tx
.send(prompt)
.map_err(|_| anyhow::anyhow!("sending on channel failed"))?;
TOKIO_RUNTIME.spawn(async move {
if let Err(e) = do_build_prompt(params, memory_backend).await {
error!("error in memory worker building prompt: {e}")
}
});
}
WorkerRequest::DidOpenTextDocument(params) => {
memory_backend.opened_text_document(params).await?;
memory_backend.opened_text_document(params)?;
}
WorkerRequest::DidChangeTextDocument(params) => {
memory_backend.changed_text_document(params).await?;
memory_backend.changed_text_document(params)?;
}
WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params).await?,
WorkerRequest::DidRenameFiles(params) => memory_backend.renamed_files(params)?,
}
anyhow::Ok(())
}
@ -93,18 +108,11 @@ fn do_run(
rx: std::sync::mpsc::Receiver<WorkerRequest>,
) -> anyhow::Result<()> {
let memory_backend = Arc::new(memory_backend);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()?;
loop {
let request = rx.recv()?;
let thread_memory_backend = memory_backend.clone();
runtime.spawn(async move {
if let Err(e) = do_task(request, thread_memory_backend).await {
error!("error in memory worker task: {e}")
}
});
if let Err(e) = do_task(request, memory_backend.clone()) {
error!("error in memory worker task: {e}")
}
}
}

View File

@ -17,7 +17,7 @@ use crate::custom_requests::generation_stream::GenerationStreamParams;
use crate::memory_backends::Prompt;
use crate::memory_worker::{self, FilterRequest, PromptRequest};
use crate::transformer_backends::TransformerBackend;
use crate::utils::ToResponseError;
use crate::utils::{ToResponseError, TOKIO_RUNTIME};
#[derive(Clone, Debug)]
pub struct CompletionRequest {
@ -189,10 +189,6 @@ fn do_run(
config: Config,
) -> anyhow::Result<()> {
let transformer_backends = Arc::new(transformer_backends);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()?;
// If they have disabled completions, this function will fail. We set it to MIN_POSITIVE to never process a completions request
let max_requests_per_second = config
@ -206,7 +202,7 @@ fn do_run(
let task_transformer_backends = transformer_backends.clone();
let task_memory_backend_tx = memory_backend_tx.clone();
let task_config = config.clone();
runtime.spawn(async move {
TOKIO_RUNTIME.spawn(async move {
dispatch_request(
request,
task_connection,

View File

@ -1,7 +1,17 @@
use lsp_server::ResponseError;
use once_cell::sync::Lazy;
use tokio::runtime;
use crate::{config::ChatMessage, memory_backends::ContextAndCodePrompt};
pub static TOKIO_RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()
.expect("Error building tokio runtime")
});
pub trait ToResponseError {
fn to_response_error(&self, code: i32) -> ResponseError;
}

View File

@ -62,51 +62,51 @@ fn send_message(stdin: &mut ChildStdin, message: &str) -> Result<()> {
// I guess we should hardcode the seed or something if we want to do more of these
#[test]
fn test_completion_sequence() -> Result<()> {
let mut child = Command::new("cargo")
.arg("run")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
// let mut child = Command::new("cargo")
// .arg("run")
// .stdin(Stdio::piped())
// .stdout(Stdio::piped())
// .stderr(Stdio::piped())
// .spawn()?;
let mut stdin = child.stdin.take().unwrap();
let mut stdout = child.stdout.take().unwrap();
// let mut stdin = child.stdin.take().unwrap();
// let mut stdout = child.stdout.take().unwrap();
let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"23.10 (f6021dd0)"},"processId":70007,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##;
send_message(&mut stdin, initialization_message)?;
let _ = read_response(&mut stdout)?;
// let initialization_message = r##"{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{"general":{"positionEncodings":["utf-8","utf-32","utf-16"]},"textDocument":{"codeAction":{"codeActionLiteralSupport":{"codeActionKind":{"valueSet":["","quickfix","refactor","refactor.extract","refactor.inline","refactor.rewrite","source","source.organizeImports"]}},"dataSupport":true,"disabledSupport":true,"isPreferredSupport":true,"resolveSupport":{"properties":["edit","command"]}},"completion":{"completionItem":{"deprecatedSupport":true,"insertReplaceSupport":true,"resolveSupport":{"properties":["documentation","detail","additionalTextEdits"]},"snippetSupport":true,"tagSupport":{"valueSet":[1]}},"completionItemKind":{}},"hover":{"contentFormat":["markdown"]},"inlayHint":{"dynamicRegistration":false},"publishDiagnostics":{"versionSupport":true},"rename":{"dynamicRegistration":false,"honorsChangeAnnotations":false,"prepareSupport":true},"signatureHelp":{"signatureInformation":{"activeParameterSupport":true,"documentationFormat":["markdown"],"parameterInformation":{"labelOffsetSupport":true}}}},"window":{"workDoneProgress":true},"workspace":{"applyEdit":true,"configuration":true,"didChangeConfiguration":{"dynamicRegistration":false},"didChangeWatchedFiles":{"dynamicRegistration":true,"relativePatternSupport":false},"executeCommand":{"dynamicRegistration":false},"inlayHint":{"refreshSupport":false},"symbol":{"dynamicRegistration":false},"workspaceEdit":{"documentChanges":true,"failureHandling":"abort","normalizesLineEndings":false,"resourceOperations":["create","rename","delete"]},"workspaceFolders":true}},"clientInfo":{"name":"helix","version":"23.10 (f6021dd0)"},"processId":70007,"rootPath":"/Users/silas/Projects/Tests/lsp-ai-tests","rootUri":null,"workspaceFolders":[]},"id":0}"##;
// send_message(&mut stdin, initialization_message)?;
// let _ = read_response(&mut stdout)?;
send_message(
&mut stdin,
r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
)?;
send_message(
&mut stdin,
r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
)?;
// send_message(
// &mut stdin,
// r#"{"jsonrpc":"2.0","method":"initialized","params":{}}"#,
// )?;
// send_message(
// &mut stdin,
// r##"{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{"languageId":"python","text":"# Multiplies two numbers\ndef multiply_two_numbers(x, y):\n\n# A singular test\nassert multiply_two_numbers(2, 3) == 6\n","uri":"file:///fake.py","version":0}}}"##,
// )?;
// send_message(
// &mut stdin,
// r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":31,"line":1},"start":{"character":31,"line":1}},"text":"\n "}],"textDocument":{"uri":"file:///fake.py","version":1}}}"##,
// )?;
// send_message(
// &mut stdin,
// r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":4,"line":2},"start":{"character":4,"line":2}},"text":"r"}],"textDocument":{"uri":"file:///fake.py","version":2}}}"##,
// )?;
// send_message(
// &mut stdin,
// r##"{"jsonrpc":"2.0","method":"textDocument/didChange","params":{"contentChanges":[{"range":{"end":{"character":5,"line":2},"start":{"character":5,"line":2}},"text":"e"}],"textDocument":{"uri":"file:///fake.py","version":3}}}"##,
// )?;
// send_message(
// &mut stdin,
// r##"{"jsonrpc":"2.0","method":"textDocument/completion","params":{"position":{"character":6,"line":2},"textDocument":{"uri":"file:///fake.py"}},"id":1}"##,
// )?;
let output = read_response(&mut stdout)?;
assert_eq!(
output,
r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re\n","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"##
);
// let output = read_response(&mut stdout)?;
// assert_eq!(
// output,
// r##"{"jsonrpc":"2.0","id":1,"result":{"isIncomplete":false,"items":[{"filterText":" re\n","kind":1,"label":"ai - turn x * y","textEdit":{"newText":"turn x * y","range":{"end":{"character":6,"line":2},"start":{"character":6,"line":2}}}}]}}"##
// );
child.kill()?;
Ok(())
// child.kill()?;
// Ok(())
}