mirror of
https://github.com/zed-industries/zed.git
synced 2024-09-18 18:08:07 +03:00
Merge branch 'zed2' into zed2-workspace
This commit is contained in:
commit
3dadfb8ba8
42
Cargo.lock
generated
42
Cargo.lock
generated
@ -108,6 +108,33 @@ dependencies = [
|
|||||||
"util",
|
"util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ai2"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
|
"bincode",
|
||||||
|
"futures 0.3.28",
|
||||||
|
"gpui2",
|
||||||
|
"isahc",
|
||||||
|
"language2",
|
||||||
|
"lazy_static",
|
||||||
|
"log",
|
||||||
|
"matrixmultiply",
|
||||||
|
"ordered-float 2.10.0",
|
||||||
|
"parking_lot 0.11.2",
|
||||||
|
"parse_duration",
|
||||||
|
"postage",
|
||||||
|
"rand 0.8.5",
|
||||||
|
"regex",
|
||||||
|
"rusqlite",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tiktoken-rs",
|
||||||
|
"util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alacritty_config"
|
name = "alacritty_config"
|
||||||
version = "0.1.2-dev"
|
version = "0.1.2-dev"
|
||||||
@ -1138,7 +1165,7 @@ dependencies = [
|
|||||||
"audio2",
|
"audio2",
|
||||||
"client2",
|
"client2",
|
||||||
"collections",
|
"collections",
|
||||||
"fs",
|
"fs2",
|
||||||
"futures 0.3.28",
|
"futures 0.3.28",
|
||||||
"gpui2",
|
"gpui2",
|
||||||
"language2",
|
"language2",
|
||||||
@ -4795,6 +4822,13 @@ dependencies = [
|
|||||||
"gpui",
|
"gpui",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "menu2"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"gpui2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "metal"
|
name = "metal"
|
||||||
version = "0.21.0"
|
version = "0.21.0"
|
||||||
@ -6000,7 +6034,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"client2",
|
"client2",
|
||||||
"collections",
|
"collections",
|
||||||
"fs",
|
"fs2",
|
||||||
"futures 0.3.28",
|
"futures 0.3.28",
|
||||||
"gpui2",
|
"gpui2",
|
||||||
"language2",
|
"language2",
|
||||||
@ -6167,7 +6201,7 @@ dependencies = [
|
|||||||
"ctor",
|
"ctor",
|
||||||
"db2",
|
"db2",
|
||||||
"env_logger 0.9.3",
|
"env_logger 0.9.3",
|
||||||
"fs",
|
"fs2",
|
||||||
"fsevent",
|
"fsevent",
|
||||||
"futures 0.3.28",
|
"futures 0.3.28",
|
||||||
"fuzzy2",
|
"fuzzy2",
|
||||||
@ -8740,6 +8774,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"clap 4.4.4",
|
"clap 4.4.4",
|
||||||
|
"convert_case 0.6.0",
|
||||||
"gpui2",
|
"gpui2",
|
||||||
"log",
|
"log",
|
||||||
"rust-embed",
|
"rust-embed",
|
||||||
@ -10932,6 +10967,7 @@ dependencies = [
|
|||||||
name = "zed2"
|
name = "zed2"
|
||||||
version = "0.109.0"
|
version = "0.109.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"ai2",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-compression",
|
"async-compression",
|
||||||
"async-recursion 0.3.2",
|
"async-recursion 0.3.2",
|
||||||
|
@ -59,6 +59,7 @@ members = [
|
|||||||
"crates/lsp2",
|
"crates/lsp2",
|
||||||
"crates/media",
|
"crates/media",
|
||||||
"crates/menu",
|
"crates/menu",
|
||||||
|
"crates/menu2",
|
||||||
"crates/multi_buffer",
|
"crates/multi_buffer",
|
||||||
"crates/node_runtime",
|
"crates/node_runtime",
|
||||||
"crates/notifications",
|
"crates/notifications",
|
||||||
|
38
crates/Cargo.toml
Normal file
38
crates/Cargo.toml
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
[package]
|
||||||
|
name = "ai"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/ai.rs"
|
||||||
|
doctest = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = []
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
gpui = { path = "../gpui" }
|
||||||
|
util = { path = "../util" }
|
||||||
|
language = { path = "../language" }
|
||||||
|
async-trait.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
lazy_static.workspace = true
|
||||||
|
ordered-float.workspace = true
|
||||||
|
parking_lot.workspace = true
|
||||||
|
isahc.workspace = true
|
||||||
|
regex.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
postage.workspace = true
|
||||||
|
rand.workspace = true
|
||||||
|
log.workspace = true
|
||||||
|
parse_duration = "2.1.1"
|
||||||
|
tiktoken-rs = "0.5.0"
|
||||||
|
matrixmultiply = "0.3.7"
|
||||||
|
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||||
|
bincode = "1.3.3"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
gpui = { path = "../gpui", features = ["test-support"] }
|
@ -8,6 +8,9 @@ publish = false
|
|||||||
path = "src/ai.rs"
|
path = "src/ai.rs"
|
||||||
doctest = false
|
doctest = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
gpui = { path = "../gpui" }
|
gpui = { path = "../gpui" }
|
||||||
util = { path = "../util" }
|
util = { path = "../util" }
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
|
pub mod auth;
|
||||||
pub mod completion;
|
pub mod completion;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod templates;
|
pub mod prompts;
|
||||||
|
pub mod providers;
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub mod test;
|
||||||
|
15
crates/ai/src/auth.rs
Normal file
15
crates/ai/src/auth.rs
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
use gpui::AppContext;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum ProviderCredential {
|
||||||
|
Credentials { api_key: String },
|
||||||
|
NoCredentials,
|
||||||
|
NotNeeded,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CredentialProvider: Send + Sync {
|
||||||
|
fn has_credentials(&self) -> bool;
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
|
||||||
|
fn delete_credentials(&self, cx: &AppContext);
|
||||||
|
}
|
@ -1,214 +1,23 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::Result;
|
||||||
use futures::{
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
|
||||||
Stream, StreamExt,
|
|
||||||
};
|
|
||||||
use gpui::executor::Background;
|
|
||||||
use isahc::{http::StatusCode, Request, RequestExt};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::{
|
|
||||||
fmt::{self, Display},
|
|
||||||
io,
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
pub trait CompletionRequest: Send + Sync {
|
||||||
#[serde(rename_all = "lowercase")]
|
fn data(&self) -> serde_json::Result<String>;
|
||||||
pub enum Role {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
System,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Role {
|
pub trait CompletionProvider: CredentialProvider {
|
||||||
pub fn cycle(&mut self) {
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
*self = match self {
|
|
||||||
Role::User => Role::Assistant,
|
|
||||||
Role::Assistant => Role::System,
|
|
||||||
Role::System => Role::User,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Role {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Role::User => write!(f, "User"),
|
|
||||||
Role::Assistant => write!(f, "Assistant"),
|
|
||||||
Role::System => write!(f, "System"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct RequestMessage {
|
|
||||||
pub role: Role,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize)]
|
|
||||||
pub struct OpenAIRequest {
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<RequestMessage>,
|
|
||||||
pub stream: bool,
|
|
||||||
pub stop: Vec<String>,
|
|
||||||
pub temperature: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct ResponseMessage {
|
|
||||||
pub role: Option<Role>,
|
|
||||||
pub content: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct OpenAIUsage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct ChatChoiceDelta {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ResponseMessage,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct OpenAIResponseStreamEvent {
|
|
||||||
pub id: Option<String>,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u32,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatChoiceDelta>,
|
|
||||||
pub usage: Option<OpenAIUsage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn stream_completion(
|
|
||||||
api_key: String,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
mut request: OpenAIRequest,
|
|
||||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
|
||||||
request.stream = true;
|
|
||||||
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
|
||||||
|
|
||||||
let json_data = serde_json::to_string(&request)?;
|
|
||||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.body(json_data)?
|
|
||||||
.send_async()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
if status == StatusCode::OK {
|
|
||||||
executor
|
|
||||||
.spawn(async move {
|
|
||||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
|
||||||
|
|
||||||
fn parse_line(
|
|
||||||
line: Result<String, io::Error>,
|
|
||||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
|
||||||
if let Some(data) = line?.strip_prefix("data: ") {
|
|
||||||
let event = serde_json::from_str(&data)?;
|
|
||||||
Ok(Some(event))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Some(line) = lines.next().await {
|
|
||||||
if let Some(event) = parse_line(line).transpose() {
|
|
||||||
let done = event.as_ref().map_or(false, |event| {
|
|
||||||
event
|
|
||||||
.choices
|
|
||||||
.last()
|
|
||||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
|
||||||
});
|
|
||||||
if tx.unbounded_send(event).is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if done {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
})
|
|
||||||
.detach();
|
|
||||||
|
|
||||||
Ok(rx)
|
|
||||||
} else {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIResponse {
|
|
||||||
error: OpenAIError,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIError {
|
|
||||||
message: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
|
||||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
|
||||||
"Failed to connect to OpenAI API: {}",
|
|
||||||
response.error.message,
|
|
||||||
)),
|
|
||||||
|
|
||||||
_ => Err(anyhow!(
|
|
||||||
"Failed to connect to OpenAI API: {} {}",
|
|
||||||
response.status(),
|
|
||||||
body,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CompletionProvider {
|
|
||||||
fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
prompt: OpenAIRequest,
|
prompt: Box<dyn CompletionRequest>,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAICompletionProvider {
|
impl Clone for Box<dyn CompletionProvider> {
|
||||||
api_key: String,
|
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
executor: Arc<Background>,
|
self.box_clone()
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAICompletionProvider {
|
|
||||||
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
|
|
||||||
Self { api_key, executor }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionProvider for OpenAICompletionProvider {
|
|
||||||
fn complete(
|
|
||||||
&self,
|
|
||||||
prompt: OpenAIRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
|
||||||
async move {
|
|
||||||
let response = request.await?;
|
|
||||||
let stream = response
|
|
||||||
.filter_map(|response| async move {
|
|
||||||
match response {
|
|
||||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed();
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,32 +1,13 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::AsyncReadExt;
|
|
||||||
use gpui::executor::Background;
|
|
||||||
use gpui::{serde_json, AppContext};
|
|
||||||
use isahc::http::StatusCode;
|
|
||||||
use isahc::prelude::Configurable;
|
|
||||||
use isahc::{AsyncBody, Response};
|
|
||||||
use lazy_static::lazy_static;
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use parking_lot::Mutex;
|
|
||||||
use parse_duration::parse;
|
|
||||||
use postage::watch;
|
|
||||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||||
use rusqlite::ToSql;
|
use rusqlite::ToSql;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::env;
|
|
||||||
use std::ops::Add;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
|
||||||
use util::http::{HttpClient, Request};
|
|
||||||
use util::ResultExt;
|
|
||||||
|
|
||||||
use crate::completion::OPENAI_API_URL;
|
use crate::auth::CredentialProvider;
|
||||||
|
use crate::models::LanguageModel;
|
||||||
lazy_static! {
|
|
||||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
pub struct Embedding(pub Vec<f32>);
|
pub struct Embedding(pub Vec<f32>);
|
||||||
@ -87,301 +68,14 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct OpenAIEmbeddings {
|
|
||||||
pub client: Arc<dyn HttpClient>,
|
|
||||||
pub executor: Arc<Background>,
|
|
||||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
|
||||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct OpenAIEmbeddingRequest<'a> {
|
|
||||||
model: &'static str,
|
|
||||||
input: Vec<&'a str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIEmbeddingResponse {
|
|
||||||
data: Vec<OpenAIEmbedding>,
|
|
||||||
usage: OpenAIEmbeddingUsage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OpenAIEmbedding {
|
|
||||||
embedding: Vec<f32>,
|
|
||||||
index: usize,
|
|
||||||
object: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenAIEmbeddingUsage {
|
|
||||||
prompt_tokens: usize,
|
|
||||||
total_tokens: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait EmbeddingProvider: Sync + Send {
|
pub trait EmbeddingProvider: CredentialProvider {
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
async fn embed_batch(
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>>;
|
|
||||||
fn max_tokens_per_batch(&self) -> usize;
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
fn truncate(&self, span: &str) -> (String, usize);
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct DummyEmbeddings {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for DummyEmbeddings {
|
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
|
||||||
Some("Dummy API KEY".to_string())
|
|
||||||
}
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
_api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
// 1024 is the OpenAI Embeddings size for ada models.
|
|
||||||
// the model we will likely be starting with.
|
|
||||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
|
||||||
return Ok(vec![dummy_vec; spans.len()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
OPENAI_INPUT_LIMIT
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
let token_count = tokens.len();
|
|
||||||
let output = if token_count > OPENAI_INPUT_LIMIT {
|
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
|
||||||
new_input.ok().unwrap_or_else(|| span.to_string())
|
|
||||||
} else {
|
|
||||||
span.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
(output, tokens.len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
|
||||||
|
|
||||||
impl OpenAIEmbeddings {
|
|
||||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
|
||||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
|
||||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
|
||||||
|
|
||||||
OpenAIEmbeddings {
|
|
||||||
client,
|
|
||||||
executor,
|
|
||||||
rate_limit_count_rx,
|
|
||||||
rate_limit_count_tx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_rate_limit(&self) {
|
|
||||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
|
||||||
|
|
||||||
if let Some(reset_time) = reset_time {
|
|
||||||
if Instant::now() >= reset_time {
|
|
||||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"resolving reset time: {:?}",
|
|
||||||
*self.rate_limit_count_tx.lock().borrow()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_reset_time(&self, reset_time: Instant) {
|
|
||||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
|
||||||
|
|
||||||
let updated_time = if let Some(original_time) = original_time {
|
|
||||||
if reset_time < original_time {
|
|
||||||
Some(reset_time)
|
|
||||||
} else {
|
|
||||||
Some(original_time)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Some(reset_time)
|
|
||||||
};
|
|
||||||
|
|
||||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
|
||||||
|
|
||||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
|
||||||
}
|
|
||||||
async fn send_request(
|
|
||||||
&self,
|
|
||||||
api_key: &str,
|
|
||||||
spans: Vec<&str>,
|
|
||||||
request_timeout: u64,
|
|
||||||
) -> Result<Response<AsyncBody>> {
|
|
||||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
|
||||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
|
||||||
.timeout(Duration::from_secs(request_timeout))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.body(
|
|
||||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
|
||||||
input: spans.clone(),
|
|
||||||
model: "text-embedding-ada-002",
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.into(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(self.client.send(request).await?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
|
||||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
|
|
||||||
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
50000
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
*self.rate_limit_count_rx.borrow()
|
|
||||||
}
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
|
||||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
|
||||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
|
||||||
OPENAI_BPE_TOKENIZER
|
|
||||||
.decode(tokens.clone())
|
|
||||||
.ok()
|
|
||||||
.unwrap_or_else(|| span.to_string())
|
|
||||||
} else {
|
|
||||||
span.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
(output, tokens.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
|
||||||
const MAX_RETRIES: usize = 4;
|
|
||||||
|
|
||||||
let Some(api_key) = api_key else {
|
|
||||||
return Err(anyhow!("no open ai key provided"));
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut request_number = 0;
|
|
||||||
let mut rate_limiting = false;
|
|
||||||
let mut request_timeout: u64 = 15;
|
|
||||||
let mut response: Response<AsyncBody>;
|
|
||||||
while request_number < MAX_RETRIES {
|
|
||||||
response = self
|
|
||||||
.send_request(
|
|
||||||
&api_key,
|
|
||||||
spans.iter().map(|x| &**x).collect(),
|
|
||||||
request_timeout,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
request_number += 1;
|
|
||||||
|
|
||||||
match response.status() {
|
|
||||||
StatusCode::REQUEST_TIMEOUT => {
|
|
||||||
request_timeout += 5;
|
|
||||||
}
|
|
||||||
StatusCode::OK => {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"openai embedding completed. tokens: {:?}",
|
|
||||||
response.usage.total_tokens
|
|
||||||
);
|
|
||||||
|
|
||||||
// If we complete a request successfully that was previously rate_limited
|
|
||||||
// resolve the rate limit
|
|
||||||
if rate_limiting {
|
|
||||||
self.resolve_rate_limit()
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok(response
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|embedding| Embedding::from(embedding.embedding))
|
|
||||||
.collect());
|
|
||||||
}
|
|
||||||
StatusCode::TOO_MANY_REQUESTS => {
|
|
||||||
rate_limiting = true;
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
let delay_duration = {
|
|
||||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
|
||||||
if let Some(time_to_reset) =
|
|
||||||
response.headers().get("x-ratelimit-reset-tokens")
|
|
||||||
{
|
|
||||||
if let Ok(time_str) = time_to_reset.to_str() {
|
|
||||||
parse(time_str).unwrap_or(delay)
|
|
||||||
} else {
|
|
||||||
delay
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
delay
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// If we've previously rate limited, increment the duration but not the count
|
|
||||||
let reset_time = Instant::now().add(delay_duration);
|
|
||||||
self.update_reset_time(reset_time);
|
|
||||||
|
|
||||||
log::trace!(
|
|
||||||
"openai rate limiting: waiting {:?} until lifted",
|
|
||||||
&delay_duration
|
|
||||||
);
|
|
||||||
|
|
||||||
self.executor.timer(delay_duration).await;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
return Err(anyhow!(
|
|
||||||
"open ai bad request: {:?} {:?}",
|
|
||||||
&response.status(),
|
|
||||||
body
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(anyhow!("openai max retries"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -1,66 +1,16 @@
|
|||||||
use anyhow::anyhow;
|
pub enum TruncationDirection {
|
||||||
use tiktoken_rs::CoreBPE;
|
Start,
|
||||||
use util::ResultExt;
|
End,
|
||||||
|
}
|
||||||
|
|
||||||
pub trait LanguageModel {
|
pub trait LanguageModel {
|
||||||
fn name(&self) -> String;
|
fn name(&self) -> String;
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
fn truncate(
|
||||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
fn capacity(&self) -> anyhow::Result<usize>;
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAILanguageModel {
|
|
||||||
name: String,
|
|
||||||
bpe: Option<CoreBPE>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAILanguageModel {
|
|
||||||
pub fn load(model_name: &str) -> Self {
|
|
||||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
|
||||||
OpenAILanguageModel {
|
|
||||||
name: model_name.to_string(),
|
|
||||||
bpe,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel for OpenAILanguageModel {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
self.name.clone()
|
|
||||||
}
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
let tokens = bpe.encode_with_special_tokens(content);
|
|
||||||
if tokens.len() > length {
|
|
||||||
bpe.decode(tokens[..length].to_vec())
|
|
||||||
} else {
|
|
||||||
bpe.decode(tokens)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
if let Some(bpe) = &self.bpe {
|
|
||||||
let tokens = bpe.encode_with_special_tokens(content);
|
|
||||||
if tokens.len() > length {
|
|
||||||
bpe.decode(tokens[length..].to_vec())
|
|
||||||
} else {
|
|
||||||
bpe.decode(tokens)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -6,7 +6,7 @@ use language::BufferSnapshot;
|
|||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
use crate::templates::repository_context::PromptCodeSnippet;
|
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||||
|
|
||||||
pub(crate) enum PromptFileType {
|
pub(crate) enum PromptFileType {
|
||||||
Text,
|
Text,
|
||||||
@ -125,6 +125,9 @@ impl PromptChain {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) mod tests {
|
pub(crate) mod tests {
|
||||||
|
use crate::models::TruncationDirection;
|
||||||
|
use crate::test::FakeLanguageModel;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -141,7 +144,11 @@ pub(crate) mod tests {
|
|||||||
let mut token_count = args.model.count_tokens(&content)?;
|
let mut token_count = args.model.count_tokens(&content)?;
|
||||||
if let Some(max_token_length) = max_token_length {
|
if let Some(max_token_length) = max_token_length {
|
||||||
if token_count > max_token_length {
|
if token_count > max_token_length {
|
||||||
content = args.model.truncate(&content, max_token_length)?;
|
content = args.model.truncate(
|
||||||
|
&content,
|
||||||
|
max_token_length,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
token_count = max_token_length;
|
token_count = max_token_length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -162,7 +169,11 @@ pub(crate) mod tests {
|
|||||||
let mut token_count = args.model.count_tokens(&content)?;
|
let mut token_count = args.model.count_tokens(&content)?;
|
||||||
if let Some(max_token_length) = max_token_length {
|
if let Some(max_token_length) = max_token_length {
|
||||||
if token_count > max_token_length {
|
if token_count > max_token_length {
|
||||||
content = args.model.truncate(&content, max_token_length)?;
|
content = args.model.truncate(
|
||||||
|
&content,
|
||||||
|
max_token_length,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
token_count = max_token_length;
|
token_count = max_token_length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,38 +182,7 @@ pub(crate) mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||||
struct DummyLanguageModel {
|
|
||||||
capacity: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LanguageModel for DummyLanguageModel {
|
|
||||||
fn name(&self) -> String {
|
|
||||||
"dummy".to_string()
|
|
||||||
}
|
|
||||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
|
||||||
}
|
|
||||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
anyhow::Ok(
|
|
||||||
content.chars().collect::<Vec<char>>()[..length]
|
|
||||||
.into_iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
|
||||||
anyhow::Ok(
|
|
||||||
content.chars().collect::<Vec<char>>()[length..]
|
|
||||||
.into_iter()
|
|
||||||
.collect::<String>(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
fn capacity(&self) -> anyhow::Result<usize> {
|
|
||||||
anyhow::Ok(self.capacity)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
|
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
@ -238,7 +218,7 @@ pub(crate) mod tests {
|
|||||||
|
|
||||||
// Testing with Truncation Off
|
// Testing with Truncation Off
|
||||||
// Should ignore capacity and return all prompts
|
// Should ignore capacity and return all prompts
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
@ -275,7 +255,7 @@ pub(crate) mod tests {
|
|||||||
// Testing with Truncation Off
|
// Testing with Truncation Off
|
||||||
// Should ignore capacity and return all prompts
|
// Should ignore capacity and return all prompts
|
||||||
let capacity = 20;
|
let capacity = 20;
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
||||||
@ -311,7 +291,7 @@ pub(crate) mod tests {
|
|||||||
// Change Ordering of Prompts Based on Priority
|
// Change Ordering of Prompts Based on Priority
|
||||||
let capacity = 120;
|
let capacity = 120;
|
||||||
let reserved_tokens = 10;
|
let reserved_tokens = 10;
|
||||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
let args = PromptArguments {
|
let args = PromptArguments {
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
language_name: None,
|
language_name: None,
|
@ -3,8 +3,9 @@ use language::BufferSnapshot;
|
|||||||
use language::ToOffset;
|
use language::ToOffset;
|
||||||
|
|
||||||
use crate::models::LanguageModel;
|
use crate::models::LanguageModel;
|
||||||
use crate::templates::base::PromptArguments;
|
use crate::models::TruncationDirection;
|
||||||
use crate::templates::base::PromptTemplate;
|
use crate::prompts::base::PromptArguments;
|
||||||
|
use crate::prompts::base::PromptTemplate;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -70,8 +71,9 @@ fn retrieve_context(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let truncated_start_window =
|
let truncated_start_window =
|
||||||
model.truncate_start(&start_window, start_goal_tokens)?;
|
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||||
let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
|
let truncated_end_window =
|
||||||
|
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||||
writeln!(
|
writeln!(
|
||||||
prompt,
|
prompt,
|
||||||
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||||
@ -89,7 +91,7 @@ fn retrieve_context(
|
|||||||
if let Some(max_token_count) = max_token_count {
|
if let Some(max_token_count) = max_token_count {
|
||||||
if model.count_tokens(&prompt)? > max_token_count {
|
if model.count_tokens(&prompt)? > max_token_count {
|
||||||
truncated = true;
|
truncated = true;
|
||||||
prompt = model.truncate(&prompt, max_token_count)?;
|
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
|
|||||||
|
|
||||||
// Really dumb truncation strategy
|
// Really dumb truncation strategy
|
||||||
if let Some(max_tokens) = max_token_length {
|
if let Some(max_tokens) = max_token_length {
|
||||||
prompt = args.model.truncate(&prompt, max_tokens)?;
|
prompt = args
|
||||||
|
.model
|
||||||
|
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_count = args.model.count_tokens(&prompt)?;
|
let token_count = args.model.count_tokens(&prompt)?;
|
@ -1,4 +1,4 @@
|
|||||||
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
|
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
||||||
@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
|
|||||||
|
|
||||||
// Really dumb truncation strategy
|
// Really dumb truncation strategy
|
||||||
if let Some(max_tokens) = max_token_length {
|
if let Some(max_tokens) = max_token_length {
|
||||||
prompt = args.model.truncate(&prompt, max_tokens)?;
|
prompt = args.model.truncate(
|
||||||
|
&prompt,
|
||||||
|
max_tokens,
|
||||||
|
crate::models::TruncationDirection::End,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let token_count = args.model.count_tokens(&prompt)?;
|
let token_count = args.model.count_tokens(&prompt)?;
|
@ -1,4 +1,4 @@
|
|||||||
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
|
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
||||||
pub struct EngineerPreamble {}
|
pub struct EngineerPreamble {}
|
@ -1,4 +1,4 @@
|
|||||||
use crate::templates::base::{PromptArguments, PromptTemplate};
|
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::{ops::Range, path::PathBuf};
|
use std::{ops::Range, path::PathBuf};
|
||||||
|
|
1
crates/ai/src/providers/mod.rs
Normal file
1
crates/ai/src/providers/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod open_ai;
|
298
crates/ai/src/providers/open_ai/completion.rs
Normal file
298
crates/ai/src/providers/open_ai/completion.rs
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use futures::{
|
||||||
|
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||||
|
Stream, StreamExt,
|
||||||
|
};
|
||||||
|
use gpui::{executor::Background, AppContext};
|
||||||
|
use isahc::{http::StatusCode, Request, RequestExt};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{
|
||||||
|
env,
|
||||||
|
fmt::{self, Display},
|
||||||
|
io,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
|
models::LanguageModel,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
System,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Role {
|
||||||
|
pub fn cycle(&mut self) {
|
||||||
|
*self = match self {
|
||||||
|
Role::User => Role::Assistant,
|
||||||
|
Role::Assistant => Role::System,
|
||||||
|
Role::System => Role::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Role {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Role::User => write!(f, "User"),
|
||||||
|
Role::Assistant => write!(f, "Assistant"),
|
||||||
|
Role::System => write!(f, "System"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct RequestMessage {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Serialize)]
|
||||||
|
pub struct OpenAIRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<RequestMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
pub stop: Vec<String>,
|
||||||
|
pub temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for OpenAIRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct ResponseMessage {
|
||||||
|
pub role: Option<Role>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIUsage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct ChatChoiceDelta {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ResponseMessage,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIResponseStreamEvent {
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u32,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoiceDelta>,
|
||||||
|
pub usage: Option<OpenAIUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stream_completion(
|
||||||
|
credential: ProviderCredential,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
request: Box<dyn CompletionRequest>,
|
||||||
|
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||||
|
let api_key = match credential {
|
||||||
|
ProviderCredential::Credentials { api_key } => api_key,
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow!("no credentials provider for completion"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||||
|
|
||||||
|
let json_data = request.data()?;
|
||||||
|
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(json_data)?
|
||||||
|
.send_async()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if status == StatusCode::OK {
|
||||||
|
executor
|
||||||
|
.spawn(async move {
|
||||||
|
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||||
|
|
||||||
|
fn parse_line(
|
||||||
|
line: Result<String, io::Error>,
|
||||||
|
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||||
|
if let Some(data) = line?.strip_prefix("data: ") {
|
||||||
|
let event = serde_json::from_str(&data)?;
|
||||||
|
Ok(Some(event))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(line) = lines.next().await {
|
||||||
|
if let Some(event) = parse_line(line).transpose() {
|
||||||
|
let done = event.as_ref().map_or(false, |event| {
|
||||||
|
event
|
||||||
|
.choices
|
||||||
|
.last()
|
||||||
|
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||||
|
});
|
||||||
|
if tx.unbounded_send(event).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIResponse {
|
||||||
|
error: OpenAIError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIError {
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||||
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {}",
|
||||||
|
response.error.message,
|
||||||
|
)),
|
||||||
|
|
||||||
|
_ => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAICompletionProvider {
|
||||||
|
model: OpenAILanguageModel,
|
||||||
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
|
executor: Arc<Background>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAICompletionProvider {
|
||||||
|
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
|
||||||
|
let model = OpenAILanguageModel::load(model_name);
|
||||||
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
credential,
|
||||||
|
executor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for OpenAICompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
let mut credential = self.credential.write();
|
||||||
|
match *credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
credential.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
match credential.clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.platform()
|
||||||
|
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.credential.write() = credential;
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for OpenAICompletionProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
|
model
|
||||||
|
}
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||||
|
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||||
|
// which is currently model based, due to the langauge model.
|
||||||
|
// At some point in the future we should rectify this.
|
||||||
|
let credential = self.credential.read().clone();
|
||||||
|
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||||
|
async move {
|
||||||
|
let response = request.await?;
|
||||||
|
let stream = response
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
|
Box::new((*self).clone())
|
||||||
|
}
|
||||||
|
}
|
306
crates/ai/src/providers/open_ai/embedding.rs
Normal file
306
crates/ai/src/providers/open_ai/embedding.rs
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::AsyncReadExt;
|
||||||
|
use gpui::executor::Background;
|
||||||
|
use gpui::{serde_json, AppContext};
|
||||||
|
use isahc::http::StatusCode;
|
||||||
|
use isahc::prelude::Configurable;
|
||||||
|
use isahc::{AsyncBody, Response};
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use parse_duration::parse;
|
||||||
|
use postage::watch;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
|
use std::ops::Add;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||||
|
use util::http::{HttpClient, Request};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
use crate::providers::open_ai::OpenAILanguageModel;
|
||||||
|
|
||||||
|
use crate::providers::open_ai::OPENAI_API_URL;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAIEmbeddingProvider {
|
||||||
|
model: OpenAILanguageModel,
|
||||||
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
|
pub client: Arc<dyn HttpClient>,
|
||||||
|
pub executor: Arc<Background>,
|
||||||
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OpenAIEmbeddingRequest<'a> {
|
||||||
|
model: &'static str,
|
||||||
|
input: Vec<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingResponse {
|
||||||
|
data: Vec<OpenAIEmbedding>,
|
||||||
|
usage: OpenAIEmbeddingUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIEmbedding {
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
index: usize,
|
||||||
|
object: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingUsage {
|
||||||
|
prompt_tokens: usize,
|
||||||
|
total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIEmbeddingProvider {
|
||||||
|
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||||
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
|
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||||
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
|
||||||
|
OpenAIEmbeddingProvider {
|
||||||
|
model,
|
||||||
|
credential,
|
||||||
|
client,
|
||||||
|
executor,
|
||||||
|
rate_limit_count_rx,
|
||||||
|
rate_limit_count_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_api_key(&self) -> Result<String> {
|
||||||
|
match self.credential.read().clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||||
|
_ => Err(anyhow!("api credentials not provided")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_rate_limit(&self) {
|
||||||
|
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
if let Some(reset_time) = reset_time {
|
||||||
|
if Instant::now() >= reset_time {
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"resolving reset time: {:?}",
|
||||||
|
*self.rate_limit_count_tx.lock().borrow()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_reset_time(&self, reset_time: Instant) {
|
||||||
|
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
let updated_time = if let Some(original_time) = original_time {
|
||||||
|
if reset_time < original_time {
|
||||||
|
Some(reset_time)
|
||||||
|
} else {
|
||||||
|
Some(original_time)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Some(reset_time)
|
||||||
|
};
|
||||||
|
|
||||||
|
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||||
|
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||||
|
}
|
||||||
|
async fn send_request(
|
||||||
|
&self,
|
||||||
|
api_key: &str,
|
||||||
|
spans: Vec<&str>,
|
||||||
|
request_timeout: u64,
|
||||||
|
) -> Result<Response<AsyncBody>> {
|
||||||
|
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||||
|
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||||
|
.timeout(Duration::from_secs(request_timeout))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||||
|
input: spans.clone(),
|
||||||
|
model: "text-embedding-ada-002",
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(self.client.send(request).await?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||||
|
let mut credential = self.credential.write();
|
||||||
|
match *credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
} else if let Some((_, api_key)) = cx
|
||||||
|
.platform()
|
||||||
|
.read_credentials(OPENAI_API_URL)
|
||||||
|
.log_err()
|
||||||
|
.flatten()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
*credential = ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
credential.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||||
|
match credential.clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.platform()
|
||||||
|
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
*self.credential.write() = credential;
|
||||||
|
}
|
||||||
|
fn delete_credentials(&self, cx: &AppContext) {
|
||||||
|
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
|
model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
50000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
*self.rate_limit_count_rx.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||||
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
|
let api_key = self.get_api_key()?;
|
||||||
|
|
||||||
|
let mut request_number = 0;
|
||||||
|
let mut rate_limiting = false;
|
||||||
|
let mut request_timeout: u64 = 15;
|
||||||
|
let mut response: Response<AsyncBody>;
|
||||||
|
while request_number < MAX_RETRIES {
|
||||||
|
response = self
|
||||||
|
.send_request(
|
||||||
|
&api_key,
|
||||||
|
spans.iter().map(|x| &**x).collect(),
|
||||||
|
request_timeout,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
request_number += 1;
|
||||||
|
|
||||||
|
match response.status() {
|
||||||
|
StatusCode::REQUEST_TIMEOUT => {
|
||||||
|
request_timeout += 5;
|
||||||
|
}
|
||||||
|
StatusCode::OK => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai embedding completed. tokens: {:?}",
|
||||||
|
response.usage.total_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
// If we complete a request successfully that was previously rate_limited
|
||||||
|
// resolve the rate limit
|
||||||
|
if rate_limiting {
|
||||||
|
self.resolve_rate_limit()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(response
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|embedding| Embedding::from(embedding.embedding))
|
||||||
|
.collect());
|
||||||
|
}
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => {
|
||||||
|
rate_limiting = true;
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
let delay_duration = {
|
||||||
|
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||||
|
if let Some(time_to_reset) =
|
||||||
|
response.headers().get("x-ratelimit-reset-tokens")
|
||||||
|
{
|
||||||
|
if let Ok(time_str) = time_to_reset.to_str() {
|
||||||
|
parse(time_str).unwrap_or(delay)
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// If we've previously rate limited, increment the duration but not the count
|
||||||
|
let reset_time = Instant::now().add(delay_duration);
|
||||||
|
self.update_reset_time(reset_time);
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai rate limiting: waiting {:?} until lifted",
|
||||||
|
&delay_duration
|
||||||
|
);
|
||||||
|
|
||||||
|
self.executor.timer(delay_duration).await;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
return Err(anyhow!(
|
||||||
|
"open ai bad request: {:?} {:?}",
|
||||||
|
&response.status(),
|
||||||
|
body
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(anyhow!("openai max retries"))
|
||||||
|
}
|
||||||
|
}
|
9
crates/ai/src/providers/open_ai/mod.rs
Normal file
9
crates/ai/src/providers/open_ai/mod.rs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
pub mod completion;
|
||||||
|
pub mod embedding;
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
|
pub use completion::*;
|
||||||
|
pub use embedding::*;
|
||||||
|
pub use model::OpenAILanguageModel;
|
||||||
|
|
||||||
|
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
57
crates/ai/src/providers/open_ai/model.rs
Normal file
57
crates/ai/src/providers/open_ai/model.rs
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
use anyhow::anyhow;
|
||||||
|
use tiktoken_rs::CoreBPE;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::models::{LanguageModel, TruncationDirection};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAILanguageModel {
|
||||||
|
name: String,
|
||||||
|
bpe: Option<CoreBPE>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAILanguageModel {
|
||||||
|
pub fn load(model_name: &str) -> Self {
|
||||||
|
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||||
|
OpenAILanguageModel {
|
||||||
|
name: model_name.to_string(),
|
||||||
|
bpe,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for OpenAILanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
self.name.clone()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
let tokens = bpe.encode_with_special_tokens(content);
|
||||||
|
if tokens.len() > length {
|
||||||
|
match direction {
|
||||||
|
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||||
|
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bpe.decode(tokens)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||||
|
}
|
||||||
|
}
|
11
crates/ai/src/providers/open_ai/new.rs
Normal file
11
crates/ai/src/providers/open_ai/new.rs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
pub trait LanguageModel {
|
||||||
|
fn name(&self) -> String;
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
|
}
|
191
crates/ai/src/test.rs
Normal file
191
crates/ai/src/test.rs
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
use std::{
|
||||||
|
sync::atomic::{self, AtomicUsize, Ordering},
|
||||||
|
time::Instant,
|
||||||
|
};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use gpui::AppContext;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::{LanguageModel, TruncationDirection},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FakeLanguageModel {
|
||||||
|
pub capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for FakeLanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"dummy".to_string()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||||
|
|
||||||
|
if length > self.count_tokens(content)? {
|
||||||
|
println!("NOT TRUNCATING");
|
||||||
|
return anyhow::Ok(content.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(match direction {
|
||||||
|
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(self.capacity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeEmbeddingProvider {
|
||||||
|
pub embedding_count: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for FakeEmbeddingProvider {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FakeEmbeddingProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeEmbeddingProvider {
|
||||||
|
pub fn embedding_count(&self) -> usize {
|
||||||
|
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||||
|
let mut result = vec![1.0; 26];
|
||||||
|
for letter in span.chars() {
|
||||||
|
let letter = letter.to_ascii_lowercase();
|
||||||
|
if letter as u32 >= 'a' as u32 {
|
||||||
|
let ix = (letter as u32) - ('a' as u32);
|
||||||
|
if ix < 26 {
|
||||||
|
result[ix as usize] += 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
for x in &mut result {
|
||||||
|
*x /= norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for FakeEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||||
|
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||||
|
}
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
1000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||||
|
self.embedding_count
|
||||||
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
|
|
||||||
|
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeCompletionProvider {
|
||||||
|
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for FakeCompletionProvider {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
last_completion_tx: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeCompletionProvider {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
last_completion_tx: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||||
|
let mut tx = self.last_completion_tx.lock();
|
||||||
|
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish_completion(&self) {
|
||||||
|
self.last_completion_tx.lock().take().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CredentialProvider for FakeCompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||||
|
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for FakeCompletionProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||||
|
model
|
||||||
|
}
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
_prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||||
|
let (tx, rx) = mpsc::channel(1);
|
||||||
|
*self.last_completion_tx.lock() = Some(tx);
|
||||||
|
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||||
|
}
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
|
Box::new((*self).clone())
|
||||||
|
}
|
||||||
|
}
|
38
crates/ai2/Cargo.toml
Normal file
38
crates/ai2/Cargo.toml
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
[package]
|
||||||
|
name = "ai2"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/ai2.rs"
|
||||||
|
doctest = false
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = []
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
gpui2 = { path = "../gpui2" }
|
||||||
|
util = { path = "../util" }
|
||||||
|
language2 = { path = "../language2" }
|
||||||
|
async-trait.workspace = true
|
||||||
|
anyhow.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
|
lazy_static.workspace = true
|
||||||
|
ordered-float.workspace = true
|
||||||
|
parking_lot.workspace = true
|
||||||
|
isahc.workspace = true
|
||||||
|
regex.workspace = true
|
||||||
|
serde.workspace = true
|
||||||
|
serde_json.workspace = true
|
||||||
|
postage.workspace = true
|
||||||
|
rand.workspace = true
|
||||||
|
log.workspace = true
|
||||||
|
parse_duration = "2.1.1"
|
||||||
|
tiktoken-rs = "0.5.0"
|
||||||
|
matrixmultiply = "0.3.7"
|
||||||
|
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||||
|
bincode = "1.3.3"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
8
crates/ai2/src/ai2.rs
Normal file
8
crates/ai2/src/ai2.rs
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
pub mod auth;
|
||||||
|
pub mod completion;
|
||||||
|
pub mod embedding;
|
||||||
|
pub mod models;
|
||||||
|
pub mod prompts;
|
||||||
|
pub mod providers;
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub mod test;
|
17
crates/ai2/src/auth.rs
Normal file
17
crates/ai2/src/auth.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use gpui2::AppContext;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum ProviderCredential {
|
||||||
|
Credentials { api_key: String },
|
||||||
|
NoCredentials,
|
||||||
|
NotNeeded,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait CredentialProvider: Send + Sync {
|
||||||
|
fn has_credentials(&self) -> bool;
|
||||||
|
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
|
||||||
|
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
|
||||||
|
async fn delete_credentials(&self, cx: &mut AppContext);
|
||||||
|
}
|
23
crates/ai2/src/completion.rs
Normal file
23
crates/ai2/src/completion.rs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
|
|
||||||
|
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||||
|
|
||||||
|
pub trait CompletionRequest: Send + Sync {
|
||||||
|
fn data(&self) -> serde_json::Result<String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CompletionProvider: CredentialProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for Box<dyn CompletionProvider> {
|
||||||
|
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
|
self.box_clone()
|
||||||
|
}
|
||||||
|
}
|
123
crates/ai2/src/embedding.rs
Normal file
123
crates/ai2/src/embedding.rs
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
|
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||||
|
use rusqlite::ToSql;
|
||||||
|
|
||||||
|
use crate::auth::CredentialProvider;
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct Embedding(pub Vec<f32>);
|
||||||
|
|
||||||
|
// This is needed for semantic index functionality
|
||||||
|
// Unfortunately it has to live wherever the "Embedding" struct is created.
|
||||||
|
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
|
||||||
|
// which is less than ideal
|
||||||
|
impl FromSql for Embedding {
|
||||||
|
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||||
|
let bytes = value.as_blob()?;
|
||||||
|
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||||
|
if embedding.is_err() {
|
||||||
|
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||||
|
}
|
||||||
|
Ok(Embedding(embedding.unwrap()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToSql for Embedding {
|
||||||
|
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||||
|
let bytes = bincode::serialize(&self.0)
|
||||||
|
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||||
|
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<Vec<f32>> for Embedding {
|
||||||
|
fn from(value: Vec<f32>) -> Self {
|
||||||
|
Embedding(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||||
|
let len = self.0.len();
|
||||||
|
assert_eq!(len, other.0.len());
|
||||||
|
|
||||||
|
let mut result = 0.0;
|
||||||
|
unsafe {
|
||||||
|
matrixmultiply::sgemm(
|
||||||
|
1,
|
||||||
|
len,
|
||||||
|
1,
|
||||||
|
1.0,
|
||||||
|
self.0.as_ptr(),
|
||||||
|
len as isize,
|
||||||
|
1,
|
||||||
|
other.0.as_ptr(),
|
||||||
|
1,
|
||||||
|
len as isize,
|
||||||
|
0.0,
|
||||||
|
&mut result as *mut f32,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
OrderedFloat(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait EmbeddingProvider: CredentialProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||||
|
fn max_tokens_per_batch(&self) -> usize;
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
#[gpui2::test]
|
||||||
|
fn test_similarity(mut rng: StdRng) {
|
||||||
|
assert_eq!(
|
||||||
|
Embedding::from(vec![1., 0., 0., 0., 0.])
|
||||||
|
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
|
||||||
|
0.
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
Embedding::from(vec![2., 0., 0., 0., 0.])
|
||||||
|
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
|
||||||
|
6.
|
||||||
|
);
|
||||||
|
|
||||||
|
for _ in 0..100 {
|
||||||
|
let size = 1536;
|
||||||
|
let mut a = vec![0.; size];
|
||||||
|
let mut b = vec![0.; size];
|
||||||
|
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||||
|
*a = rng.gen();
|
||||||
|
*b = rng.gen();
|
||||||
|
}
|
||||||
|
let a = Embedding::from(a);
|
||||||
|
let b = Embedding::from(b);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
round_to_decimals(a.similarity(&b), 1),
|
||||||
|
round_to_decimals(reference_dot(&a.0, &b.0), 1)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||||
|
let factor = (10.0 as f32).powi(decimal_places);
|
||||||
|
(n * factor).round() / factor
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||||
|
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
16
crates/ai2/src/models.rs
Normal file
16
crates/ai2/src/models.rs
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
pub enum TruncationDirection {
|
||||||
|
Start,
|
||||||
|
End,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LanguageModel {
|
||||||
|
fn name(&self) -> String;
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
|
}
|
330
crates/ai2/src/prompts/base.rs
Normal file
330
crates/ai2/src/prompts/base.rs
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
use std::cmp::Reverse;
|
||||||
|
use std::ops::Range;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use language2::BufferSnapshot;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||||
|
|
||||||
|
pub(crate) enum PromptFileType {
|
||||||
|
Text,
|
||||||
|
Code,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Set this up to manage for defaults well
|
||||||
|
pub struct PromptArguments {
|
||||||
|
pub model: Arc<dyn LanguageModel>,
|
||||||
|
pub user_prompt: Option<String>,
|
||||||
|
pub language_name: Option<String>,
|
||||||
|
pub project_name: Option<String>,
|
||||||
|
pub snippets: Vec<PromptCodeSnippet>,
|
||||||
|
pub reserved_tokens: usize,
|
||||||
|
pub buffer: Option<BufferSnapshot>,
|
||||||
|
pub selected_range: Option<Range<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptArguments {
|
||||||
|
pub(crate) fn get_file_type(&self) -> PromptFileType {
|
||||||
|
if self
|
||||||
|
.language_name
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
|
||||||
|
.unwrap_or(true)
|
||||||
|
{
|
||||||
|
PromptFileType::Code
|
||||||
|
} else {
|
||||||
|
PromptFileType::Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait PromptTemplate {
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
args: &PromptArguments,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize)>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(i8)]
|
||||||
|
#[derive(PartialEq, Eq, Ord)]
|
||||||
|
pub enum PromptPriority {
|
||||||
|
Mandatory, // Ignores truncation
|
||||||
|
Ordered { order: usize }, // Truncates based on priority
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for PromptPriority {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
match (self, other) {
|
||||||
|
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
|
||||||
|
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
|
||||||
|
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
|
||||||
|
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PromptChain {
|
||||||
|
args: PromptArguments,
|
||||||
|
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptChain {
|
||||||
|
pub fn new(
|
||||||
|
args: PromptArguments,
|
||||||
|
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||||
|
) -> Self {
|
||||||
|
PromptChain { args, templates }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
|
||||||
|
// Argsort based on Prompt Priority
|
||||||
|
let seperator = "\n";
|
||||||
|
let seperator_tokens = self.args.model.count_tokens(seperator)?;
|
||||||
|
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
|
||||||
|
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
|
||||||
|
|
||||||
|
// If Truncate
|
||||||
|
let mut tokens_outstanding = if truncate {
|
||||||
|
Some(self.args.model.capacity()? - self.args.reserved_tokens)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut prompts = vec!["".to_string(); sorted_indices.len()];
|
||||||
|
for idx in sorted_indices {
|
||||||
|
let (_, template) = &self.templates[idx];
|
||||||
|
|
||||||
|
if let Some((template_prompt, prompt_token_count)) =
|
||||||
|
template.generate(&self.args, tokens_outstanding).log_err()
|
||||||
|
{
|
||||||
|
if template_prompt != "" {
|
||||||
|
prompts[idx] = template_prompt;
|
||||||
|
|
||||||
|
if let Some(remaining_tokens) = tokens_outstanding {
|
||||||
|
let new_tokens = prompt_token_count + seperator_tokens;
|
||||||
|
tokens_outstanding = if remaining_tokens > new_tokens {
|
||||||
|
Some(remaining_tokens - new_tokens)
|
||||||
|
} else {
|
||||||
|
Some(0)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts.retain(|x| x != "");
|
||||||
|
|
||||||
|
let full_prompt = prompts.join(seperator);
|
||||||
|
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
|
||||||
|
anyhow::Ok((prompts.join(seperator), total_token_count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) mod tests {
|
||||||
|
use crate::models::TruncationDirection;
|
||||||
|
use crate::test::FakeLanguageModel;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
pub fn test_prompt_chain() {
|
||||||
|
struct TestPromptTemplate {}
|
||||||
|
impl PromptTemplate for TestPromptTemplate {
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
args: &PromptArguments,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize)> {
|
||||||
|
let mut content = "This is a test prompt template".to_string();
|
||||||
|
|
||||||
|
let mut token_count = args.model.count_tokens(&content)?;
|
||||||
|
if let Some(max_token_length) = max_token_length {
|
||||||
|
if token_count > max_token_length {
|
||||||
|
content = args.model.truncate(
|
||||||
|
&content,
|
||||||
|
max_token_length,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
token_count = max_token_length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok((content, token_count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestLowPriorityTemplate {}
|
||||||
|
impl PromptTemplate for TestLowPriorityTemplate {
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
args: &PromptArguments,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize)> {
|
||||||
|
let mut content = "This is a low priority test prompt template".to_string();
|
||||||
|
|
||||||
|
let mut token_count = args.model.count_tokens(&content)?;
|
||||||
|
if let Some(max_token_length) = max_token_length {
|
||||||
|
if token_count > max_token_length {
|
||||||
|
content = args.model.truncate(
|
||||||
|
&content,
|
||||||
|
max_token_length,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
token_count = max_token_length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok((content, token_count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||||
|
let args = PromptArguments {
|
||||||
|
model: model.clone(),
|
||||||
|
language_name: None,
|
||||||
|
project_name: None,
|
||||||
|
snippets: Vec::new(),
|
||||||
|
reserved_tokens: 0,
|
||||||
|
buffer: None,
|
||||||
|
selected_range: None,
|
||||||
|
user_prompt: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 0 },
|
||||||
|
Box::new(TestPromptTemplate {}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 1 },
|
||||||
|
Box::new(TestLowPriorityTemplate {}),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
let chain = PromptChain::new(args, templates);
|
||||||
|
|
||||||
|
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
prompt,
|
||||||
|
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||||
|
.to_string()
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||||
|
|
||||||
|
// Testing with Truncation Off
|
||||||
|
// Should ignore capacity and return all prompts
|
||||||
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||||
|
let args = PromptArguments {
|
||||||
|
model: model.clone(),
|
||||||
|
language_name: None,
|
||||||
|
project_name: None,
|
||||||
|
snippets: Vec::new(),
|
||||||
|
reserved_tokens: 0,
|
||||||
|
buffer: None,
|
||||||
|
selected_range: None,
|
||||||
|
user_prompt: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 0 },
|
||||||
|
Box::new(TestPromptTemplate {}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 1 },
|
||||||
|
Box::new(TestLowPriorityTemplate {}),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
let chain = PromptChain::new(args, templates);
|
||||||
|
|
||||||
|
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
prompt,
|
||||||
|
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||||
|
.to_string()
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||||
|
|
||||||
|
// Testing with Truncation Off
|
||||||
|
// Should ignore capacity and return all prompts
|
||||||
|
let capacity = 20;
|
||||||
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
|
let args = PromptArguments {
|
||||||
|
model: model.clone(),
|
||||||
|
language_name: None,
|
||||||
|
project_name: None,
|
||||||
|
snippets: Vec::new(),
|
||||||
|
reserved_tokens: 0,
|
||||||
|
buffer: None,
|
||||||
|
selected_range: None,
|
||||||
|
user_prompt: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 0 },
|
||||||
|
Box::new(TestPromptTemplate {}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 1 },
|
||||||
|
Box::new(TestLowPriorityTemplate {}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 2 },
|
||||||
|
Box::new(TestLowPriorityTemplate {}),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
let chain = PromptChain::new(args, templates);
|
||||||
|
|
||||||
|
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(prompt, "This is a test promp".to_string());
|
||||||
|
assert_eq!(token_count, capacity);
|
||||||
|
|
||||||
|
// Change Ordering of Prompts Based on Priority
|
||||||
|
let capacity = 120;
|
||||||
|
let reserved_tokens = 10;
|
||||||
|
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||||
|
let args = PromptArguments {
|
||||||
|
model: model.clone(),
|
||||||
|
language_name: None,
|
||||||
|
project_name: None,
|
||||||
|
snippets: Vec::new(),
|
||||||
|
reserved_tokens,
|
||||||
|
buffer: None,
|
||||||
|
selected_range: None,
|
||||||
|
user_prompt: None,
|
||||||
|
};
|
||||||
|
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||||
|
(
|
||||||
|
PromptPriority::Mandatory,
|
||||||
|
Box::new(TestLowPriorityTemplate {}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 0 },
|
||||||
|
Box::new(TestPromptTemplate {}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
PromptPriority::Ordered { order: 1 },
|
||||||
|
Box::new(TestLowPriorityTemplate {}),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
let chain = PromptChain::new(args, templates);
|
||||||
|
|
||||||
|
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
prompt,
|
||||||
|
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
|
||||||
|
.to_string()
|
||||||
|
);
|
||||||
|
assert_eq!(token_count, capacity - reserved_tokens);
|
||||||
|
}
|
||||||
|
}
|
164
crates/ai2/src/prompts/file_context.rs
Normal file
164
crates/ai2/src/prompts/file_context.rs
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
use anyhow::anyhow;
|
||||||
|
use language2::BufferSnapshot;
|
||||||
|
use language2::ToOffset;
|
||||||
|
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
use crate::models::TruncationDirection;
|
||||||
|
use crate::prompts::base::PromptArguments;
|
||||||
|
use crate::prompts::base::PromptTemplate;
|
||||||
|
use std::fmt::Write;
|
||||||
|
use std::ops::Range;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn retrieve_context(
|
||||||
|
buffer: &BufferSnapshot,
|
||||||
|
selected_range: &Option<Range<usize>>,
|
||||||
|
model: Arc<dyn LanguageModel>,
|
||||||
|
max_token_count: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize, bool)> {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
let mut truncated = false;
|
||||||
|
if let Some(selected_range) = selected_range {
|
||||||
|
let start = selected_range.start.to_offset(buffer);
|
||||||
|
let end = selected_range.end.to_offset(buffer);
|
||||||
|
|
||||||
|
let start_window = buffer.text_for_range(0..start).collect::<String>();
|
||||||
|
|
||||||
|
let mut selected_window = String::new();
|
||||||
|
if start == end {
|
||||||
|
write!(selected_window, "<|START|>").unwrap();
|
||||||
|
} else {
|
||||||
|
write!(selected_window, "<|START|").unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
write!(
|
||||||
|
selected_window,
|
||||||
|
"{}",
|
||||||
|
buffer.text_for_range(start..end).collect::<String>()
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
if start != end {
|
||||||
|
write!(selected_window, "|END|>").unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
|
||||||
|
|
||||||
|
if let Some(max_token_count) = max_token_count {
|
||||||
|
let selected_tokens = model.count_tokens(&selected_window)?;
|
||||||
|
if selected_tokens > max_token_count {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"selected range is greater than model context window, truncation not possible"
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut remaining_tokens = max_token_count - selected_tokens;
|
||||||
|
let start_window_tokens = model.count_tokens(&start_window)?;
|
||||||
|
let end_window_tokens = model.count_tokens(&end_window)?;
|
||||||
|
let outside_tokens = start_window_tokens + end_window_tokens;
|
||||||
|
if outside_tokens > remaining_tokens {
|
||||||
|
let (start_goal_tokens, end_goal_tokens) =
|
||||||
|
if start_window_tokens < end_window_tokens {
|
||||||
|
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
|
||||||
|
remaining_tokens -= start_goal_tokens;
|
||||||
|
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
|
||||||
|
(start_goal_tokens, end_goal_tokens)
|
||||||
|
} else {
|
||||||
|
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
|
||||||
|
remaining_tokens -= end_goal_tokens;
|
||||||
|
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
|
||||||
|
(start_goal_tokens, end_goal_tokens)
|
||||||
|
};
|
||||||
|
|
||||||
|
let truncated_start_window =
|
||||||
|
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||||
|
let truncated_end_window =
|
||||||
|
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
truncated = true;
|
||||||
|
} else {
|
||||||
|
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If we dont have a selected range, include entire file.
|
||||||
|
writeln!(prompt, "{}", &buffer.text()).unwrap();
|
||||||
|
|
||||||
|
// Dumb truncation strategy
|
||||||
|
if let Some(max_token_count) = max_token_count {
|
||||||
|
if model.count_tokens(&prompt)? > max_token_count {
|
||||||
|
truncated = true;
|
||||||
|
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_count = model.count_tokens(&prompt)?;
|
||||||
|
anyhow::Ok((prompt, token_count, truncated))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FileContext {}
|
||||||
|
|
||||||
|
impl PromptTemplate for FileContext {
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
args: &PromptArguments,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize)> {
|
||||||
|
if let Some(buffer) = &args.buffer {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
// Add Initial Preamble
|
||||||
|
// TODO: Do we want to add the path in here?
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"The file you are currently working on has the following content:"
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let language_name = args
|
||||||
|
.language_name
|
||||||
|
.clone()
|
||||||
|
.unwrap_or("".to_string())
|
||||||
|
.to_lowercase();
|
||||||
|
|
||||||
|
let (context, _, truncated) = retrieve_context(
|
||||||
|
buffer,
|
||||||
|
&args.selected_range,
|
||||||
|
args.model.clone(),
|
||||||
|
max_token_length,
|
||||||
|
)?;
|
||||||
|
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
|
||||||
|
|
||||||
|
if truncated {
|
||||||
|
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(selected_range) = &args.selected_range {
|
||||||
|
let start = selected_range.start.to_offset(buffer);
|
||||||
|
let end = selected_range.end.to_offset(buffer);
|
||||||
|
|
||||||
|
if start == end {
|
||||||
|
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
|
||||||
|
} else {
|
||||||
|
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Really dumb truncation strategy
|
||||||
|
if let Some(max_tokens) = max_token_length {
|
||||||
|
prompt = args
|
||||||
|
.model
|
||||||
|
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_count = args.model.count_tokens(&prompt)?;
|
||||||
|
anyhow::Ok((prompt, token_count))
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("no buffer provided to retrieve file context from"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
99
crates/ai2/src/prompts/generate.rs
Normal file
99
crates/ai2/src/prompts/generate.rs
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
pub fn capitalize(s: &str) -> String {
|
||||||
|
let mut c = s.chars();
|
||||||
|
match c.next() {
|
||||||
|
None => String::new(),
|
||||||
|
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GenerateInlineContent {}
|
||||||
|
|
||||||
|
impl PromptTemplate for GenerateInlineContent {
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
args: &PromptArguments,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize)> {
|
||||||
|
let Some(user_prompt) = &args.user_prompt else {
|
||||||
|
return Err(anyhow!("user prompt not provided"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let file_type = args.get_file_type();
|
||||||
|
let content_type = match &file_type {
|
||||||
|
PromptFileType::Code => "code",
|
||||||
|
PromptFileType::Text => "text",
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
if let Some(selected_range) = &args.selected_range {
|
||||||
|
if selected_range.start == selected_range.end {
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"Assume the cursor is located where the `<|START|>` span is."
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||||
|
capitalize(content_type)
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
} else {
|
||||||
|
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||||
|
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||||
|
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"Generate {content_type} based on the users prompt: {user_prompt}"
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(language_name) = &args.language_name {
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"Your answer MUST always and only be valid {}.",
|
||||||
|
language_name
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||||
|
writeln!(
|
||||||
|
prompt,
|
||||||
|
"Do not return anything else, except the generated {content_type}."
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
match file_type {
|
||||||
|
PromptFileType::Code => {
|
||||||
|
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Really dumb truncation strategy
|
||||||
|
if let Some(max_tokens) = max_token_length {
|
||||||
|
prompt = args.model.truncate(
|
||||||
|
&prompt,
|
||||||
|
max_tokens,
|
||||||
|
crate::models::TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_count = args.model.count_tokens(&prompt)?;
|
||||||
|
|
||||||
|
anyhow::Ok((prompt, token_count))
|
||||||
|
}
|
||||||
|
}
|
5
crates/ai2/src/prompts/mod.rs
Normal file
5
crates/ai2/src/prompts/mod.rs
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
pub mod base;
|
||||||
|
pub mod file_context;
|
||||||
|
pub mod generate;
|
||||||
|
pub mod preamble;
|
||||||
|
pub mod repository_context;
|
52
crates/ai2/src/prompts/preamble.rs
Normal file
52
crates/ai2/src/prompts/preamble.rs
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
pub struct EngineerPreamble {}
|
||||||
|
|
||||||
|
impl PromptTemplate for EngineerPreamble {
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
args: &PromptArguments,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize)> {
|
||||||
|
let mut prompts = Vec::new();
|
||||||
|
|
||||||
|
match args.get_file_type() {
|
||||||
|
PromptFileType::Code => {
|
||||||
|
prompts.push(format!(
|
||||||
|
"You are an expert {}engineer.",
|
||||||
|
args.language_name.clone().unwrap_or("".to_string()) + " "
|
||||||
|
));
|
||||||
|
}
|
||||||
|
PromptFileType::Text => {
|
||||||
|
prompts.push("You are an expert engineer.".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(project_name) = args.project_name.clone() {
|
||||||
|
prompts.push(format!(
|
||||||
|
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(mut remaining_tokens) = max_token_length {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
let mut total_count = 0;
|
||||||
|
for prompt_piece in prompts {
|
||||||
|
let prompt_token_count =
|
||||||
|
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
|
||||||
|
if remaining_tokens > prompt_token_count {
|
||||||
|
writeln!(prompt, "{prompt_piece}").unwrap();
|
||||||
|
remaining_tokens -= prompt_token_count;
|
||||||
|
total_count += prompt_token_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok((prompt, total_count))
|
||||||
|
} else {
|
||||||
|
let prompt = prompts.join("\n");
|
||||||
|
let token_count = args.model.count_tokens(&prompt)?;
|
||||||
|
anyhow::Ok((prompt, token_count))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
98
crates/ai2/src/prompts/repository_context.rs
Normal file
98
crates/ai2/src/prompts/repository_context.rs
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||||
|
use std::fmt::Write;
|
||||||
|
use std::{ops::Range, path::PathBuf};
|
||||||
|
|
||||||
|
use gpui2::{AsyncAppContext, Model};
|
||||||
|
use language2::{Anchor, Buffer};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PromptCodeSnippet {
|
||||||
|
path: Option<PathBuf>,
|
||||||
|
language_name: Option<String>,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptCodeSnippet {
|
||||||
|
pub fn new(
|
||||||
|
buffer: Model<Buffer>,
|
||||||
|
range: Range<Anchor>,
|
||||||
|
cx: &mut AsyncAppContext,
|
||||||
|
) -> anyhow::Result<Self> {
|
||||||
|
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
|
||||||
|
let snapshot = buffer.snapshot();
|
||||||
|
let content = snapshot.text_for_range(range.clone()).collect::<String>();
|
||||||
|
|
||||||
|
let language_name = buffer
|
||||||
|
.language()
|
||||||
|
.and_then(|language| Some(language.name().to_string().to_lowercase()));
|
||||||
|
|
||||||
|
let file_path = buffer
|
||||||
|
.file()
|
||||||
|
.and_then(|file| Some(file.path().to_path_buf()));
|
||||||
|
|
||||||
|
(content, language_name, file_path)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
anyhow::Ok(PromptCodeSnippet {
|
||||||
|
path: file_path,
|
||||||
|
language_name,
|
||||||
|
content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToString for PromptCodeSnippet {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
let path = self
|
||||||
|
.path
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|path| Some(path.to_string_lossy().to_string()))
|
||||||
|
.unwrap_or("".to_string());
|
||||||
|
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||||
|
let content = self.content.clone();
|
||||||
|
|
||||||
|
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RepositoryContext {}
|
||||||
|
|
||||||
|
impl PromptTemplate for RepositoryContext {
|
||||||
|
fn generate(
|
||||||
|
&self,
|
||||||
|
args: &PromptArguments,
|
||||||
|
max_token_length: Option<usize>,
|
||||||
|
) -> anyhow::Result<(String, usize)> {
|
||||||
|
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||||
|
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
|
||||||
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
let mut remaining_tokens = max_token_length.clone();
|
||||||
|
let seperator_token_length = args.model.count_tokens("\n")?;
|
||||||
|
for snippet in &args.snippets {
|
||||||
|
let mut snippet_prompt = template.to_string();
|
||||||
|
let content = snippet.to_string();
|
||||||
|
writeln!(snippet_prompt, "{content}").unwrap();
|
||||||
|
|
||||||
|
let token_count = args.model.count_tokens(&snippet_prompt)?;
|
||||||
|
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||||
|
if let Some(tokens_left) = remaining_tokens {
|
||||||
|
if tokens_left >= token_count {
|
||||||
|
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||||
|
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
|
||||||
|
{
|
||||||
|
Some(tokens_left - token_count - seperator_token_length)
|
||||||
|
} else {
|
||||||
|
Some(0)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_token_count = args.model.count_tokens(&prompt)?;
|
||||||
|
anyhow::Ok((prompt, total_token_count))
|
||||||
|
}
|
||||||
|
}
|
1
crates/ai2/src/providers/mod.rs
Normal file
1
crates/ai2/src/providers/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod open_ai;
|
306
crates/ai2/src/providers/open_ai/completion.rs
Normal file
306
crates/ai2/src/providers/open_ai/completion.rs
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::{
|
||||||
|
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||||
|
Stream, StreamExt,
|
||||||
|
};
|
||||||
|
use gpui2::{AppContext, Executor};
|
||||||
|
use isahc::{http::StatusCode, Request, RequestExt};
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{
|
||||||
|
env,
|
||||||
|
fmt::{self, Display},
|
||||||
|
io,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
|
models::LanguageModel,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Assistant,
|
||||||
|
System,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Role {
|
||||||
|
pub fn cycle(&mut self) {
|
||||||
|
*self = match self {
|
||||||
|
Role::User => Role::Assistant,
|
||||||
|
Role::Assistant => Role::System,
|
||||||
|
Role::System => Role::User,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Role {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Role::User => write!(f, "User"),
|
||||||
|
Role::Assistant => write!(f, "Assistant"),
|
||||||
|
Role::System => write!(f, "System"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct RequestMessage {
|
||||||
|
pub role: Role,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Serialize)]
|
||||||
|
pub struct OpenAIRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<RequestMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
pub stop: Vec<String>,
|
||||||
|
pub temperature: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for OpenAIRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
pub struct ResponseMessage {
|
||||||
|
pub role: Option<Role>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIUsage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct ChatChoiceDelta {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ResponseMessage,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
pub struct OpenAIResponseStreamEvent {
|
||||||
|
pub id: Option<String>,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u32,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoiceDelta>,
|
||||||
|
pub usage: Option<OpenAIUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stream_completion(
|
||||||
|
credential: ProviderCredential,
|
||||||
|
executor: Arc<Executor>,
|
||||||
|
request: Box<dyn CompletionRequest>,
|
||||||
|
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||||
|
let api_key = match credential {
|
||||||
|
ProviderCredential::Credentials { api_key } => api_key,
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow!("no credentials provider for completion"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||||
|
|
||||||
|
let json_data = request.data()?;
|
||||||
|
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(json_data)?
|
||||||
|
.send_async()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if status == StatusCode::OK {
|
||||||
|
executor
|
||||||
|
.spawn(async move {
|
||||||
|
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||||
|
|
||||||
|
fn parse_line(
|
||||||
|
line: Result<String, io::Error>,
|
||||||
|
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||||
|
if let Some(data) = line?.strip_prefix("data: ") {
|
||||||
|
let event = serde_json::from_str(&data)?;
|
||||||
|
Ok(Some(event))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(line) = lines.next().await {
|
||||||
|
if let Some(event) = parse_line(line).transpose() {
|
||||||
|
let done = event.as_ref().map_or(false, |event| {
|
||||||
|
event
|
||||||
|
.choices
|
||||||
|
.last()
|
||||||
|
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||||
|
});
|
||||||
|
if tx.unbounded_send(event).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
|
||||||
|
Ok(rx)
|
||||||
|
} else {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIResponse {
|
||||||
|
error: OpenAIError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIError {
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||||
|
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {}",
|
||||||
|
response.error.message,
|
||||||
|
)),
|
||||||
|
|
||||||
|
_ => Err(anyhow!(
|
||||||
|
"Failed to connect to OpenAI API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAICompletionProvider {
|
||||||
|
model: OpenAILanguageModel,
|
||||||
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
|
executor: Arc<Executor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAICompletionProvider {
|
||||||
|
pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
|
||||||
|
let model = OpenAILanguageModel::load(model_name);
|
||||||
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
credential,
|
||||||
|
executor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CredentialProvider for OpenAICompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
|
||||||
|
let existing_credential = self.credential.read().clone();
|
||||||
|
|
||||||
|
let retrieved_credential = cx
|
||||||
|
.run_on_main(move |cx| match existing_credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return existing_credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||||
|
return ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
return ProviderCredential::Credentials { api_key };
|
||||||
|
} else {
|
||||||
|
return ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
*self.credential.write() = retrieved_credential.clone();
|
||||||
|
retrieved_credential
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
|
||||||
|
*self.credential.write() = credential.clone();
|
||||||
|
let credential = credential.clone();
|
||||||
|
cx.run_on_main(move |cx| match credential {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
async fn delete_credentials(&self, cx: &mut AppContext) {
|
||||||
|
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
|
||||||
|
.await;
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for OpenAICompletionProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
|
model
|
||||||
|
}
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||||
|
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||||
|
// which is currently model based, due to the langauge model.
|
||||||
|
// At some point in the future we should rectify this.
|
||||||
|
let credential = self.credential.read().clone();
|
||||||
|
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||||
|
async move {
|
||||||
|
let response = request.await?;
|
||||||
|
let stream = response
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
|
Box::new((*self).clone())
|
||||||
|
}
|
||||||
|
}
|
313
crates/ai2/src/providers/open_ai/embedding.rs
Normal file
313
crates/ai2/src/providers/open_ai/embedding.rs
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::AsyncReadExt;
|
||||||
|
use gpui2::Executor;
|
||||||
|
use gpui2::{serde_json, AppContext};
|
||||||
|
use isahc::http::StatusCode;
|
||||||
|
use isahc::prelude::Configurable;
|
||||||
|
use isahc::{AsyncBody, Response};
|
||||||
|
use lazy_static::lazy_static;
|
||||||
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
use parse_duration::parse;
|
||||||
|
use postage::watch;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
|
use std::ops::Add;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||||
|
use util::http::{HttpClient, Request};
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||||
|
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use crate::models::LanguageModel;
|
||||||
|
use crate::providers::open_ai::OpenAILanguageModel;
|
||||||
|
|
||||||
|
use crate::providers::open_ai::OPENAI_API_URL;
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAIEmbeddingProvider {
|
||||||
|
model: OpenAILanguageModel,
|
||||||
|
credential: Arc<RwLock<ProviderCredential>>,
|
||||||
|
pub client: Arc<dyn HttpClient>,
|
||||||
|
pub executor: Arc<Executor>,
|
||||||
|
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||||
|
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OpenAIEmbeddingRequest<'a> {
|
||||||
|
model: &'static str,
|
||||||
|
input: Vec<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingResponse {
|
||||||
|
data: Vec<OpenAIEmbedding>,
|
||||||
|
usage: OpenAIEmbeddingUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIEmbedding {
|
||||||
|
embedding: Vec<f32>,
|
||||||
|
index: usize,
|
||||||
|
object: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OpenAIEmbeddingUsage {
|
||||||
|
prompt_tokens: usize,
|
||||||
|
total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIEmbeddingProvider {
|
||||||
|
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
|
||||||
|
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||||
|
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||||
|
|
||||||
|
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||||
|
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||||
|
|
||||||
|
OpenAIEmbeddingProvider {
|
||||||
|
model,
|
||||||
|
credential,
|
||||||
|
client,
|
||||||
|
executor,
|
||||||
|
rate_limit_count_rx,
|
||||||
|
rate_limit_count_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_api_key(&self) -> Result<String> {
|
||||||
|
match self.credential.read().clone() {
|
||||||
|
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||||
|
_ => Err(anyhow!("api credentials not provided")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_rate_limit(&self) {
|
||||||
|
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
if let Some(reset_time) = reset_time {
|
||||||
|
if Instant::now() >= reset_time {
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"resolving reset time: {:?}",
|
||||||
|
*self.rate_limit_count_tx.lock().borrow()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_reset_time(&self, reset_time: Instant) {
|
||||||
|
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||||
|
|
||||||
|
let updated_time = if let Some(original_time) = original_time {
|
||||||
|
if reset_time < original_time {
|
||||||
|
Some(reset_time)
|
||||||
|
} else {
|
||||||
|
Some(original_time)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Some(reset_time)
|
||||||
|
};
|
||||||
|
|
||||||
|
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||||
|
|
||||||
|
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||||
|
}
|
||||||
|
async fn send_request(
|
||||||
|
&self,
|
||||||
|
api_key: &str,
|
||||||
|
spans: Vec<&str>,
|
||||||
|
request_timeout: u64,
|
||||||
|
) -> Result<Response<AsyncBody>> {
|
||||||
|
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||||
|
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||||
|
.timeout(Duration::from_secs(request_timeout))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
.body(
|
||||||
|
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||||
|
input: spans.clone(),
|
||||||
|
model: "text-embedding-ada-002",
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
.into(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(self.client.send(request).await?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
match *self.credential.read() {
|
||||||
|
ProviderCredential::Credentials { .. } => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
|
||||||
|
let existing_credential = self.credential.read().clone();
|
||||||
|
|
||||||
|
let retrieved_credential = cx
|
||||||
|
.run_on_main(move |cx| match existing_credential {
|
||||||
|
ProviderCredential::Credentials { .. } => {
|
||||||
|
return existing_credential.clone();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||||
|
return ProviderCredential::Credentials { api_key };
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
|
||||||
|
{
|
||||||
|
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||||
|
return ProviderCredential::Credentials { api_key };
|
||||||
|
} else {
|
||||||
|
return ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
*self.credential.write() = retrieved_credential.clone();
|
||||||
|
retrieved_credential
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
|
||||||
|
*self.credential.write() = credential.clone();
|
||||||
|
let credential = credential.clone();
|
||||||
|
cx.run_on_main(move |cx| match credential {
|
||||||
|
ProviderCredential::Credentials { api_key } => {
|
||||||
|
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||||
|
.log_err();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
async fn delete_credentials(&self, cx: &mut AppContext) {
|
||||||
|
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
|
||||||
|
.await;
|
||||||
|
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||||
|
model
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
50000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
*self.rate_limit_count_rx.borrow()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||||
|
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||||
|
const MAX_RETRIES: usize = 4;
|
||||||
|
|
||||||
|
let api_key = self.get_api_key()?;
|
||||||
|
|
||||||
|
let mut request_number = 0;
|
||||||
|
let mut rate_limiting = false;
|
||||||
|
let mut request_timeout: u64 = 15;
|
||||||
|
let mut response: Response<AsyncBody>;
|
||||||
|
while request_number < MAX_RETRIES {
|
||||||
|
response = self
|
||||||
|
.send_request(
|
||||||
|
&api_key,
|
||||||
|
spans.iter().map(|x| &**x).collect(),
|
||||||
|
request_timeout,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
request_number += 1;
|
||||||
|
|
||||||
|
match response.status() {
|
||||||
|
StatusCode::REQUEST_TIMEOUT => {
|
||||||
|
request_timeout += 5;
|
||||||
|
}
|
||||||
|
StatusCode::OK => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai embedding completed. tokens: {:?}",
|
||||||
|
response.usage.total_tokens
|
||||||
|
);
|
||||||
|
|
||||||
|
// If we complete a request successfully that was previously rate_limited
|
||||||
|
// resolve the rate limit
|
||||||
|
if rate_limiting {
|
||||||
|
self.resolve_rate_limit()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(response
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|embedding| Embedding::from(embedding.embedding))
|
||||||
|
.collect());
|
||||||
|
}
|
||||||
|
StatusCode::TOO_MANY_REQUESTS => {
|
||||||
|
rate_limiting = true;
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
let delay_duration = {
|
||||||
|
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||||
|
if let Some(time_to_reset) =
|
||||||
|
response.headers().get("x-ratelimit-reset-tokens")
|
||||||
|
{
|
||||||
|
if let Ok(time_str) = time_to_reset.to_str() {
|
||||||
|
parse(time_str).unwrap_or(delay)
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
delay
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// If we've previously rate limited, increment the duration but not the count
|
||||||
|
let reset_time = Instant::now().add(delay_duration);
|
||||||
|
self.update_reset_time(reset_time);
|
||||||
|
|
||||||
|
log::trace!(
|
||||||
|
"openai rate limiting: waiting {:?} until lifted",
|
||||||
|
&delay_duration
|
||||||
|
);
|
||||||
|
|
||||||
|
self.executor.timer(delay_duration).await;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
return Err(anyhow!(
|
||||||
|
"open ai bad request: {:?} {:?}",
|
||||||
|
&response.status(),
|
||||||
|
body
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(anyhow!("openai max retries"))
|
||||||
|
}
|
||||||
|
}
|
9
crates/ai2/src/providers/open_ai/mod.rs
Normal file
9
crates/ai2/src/providers/open_ai/mod.rs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
pub mod completion;
|
||||||
|
pub mod embedding;
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
|
pub use completion::*;
|
||||||
|
pub use embedding::*;
|
||||||
|
pub use model::OpenAILanguageModel;
|
||||||
|
|
||||||
|
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
57
crates/ai2/src/providers/open_ai/model.rs
Normal file
57
crates/ai2/src/providers/open_ai/model.rs
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
use anyhow::anyhow;
|
||||||
|
use tiktoken_rs::CoreBPE;
|
||||||
|
use util::ResultExt;
|
||||||
|
|
||||||
|
use crate::models::{LanguageModel, TruncationDirection};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OpenAILanguageModel {
|
||||||
|
name: String,
|
||||||
|
bpe: Option<CoreBPE>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAILanguageModel {
|
||||||
|
pub fn load(model_name: &str) -> Self {
|
||||||
|
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||||
|
OpenAILanguageModel {
|
||||||
|
name: model_name.to_string(),
|
||||||
|
bpe,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for OpenAILanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
self.name.clone()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
if let Some(bpe) = &self.bpe {
|
||||||
|
let tokens = bpe.encode_with_special_tokens(content);
|
||||||
|
if tokens.len() > length {
|
||||||
|
match direction {
|
||||||
|
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||||
|
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bpe.decode(tokens)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||||
|
}
|
||||||
|
}
|
11
crates/ai2/src/providers/open_ai/new.rs
Normal file
11
crates/ai2/src/providers/open_ai/new.rs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
pub trait LanguageModel {
|
||||||
|
fn name(&self) -> String;
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String>;
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize>;
|
||||||
|
}
|
193
crates/ai2/src/test.rs
Normal file
193
crates/ai2/src/test.rs
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
use std::{
|
||||||
|
sync::atomic::{self, AtomicUsize, Ordering},
|
||||||
|
time::Instant,
|
||||||
|
};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use gpui2::AppContext;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
auth::{CredentialProvider, ProviderCredential},
|
||||||
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::{LanguageModel, TruncationDirection},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FakeLanguageModel {
|
||||||
|
pub capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModel for FakeLanguageModel {
|
||||||
|
fn name(&self) -> String {
|
||||||
|
"dummy".to_string()
|
||||||
|
}
|
||||||
|
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||||
|
}
|
||||||
|
fn truncate(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
length: usize,
|
||||||
|
direction: TruncationDirection,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||||
|
|
||||||
|
if length > self.count_tokens(content)? {
|
||||||
|
println!("NOT TRUNCATING");
|
||||||
|
return anyhow::Ok(content.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::Ok(match direction {
|
||||||
|
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||||
|
.into_iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn capacity(&self) -> anyhow::Result<usize> {
|
||||||
|
anyhow::Ok(self.capacity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeEmbeddingProvider {
|
||||||
|
pub embedding_count: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for FakeEmbeddingProvider {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for FakeEmbeddingProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
FakeEmbeddingProvider {
|
||||||
|
embedding_count: AtomicUsize::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeEmbeddingProvider {
|
||||||
|
pub fn embedding_count(&self) -> usize {
|
||||||
|
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||||
|
let mut result = vec![1.0; 26];
|
||||||
|
for letter in span.chars() {
|
||||||
|
let letter = letter.to_ascii_lowercase();
|
||||||
|
if letter as u32 >= 'a' as u32 {
|
||||||
|
let ix = (letter as u32) - ('a' as u32);
|
||||||
|
if ix < 26 {
|
||||||
|
result[ix as usize] += 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
for x in &mut result {
|
||||||
|
*x /= norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CredentialProvider for FakeEmbeddingProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
||||||
|
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||||
|
}
|
||||||
|
fn max_tokens_per_batch(&self) -> usize {
|
||||||
|
1000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||||
|
self.embedding_count
|
||||||
|
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||||
|
|
||||||
|
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeCompletionProvider {
|
||||||
|
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for FakeCompletionProvider {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
last_completion_tx: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeCompletionProvider {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
last_completion_tx: Mutex::new(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||||
|
let mut tx = self.last_completion_tx.lock();
|
||||||
|
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finish_completion(&self) {
|
||||||
|
self.last_completion_tx.lock().take().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CredentialProvider for FakeCompletionProvider {
|
||||||
|
fn has_credentials(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
||||||
|
ProviderCredential::NotNeeded
|
||||||
|
}
|
||||||
|
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
||||||
|
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider for FakeCompletionProvider {
|
||||||
|
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||||
|
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||||
|
model
|
||||||
|
}
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
_prompt: Box<dyn CompletionRequest>,
|
||||||
|
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||||
|
let (tx, rx) = mpsc::channel(1);
|
||||||
|
*self.last_completion_tx.lock() = Some(tx);
|
||||||
|
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||||
|
}
|
||||||
|
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||||
|
Box::new((*self).clone())
|
||||||
|
}
|
||||||
|
}
|
@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
editor = { path = "../editor", features = ["test-support"] }
|
editor = { path = "../editor", features = ["test-support"] }
|
||||||
project = { path = "../project", features = ["test-support"] }
|
project = { path = "../project", features = ["test-support"] }
|
||||||
|
ai = { path = "../ai", features = ["test-support"]}
|
||||||
|
|
||||||
ctor.workspace = true
|
ctor.workspace = true
|
||||||
env_logger.workspace = true
|
env_logger.workspace = true
|
||||||
|
@ -4,7 +4,7 @@ mod codegen;
|
|||||||
mod prompts;
|
mod prompts;
|
||||||
mod streaming_diff;
|
mod streaming_diff;
|
||||||
|
|
||||||
use ai::completion::Role;
|
use ai::providers::open_ai::Role;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
pub use assistant_panel::AssistantPanel;
|
pub use assistant_panel::AssistantPanel;
|
||||||
use assistant_settings::OpenAIModel;
|
use assistant_settings::OpenAIModel;
|
||||||
|
@ -5,12 +5,14 @@ use crate::{
|
|||||||
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||||
SavedMessage,
|
SavedMessage,
|
||||||
};
|
};
|
||||||
|
|
||||||
use ai::{
|
use ai::{
|
||||||
completion::{
|
auth::ProviderCredential,
|
||||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
completion::{CompletionProvider, CompletionRequest},
|
||||||
},
|
providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
|
||||||
templates::repository_context::PromptCodeSnippet,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use ai::prompts::repository_context::PromptCodeSnippet;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use chrono::{DateTime, Local};
|
use chrono::{DateTime, Local};
|
||||||
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
||||||
@ -43,8 +45,8 @@ use search::BufferSearchBar;
|
|||||||
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::{
|
use std::{
|
||||||
cell::{Cell, RefCell},
|
cell::Cell,
|
||||||
cmp, env,
|
cmp,
|
||||||
fmt::Write,
|
fmt::Write,
|
||||||
iter,
|
iter,
|
||||||
ops::Range,
|
ops::Range,
|
||||||
@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
|
|||||||
cx.capture_action(ConversationEditor::copy);
|
cx.capture_action(ConversationEditor::copy);
|
||||||
cx.add_action(ConversationEditor::split);
|
cx.add_action(ConversationEditor::split);
|
||||||
cx.capture_action(ConversationEditor::cycle_message_role);
|
cx.capture_action(ConversationEditor::cycle_message_role);
|
||||||
cx.add_action(AssistantPanel::save_api_key);
|
cx.add_action(AssistantPanel::save_credentials);
|
||||||
cx.add_action(AssistantPanel::reset_api_key);
|
cx.add_action(AssistantPanel::reset_credentials);
|
||||||
cx.add_action(AssistantPanel::toggle_zoom);
|
cx.add_action(AssistantPanel::toggle_zoom);
|
||||||
cx.add_action(AssistantPanel::deploy);
|
cx.add_action(AssistantPanel::deploy);
|
||||||
cx.add_action(AssistantPanel::select_next_match);
|
cx.add_action(AssistantPanel::select_next_match);
|
||||||
@ -140,9 +142,8 @@ pub struct AssistantPanel {
|
|||||||
zoomed: bool,
|
zoomed: bool,
|
||||||
has_focus: bool,
|
has_focus: bool,
|
||||||
toolbar: ViewHandle<Toolbar>,
|
toolbar: ViewHandle<Toolbar>,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
api_key_editor: Option<ViewHandle<Editor>>,
|
api_key_editor: Option<ViewHandle<Editor>>,
|
||||||
has_read_credentials: bool,
|
|
||||||
languages: Arc<LanguageRegistry>,
|
languages: Arc<LanguageRegistry>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
subscriptions: Vec<Subscription>,
|
subscriptions: Vec<Subscription>,
|
||||||
@ -202,6 +203,11 @@ impl AssistantPanel {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let semantic_index = SemanticIndex::global(cx);
|
let semantic_index = SemanticIndex::global(cx);
|
||||||
|
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||||
|
let completion_provider = Box::new(OpenAICompletionProvider::new(
|
||||||
|
"gpt-4",
|
||||||
|
cx.background().clone(),
|
||||||
|
));
|
||||||
|
|
||||||
let mut this = Self {
|
let mut this = Self {
|
||||||
workspace: workspace_handle,
|
workspace: workspace_handle,
|
||||||
@ -213,9 +219,8 @@ impl AssistantPanel {
|
|||||||
zoomed: false,
|
zoomed: false,
|
||||||
has_focus: false,
|
has_focus: false,
|
||||||
toolbar,
|
toolbar,
|
||||||
api_key: Rc::new(RefCell::new(None)),
|
completion_provider,
|
||||||
api_key_editor: None,
|
api_key_editor: None,
|
||||||
has_read_credentials: false,
|
|
||||||
languages: workspace.app_state().languages.clone(),
|
languages: workspace.app_state().languages.clone(),
|
||||||
fs: workspace.app_state().fs.clone(),
|
fs: workspace.app_state().fs.clone(),
|
||||||
width: None,
|
width: None,
|
||||||
@ -254,10 +259,7 @@ impl AssistantPanel {
|
|||||||
cx: &mut ViewContext<Workspace>,
|
cx: &mut ViewContext<Workspace>,
|
||||||
) {
|
) {
|
||||||
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
||||||
if this
|
if this.update(cx, |assistant, _| assistant.has_credentials()) {
|
||||||
.update(cx, |assistant, cx| assistant.load_api_key(cx))
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
this
|
this
|
||||||
} else {
|
} else {
|
||||||
workspace.focus_panel::<AssistantPanel>(cx);
|
workspace.focus_panel::<AssistantPanel>(cx);
|
||||||
@ -289,12 +291,6 @@ impl AssistantPanel {
|
|||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
project: &ModelHandle<Project>,
|
project: &ModelHandle<Project>,
|
||||||
) {
|
) {
|
||||||
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
|
|
||||||
api_key
|
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
let selection = editor.read(cx).selections.newest_anchor().clone();
|
let selection = editor.read(cx).selections.newest_anchor().clone();
|
||||||
if selection.start.excerpt_id != selection.end.excerpt_id {
|
if selection.start.excerpt_id != selection.end.excerpt_id {
|
||||||
return;
|
return;
|
||||||
@ -325,10 +321,13 @@ impl AssistantPanel {
|
|||||||
|
|
||||||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||||
api_key,
|
"gpt-4",
|
||||||
cx.background().clone(),
|
cx.background().clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Retrieve Credentials Authenticates the Provider
|
||||||
|
// provider.retrieve_credentials(cx);
|
||||||
|
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
||||||
});
|
});
|
||||||
@ -745,13 +744,14 @@ impl AssistantPanel {
|
|||||||
content: prompt,
|
content: prompt,
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = OpenAIRequest {
|
let request = Box::new(OpenAIRequest {
|
||||||
model: model.full_name().into(),
|
model: model.full_name().into(),
|
||||||
messages,
|
messages,
|
||||||
stream: true,
|
stream: true,
|
||||||
stop: vec!["|END|>".to_string()],
|
stop: vec!["|END|>".to_string()],
|
||||||
temperature,
|
temperature,
|
||||||
};
|
});
|
||||||
|
|
||||||
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
|
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
|
||||||
anyhow::Ok(())
|
anyhow::Ok(())
|
||||||
})
|
})
|
||||||
@ -811,7 +811,7 @@ impl AssistantPanel {
|
|||||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
|
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
|
||||||
let editor = cx.add_view(|cx| {
|
let editor = cx.add_view(|cx| {
|
||||||
ConversationEditor::new(
|
ConversationEditor::new(
|
||||||
self.api_key.clone(),
|
self.completion_provider.clone(),
|
||||||
self.languages.clone(),
|
self.languages.clone(),
|
||||||
self.fs.clone(),
|
self.fs.clone(),
|
||||||
self.workspace.clone(),
|
self.workspace.clone(),
|
||||||
@ -870,17 +870,19 @@ impl AssistantPanel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||||
if let Some(api_key) = self
|
if let Some(api_key) = self
|
||||||
.api_key_editor
|
.api_key_editor
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|editor| editor.read(cx).text(cx))
|
.map(|editor| editor.read(cx).text(cx))
|
||||||
{
|
{
|
||||||
if !api_key.is_empty() {
|
if !api_key.is_empty() {
|
||||||
cx.platform()
|
let credential = ProviderCredential::Credentials {
|
||||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
api_key: api_key.clone(),
|
||||||
.log_err();
|
};
|
||||||
*self.api_key.borrow_mut() = Some(api_key);
|
|
||||||
|
self.completion_provider.save_credentials(cx, credential);
|
||||||
|
|
||||||
self.api_key_editor.take();
|
self.api_key_editor.take();
|
||||||
cx.focus_self();
|
cx.focus_self();
|
||||||
cx.notify();
|
cx.notify();
|
||||||
@ -890,9 +892,8 @@ impl AssistantPanel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
self.completion_provider.delete_credentials(cx);
|
||||||
self.api_key.take();
|
|
||||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
self.api_key_editor = Some(build_api_key_editor(cx));
|
||||||
cx.focus_self();
|
cx.focus_self();
|
||||||
cx.notify();
|
cx.notify();
|
||||||
@ -1151,13 +1152,12 @@ impl AssistantPanel {
|
|||||||
|
|
||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let workspace = self.workspace.clone();
|
let workspace = self.workspace.clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
let languages = self.languages.clone();
|
let languages = self.languages.clone();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
let saved_conversation = fs.load(&path).await?;
|
let saved_conversation = fs.load(&path).await?;
|
||||||
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
||||||
let conversation = cx.add_model(|cx| {
|
let conversation = cx.add_model(|cx| {
|
||||||
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
|
Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
|
||||||
});
|
});
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
// If, by the time we've loaded the conversation, the user has already opened
|
// If, by the time we've loaded the conversation, the user has already opened
|
||||||
@ -1181,30 +1181,12 @@ impl AssistantPanel {
|
|||||||
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
|
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
|
fn has_credentials(&mut self) -> bool {
|
||||||
if self.api_key.borrow().is_none() && !self.has_read_credentials {
|
self.completion_provider.has_credentials()
|
||||||
self.has_read_credentials = true;
|
}
|
||||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
|
||||||
Some(api_key)
|
|
||||||
} else if let Some((_, api_key)) = cx
|
|
||||||
.platform()
|
|
||||||
.read_credentials(OPENAI_API_URL)
|
|
||||||
.log_err()
|
|
||||||
.flatten()
|
|
||||||
{
|
|
||||||
String::from_utf8(api_key).log_err()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
if let Some(api_key) = api_key {
|
|
||||||
*self.api_key.borrow_mut() = Some(api_key);
|
|
||||||
} else if self.api_key_editor.is_none() {
|
|
||||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
|
||||||
cx.notify();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.api_key.borrow().clone()
|
fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
|
||||||
|
self.completion_provider.retrieve_credentials(cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
|
|||||||
|
|
||||||
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
|
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
|
||||||
if active {
|
if active {
|
||||||
self.load_api_key(cx);
|
self.load_credentials(cx);
|
||||||
|
|
||||||
if self.editors.is_empty() {
|
if self.editors.is_empty() {
|
||||||
self.new_conversation(cx);
|
self.new_conversation(cx);
|
||||||
@ -1454,10 +1436,10 @@ struct Conversation {
|
|||||||
token_count: Option<usize>,
|
token_count: Option<usize>,
|
||||||
max_token_count: usize,
|
max_token_count: usize,
|
||||||
pending_token_count: Task<Option<()>>,
|
pending_token_count: Task<Option<()>>,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
|
||||||
pending_save: Task<Result<()>>,
|
pending_save: Task<Result<()>>,
|
||||||
path: Option<PathBuf>,
|
path: Option<PathBuf>,
|
||||||
_subscriptions: Vec<Subscription>,
|
_subscriptions: Vec<Subscription>,
|
||||||
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Entity for Conversation {
|
impl Entity for Conversation {
|
||||||
@ -1466,9 +1448,9 @@ impl Entity for Conversation {
|
|||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
fn new(
|
fn new(
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let markdown = language_registry.language_for_name("Markdown");
|
let markdown = language_registry.language_for_name("Markdown");
|
||||||
let buffer = cx.add_model(|cx| {
|
let buffer = cx.add_model(|cx| {
|
||||||
@ -1507,8 +1489,8 @@ impl Conversation {
|
|||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: None,
|
path: None,
|
||||||
api_key,
|
|
||||||
buffer,
|
buffer,
|
||||||
|
completion_provider,
|
||||||
};
|
};
|
||||||
let message = MessageAnchor {
|
let message = MessageAnchor {
|
||||||
id: MessageId(post_inc(&mut this.next_message_id.0)),
|
id: MessageId(post_inc(&mut this.next_message_id.0)),
|
||||||
@ -1554,7 +1536,6 @@ impl Conversation {
|
|||||||
fn deserialize(
|
fn deserialize(
|
||||||
saved_conversation: SavedConversation,
|
saved_conversation: SavedConversation,
|
||||||
path: PathBuf,
|
path: PathBuf,
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@ -1563,6 +1544,10 @@ impl Conversation {
|
|||||||
None => Some(Uuid::new_v4().to_string()),
|
None => Some(Uuid::new_v4().to_string()),
|
||||||
};
|
};
|
||||||
let model = saved_conversation.model;
|
let model = saved_conversation.model;
|
||||||
|
let completion_provider: Box<dyn CompletionProvider> = Box::new(
|
||||||
|
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
|
||||||
|
);
|
||||||
|
completion_provider.retrieve_credentials(cx);
|
||||||
let markdown = language_registry.language_for_name("Markdown");
|
let markdown = language_registry.language_for_name("Markdown");
|
||||||
let mut message_anchors = Vec::new();
|
let mut message_anchors = Vec::new();
|
||||||
let mut next_message_id = MessageId(0);
|
let mut next_message_id = MessageId(0);
|
||||||
@ -1609,8 +1594,8 @@ impl Conversation {
|
|||||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||||
pending_save: Task::ready(Ok(())),
|
pending_save: Task::ready(Ok(())),
|
||||||
path: Some(path),
|
path: Some(path),
|
||||||
api_key,
|
|
||||||
buffer,
|
buffer,
|
||||||
|
completion_provider,
|
||||||
};
|
};
|
||||||
this.count_remaining_tokens(cx);
|
this.count_remaining_tokens(cx);
|
||||||
this
|
this
|
||||||
@ -1731,11 +1716,11 @@ impl Conversation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if should_assist {
|
if should_assist {
|
||||||
let Some(api_key) = self.api_key.borrow().clone() else {
|
if !self.completion_provider.has_credentials() {
|
||||||
return Default::default();
|
return Default::default();
|
||||||
};
|
}
|
||||||
|
|
||||||
let request = OpenAIRequest {
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
model: self.model.full_name().to_string(),
|
model: self.model.full_name().to_string(),
|
||||||
messages: self
|
messages: self
|
||||||
.messages(cx)
|
.messages(cx)
|
||||||
@ -1745,9 +1730,9 @@ impl Conversation {
|
|||||||
stream: true,
|
stream: true,
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
});
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
let stream = self.completion_provider.complete(request);
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -1765,33 +1750,28 @@ impl Conversation {
|
|||||||
let mut messages = stream.await?;
|
let mut messages = stream.await?;
|
||||||
|
|
||||||
while let Some(message) = messages.next().await {
|
while let Some(message) = messages.next().await {
|
||||||
let mut message = message?;
|
let text = message?;
|
||||||
if let Some(choice) = message.choices.pop() {
|
|
||||||
this.upgrade(&cx)
|
|
||||||
.ok_or_else(|| anyhow!("conversation was dropped"))?
|
|
||||||
.update(&mut cx, |this, cx| {
|
|
||||||
let text: Arc<str> = choice.delta.content?.into();
|
|
||||||
let message_ix =
|
|
||||||
this.message_anchors.iter().position(|message| {
|
|
||||||
message.id == assistant_message_id
|
|
||||||
})?;
|
|
||||||
this.buffer.update(cx, |buffer, cx| {
|
|
||||||
let offset = this.message_anchors[message_ix + 1..]
|
|
||||||
.iter()
|
|
||||||
.find(|message| message.start.is_valid(buffer))
|
|
||||||
.map_or(buffer.len(), |message| {
|
|
||||||
message
|
|
||||||
.start
|
|
||||||
.to_offset(buffer)
|
|
||||||
.saturating_sub(1)
|
|
||||||
});
|
|
||||||
buffer.edit([(offset..offset, text)], None, cx);
|
|
||||||
});
|
|
||||||
cx.emit(ConversationEvent::StreamedCompletion);
|
|
||||||
|
|
||||||
Some(())
|
this.upgrade(&cx)
|
||||||
|
.ok_or_else(|| anyhow!("conversation was dropped"))?
|
||||||
|
.update(&mut cx, |this, cx| {
|
||||||
|
let message_ix = this
|
||||||
|
.message_anchors
|
||||||
|
.iter()
|
||||||
|
.position(|message| message.id == assistant_message_id)?;
|
||||||
|
this.buffer.update(cx, |buffer, cx| {
|
||||||
|
let offset = this.message_anchors[message_ix + 1..]
|
||||||
|
.iter()
|
||||||
|
.find(|message| message.start.is_valid(buffer))
|
||||||
|
.map_or(buffer.len(), |message| {
|
||||||
|
message.start.to_offset(buffer).saturating_sub(1)
|
||||||
|
});
|
||||||
|
buffer.edit([(offset..offset, text)], None, cx);
|
||||||
});
|
});
|
||||||
}
|
cx.emit(ConversationEvent::StreamedCompletion);
|
||||||
|
|
||||||
|
Some(())
|
||||||
|
});
|
||||||
smol::future::yield_now().await;
|
smol::future::yield_now().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2013,57 +1993,54 @@ impl Conversation {
|
|||||||
|
|
||||||
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
||||||
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
||||||
let api_key = self.api_key.borrow().clone();
|
if !self.completion_provider.has_credentials() {
|
||||||
if let Some(api_key) = api_key {
|
return;
|
||||||
let messages = self
|
|
||||||
.messages(cx)
|
|
||||||
.take(2)
|
|
||||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
|
||||||
.chain(Some(RequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content:
|
|
||||||
"Summarize the conversation into a short title without punctuation"
|
|
||||||
.into(),
|
|
||||||
}));
|
|
||||||
let request = OpenAIRequest {
|
|
||||||
model: self.model.full_name().to_string(),
|
|
||||||
messages: messages.collect(),
|
|
||||||
stream: true,
|
|
||||||
stop: vec![],
|
|
||||||
temperature: 1.0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
|
||||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
|
||||||
async move {
|
|
||||||
let mut messages = stream.await?;
|
|
||||||
|
|
||||||
while let Some(message) = messages.next().await {
|
|
||||||
let mut message = message?;
|
|
||||||
if let Some(choice) = message.choices.pop() {
|
|
||||||
let text = choice.delta.content.unwrap_or_default();
|
|
||||||
this.update(&mut cx, |this, cx| {
|
|
||||||
this.summary
|
|
||||||
.get_or_insert(Default::default())
|
|
||||||
.text
|
|
||||||
.push_str(&text);
|
|
||||||
cx.emit(ConversationEvent::SummaryChanged);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.update(&mut cx, |this, cx| {
|
|
||||||
if let Some(summary) = this.summary.as_mut() {
|
|
||||||
summary.done = true;
|
|
||||||
cx.emit(ConversationEvent::SummaryChanged);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
anyhow::Ok(())
|
|
||||||
}
|
|
||||||
.log_err()
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let messages = self
|
||||||
|
.messages(cx)
|
||||||
|
.take(2)
|
||||||
|
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||||
|
.chain(Some(RequestMessage {
|
||||||
|
role: Role::User,
|
||||||
|
content: "Summarize the conversation into a short title without punctuation"
|
||||||
|
.into(),
|
||||||
|
}));
|
||||||
|
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||||
|
model: self.model.full_name().to_string(),
|
||||||
|
messages: messages.collect(),
|
||||||
|
stream: true,
|
||||||
|
stop: vec![],
|
||||||
|
temperature: 1.0,
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = self.completion_provider.complete(request);
|
||||||
|
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||||
|
async move {
|
||||||
|
let mut messages = stream.await?;
|
||||||
|
|
||||||
|
while let Some(message) = messages.next().await {
|
||||||
|
let text = message?;
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
this.summary
|
||||||
|
.get_or_insert(Default::default())
|
||||||
|
.text
|
||||||
|
.push_str(&text);
|
||||||
|
cx.emit(ConversationEvent::SummaryChanged);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
this.update(&mut cx, |this, cx| {
|
||||||
|
if let Some(summary) = this.summary.as_mut() {
|
||||||
|
summary.done = true;
|
||||||
|
cx.emit(ConversationEvent::SummaryChanged);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
.log_err()
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2224,13 +2201,14 @@ struct ConversationEditor {
|
|||||||
|
|
||||||
impl ConversationEditor {
|
impl ConversationEditor {
|
||||||
fn new(
|
fn new(
|
||||||
api_key: Rc<RefCell<Option<String>>>,
|
completion_provider: Box<dyn CompletionProvider>,
|
||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
workspace: WeakViewHandle<Workspace>,
|
workspace: WeakViewHandle<Workspace>,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
|
let conversation =
|
||||||
|
cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
|
||||||
Self::for_conversation(conversation, fs, workspace, cx)
|
Self::for_conversation(conversation, fs, workspace, cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::MessageId;
|
use crate::MessageId;
|
||||||
|
use ai::test::FakeCompletionProvider;
|
||||||
use gpui::AppContext;
|
use gpui::AppContext;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
@ -3426,7 +3405,9 @@ mod tests {
|
|||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
|
||||||
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
|
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
@ -3554,7 +3535,9 @@ mod tests {
|
|||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
|
|
||||||
|
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
@ -3650,7 +3633,8 @@ mod tests {
|
|||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
|
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
|
|
||||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||||
@ -3732,8 +3716,9 @@ mod tests {
|
|||||||
cx.set_global(SettingsStore::test(cx));
|
cx.set_global(SettingsStore::test(cx));
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test());
|
let registry = Arc::new(LanguageRegistry::test());
|
||||||
|
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||||
let conversation =
|
let conversation =
|
||||||
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
|
cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
|
||||||
let buffer = conversation.read(cx).buffer.clone();
|
let buffer = conversation.read(cx).buffer.clone();
|
||||||
let message_0 = conversation.read(cx).message_anchors[0].id;
|
let message_0 = conversation.read(cx).message_anchors[0].id;
|
||||||
let message_1 = conversation.update(cx, |conversation, cx| {
|
let message_1 = conversation.update(cx, |conversation, cx| {
|
||||||
@ -3770,7 +3755,6 @@ mod tests {
|
|||||||
Conversation::deserialize(
|
Conversation::deserialize(
|
||||||
conversation.read(cx).serialize(cx),
|
conversation.read(cx).serialize(cx),
|
||||||
Default::default(),
|
Default::default(),
|
||||||
Default::default(),
|
|
||||||
registry.clone(),
|
registry.clone(),
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||||
use ai::completion::{CompletionProvider, OpenAIRequest};
|
use ai::completion::{CompletionProvider, CompletionRequest};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||||
@ -96,7 +96,7 @@ impl Codegen {
|
|||||||
self.error.as_ref()
|
self.error.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
|
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
|
||||||
let range = self.range();
|
let range = self.range();
|
||||||
let snapshot = self.snapshot.clone();
|
let snapshot = self.snapshot.clone();
|
||||||
let selected_text = snapshot
|
let selected_text = snapshot
|
||||||
@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use futures::{
|
use ai::test::FakeCompletionProvider;
|
||||||
future::BoxFuture,
|
use futures::stream::{self};
|
||||||
stream::{self, BoxStream},
|
|
||||||
};
|
|
||||||
use gpui::{executor::Deterministic, TestAppContext};
|
use gpui::{executor::Deterministic, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||||
use parking_lot::Mutex;
|
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
use serde::Serialize;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use smol::future::FutureExt;
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct DummyCompletionRequest {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest for DummyCompletionRequest {
|
||||||
|
fn data(&self) -> serde_json::Result<String> {
|
||||||
|
serde_json::to_string(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_transform_autoindent(
|
async fn test_transform_autoindent(
|
||||||
@ -372,7 +380,7 @@ mod tests {
|
|||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
@ -381,7 +389,11 @@ mod tests {
|
|||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
" let mut x = 0;\n",
|
" let mut x = 0;\n",
|
||||||
@ -434,7 +446,7 @@ mod tests {
|
|||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 6))
|
snapshot.anchor_before(Point::new(1, 6))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
@ -443,7 +455,11 @@ mod tests {
|
|||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"t mut x = 0;\n",
|
"t mut x = 0;\n",
|
||||||
@ -496,7 +512,7 @@ mod tests {
|
|||||||
let snapshot = buffer.snapshot(cx);
|
let snapshot = buffer.snapshot(cx);
|
||||||
snapshot.anchor_before(Point::new(1, 2))
|
snapshot.anchor_before(Point::new(1, 2))
|
||||||
});
|
});
|
||||||
let provider = Arc::new(TestCompletionProvider::new());
|
let provider = Arc::new(FakeCompletionProvider::new());
|
||||||
let codegen = cx.add_model(|cx| {
|
let codegen = cx.add_model(|cx| {
|
||||||
Codegen::new(
|
Codegen::new(
|
||||||
buffer.clone(),
|
buffer.clone(),
|
||||||
@ -505,7 +521,11 @@ mod tests {
|
|||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
|
||||||
|
let request = Box::new(DummyCompletionRequest {
|
||||||
|
name: "test".to_string(),
|
||||||
|
});
|
||||||
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"let mut x = 0;\n",
|
"let mut x = 0;\n",
|
||||||
@ -593,38 +613,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestCompletionProvider {
|
|
||||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestCompletionProvider {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
last_completion_tx: Mutex::new(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send_completion(&self, completion: impl Into<String>) {
|
|
||||||
let mut tx = self.last_completion_tx.lock();
|
|
||||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finish_completion(&self) {
|
|
||||||
self.last_completion_tx.lock().take().unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionProvider for TestCompletionProvider {
|
|
||||||
fn complete(
|
|
||||||
&self,
|
|
||||||
_prompt: OpenAIRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let (tx, rx) = mpsc::channel(1);
|
|
||||||
*self.last_completion_tx.lock() = Some(tx);
|
|
||||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rust_lang() -> Language {
|
fn rust_lang() -> Language {
|
||||||
Language::new(
|
Language::new(
|
||||||
LanguageConfig {
|
LanguageConfig {
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
use ai::models::{LanguageModel, OpenAILanguageModel};
|
use ai::models::LanguageModel;
|
||||||
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||||
use ai::templates::file_context::FileContext;
|
use ai::prompts::file_context::FileContext;
|
||||||
use ai::templates::generate::GenerateInlineContent;
|
use ai::prompts::generate::GenerateInlineContent;
|
||||||
use ai::templates::preamble::EngineerPreamble;
|
use ai::prompts::preamble::EngineerPreamble;
|
||||||
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
|
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||||
|
use ai::providers::open_ai::OpenAILanguageModel;
|
||||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||||
use std::cmp::{self, Reverse};
|
use std::cmp::{self, Reverse};
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
@ -25,7 +25,7 @@ collections = { path = "../collections" }
|
|||||||
gpui2 = { path = "../gpui2" }
|
gpui2 = { path = "../gpui2" }
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
live_kit_client = { path = "../live_kit_client" }
|
live_kit_client = { path = "../live_kit_client" }
|
||||||
fs = { path = "../fs" }
|
fs2 = { path = "../fs2" }
|
||||||
language2 = { path = "../language2" }
|
language2 = { path = "../language2" }
|
||||||
media = { path = "../media" }
|
media = { path = "../media" }
|
||||||
project2 = { path = "../project2" }
|
project2 = { path = "../project2" }
|
||||||
@ -43,7 +43,7 @@ serde_derive.workspace = true
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
client2 = { path = "../client2", features = ["test-support"] }
|
client2 = { path = "../client2", features = ["test-support"] }
|
||||||
fs = { path = "../fs", features = ["test-support"] }
|
fs2 = { path = "../fs2", features = ["test-support"] }
|
||||||
language2 = { path = "../language2", features = ["test-support"] }
|
language2 = { path = "../language2", features = ["test-support"] }
|
||||||
collections = { path = "../collections", features = ["test-support"] }
|
collections = { path = "../collections", features = ["test-support"] }
|
||||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
||||||
|
@ -12,8 +12,8 @@ use client2::{
|
|||||||
use collections::HashSet;
|
use collections::HashSet;
|
||||||
use futures::{future::Shared, FutureExt};
|
use futures::{future::Shared, FutureExt};
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Subscription, Task,
|
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
|
||||||
WeakHandle,
|
WeakModel,
|
||||||
};
|
};
|
||||||
use postage::watch;
|
use postage::watch;
|
||||||
use project2::Project;
|
use project2::Project;
|
||||||
@ -23,10 +23,10 @@ use std::sync::Arc;
|
|||||||
pub use participant::ParticipantLocation;
|
pub use participant::ParticipantLocation;
|
||||||
pub use room::Room;
|
pub use room::Room;
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, user_store: Handle<UserStore>, cx: &mut AppContext) {
|
pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
|
||||||
CallSettings::register(cx);
|
CallSettings::register(cx);
|
||||||
|
|
||||||
let active_call = cx.entity(|cx| ActiveCall::new(client, user_store, cx));
|
let active_call = cx.build_model(|cx| ActiveCall::new(client, user_store, cx));
|
||||||
cx.set_global(active_call);
|
cx.set_global(active_call);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,16 +40,16 @@ pub struct IncomingCall {
|
|||||||
|
|
||||||
/// Singleton global maintaining the user's participation in a room across workspaces.
|
/// Singleton global maintaining the user's participation in a room across workspaces.
|
||||||
pub struct ActiveCall {
|
pub struct ActiveCall {
|
||||||
room: Option<(Handle<Room>, Vec<Subscription>)>,
|
room: Option<(Model<Room>, Vec<Subscription>)>,
|
||||||
pending_room_creation: Option<Shared<Task<Result<Handle<Room>, Arc<anyhow::Error>>>>>,
|
pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
|
||||||
location: Option<WeakHandle<Project>>,
|
location: Option<WeakModel<Project>>,
|
||||||
pending_invites: HashSet<u64>,
|
pending_invites: HashSet<u64>,
|
||||||
incoming_call: (
|
incoming_call: (
|
||||||
watch::Sender<Option<IncomingCall>>,
|
watch::Sender<Option<IncomingCall>>,
|
||||||
watch::Receiver<Option<IncomingCall>>,
|
watch::Receiver<Option<IncomingCall>>,
|
||||||
),
|
),
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Handle<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
_subscriptions: Vec<client2::Subscription>,
|
_subscriptions: Vec<client2::Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,11 +58,7 @@ impl EventEmitter for ActiveCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ActiveCall {
|
impl ActiveCall {
|
||||||
fn new(
|
fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
|
||||||
client: Arc<Client>,
|
|
||||||
user_store: Handle<UserStore>,
|
|
||||||
cx: &mut ModelContext<Self>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
room: None,
|
room: None,
|
||||||
pending_room_creation: None,
|
pending_room_creation: None,
|
||||||
@ -84,7 +80,7 @@ impl ActiveCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_incoming_call(
|
async fn handle_incoming_call(
|
||||||
this: Handle<Self>,
|
this: Model<Self>,
|
||||||
envelope: TypedEnvelope<proto::IncomingCall>,
|
envelope: TypedEnvelope<proto::IncomingCall>,
|
||||||
_: Arc<Client>,
|
_: Arc<Client>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
@ -112,7 +108,7 @@ impl ActiveCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_call_canceled(
|
async fn handle_call_canceled(
|
||||||
this: Handle<Self>,
|
this: Model<Self>,
|
||||||
envelope: TypedEnvelope<proto::CallCanceled>,
|
envelope: TypedEnvelope<proto::CallCanceled>,
|
||||||
_: Arc<Client>,
|
_: Arc<Client>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
@ -129,14 +125,14 @@ impl ActiveCall {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn global(cx: &AppContext) -> Handle<Self> {
|
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||||
cx.global::<Handle<Self>>().clone()
|
cx.global::<Model<Self>>().clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn invite(
|
pub fn invite(
|
||||||
&mut self,
|
&mut self,
|
||||||
called_user_id: u64,
|
called_user_id: u64,
|
||||||
initial_project: Option<Handle<Project>>,
|
initial_project: Option<Model<Project>>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if !self.pending_invites.insert(called_user_id) {
|
if !self.pending_invites.insert(called_user_id) {
|
||||||
@ -291,7 +287,7 @@ impl ActiveCall {
|
|||||||
&mut self,
|
&mut self,
|
||||||
channel_id: u64,
|
channel_id: u64,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<Handle<Room>>> {
|
) -> Task<Result<Model<Room>>> {
|
||||||
if let Some(room) = self.room().cloned() {
|
if let Some(room) = self.room().cloned() {
|
||||||
if room.read(cx).channel_id() == Some(channel_id) {
|
if room.read(cx).channel_id() == Some(channel_id) {
|
||||||
return Task::ready(Ok(room));
|
return Task::ready(Ok(room));
|
||||||
@ -327,7 +323,7 @@ impl ActiveCall {
|
|||||||
|
|
||||||
pub fn share_project(
|
pub fn share_project(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<u64>> {
|
) -> Task<Result<u64>> {
|
||||||
if let Some((room, _)) = self.room.as_ref() {
|
if let Some((room, _)) = self.room.as_ref() {
|
||||||
@ -340,7 +336,7 @@ impl ActiveCall {
|
|||||||
|
|
||||||
pub fn unshare_project(
|
pub fn unshare_project(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if let Some((room, _)) = self.room.as_ref() {
|
if let Some((room, _)) = self.room.as_ref() {
|
||||||
@ -351,13 +347,13 @@ impl ActiveCall {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn location(&self) -> Option<&WeakHandle<Project>> {
|
pub fn location(&self) -> Option<&WeakModel<Project>> {
|
||||||
self.location.as_ref()
|
self.location.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_location(
|
pub fn set_location(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: Option<&Handle<Project>>,
|
project: Option<&Model<Project>>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
|
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
|
||||||
@ -371,7 +367,7 @@ impl ActiveCall {
|
|||||||
|
|
||||||
fn set_room(
|
fn set_room(
|
||||||
&mut self,
|
&mut self,
|
||||||
room: Option<Handle<Room>>,
|
room: Option<Model<Room>>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
|
if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
|
||||||
@ -407,7 +403,7 @@ impl ActiveCall {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn room(&self) -> Option<&Handle<Room>> {
|
pub fn room(&self) -> Option<&Model<Room>> {
|
||||||
self.room.as_ref().map(|(room, _)| room)
|
self.room.as_ref().map(|(room, _)| room)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client2::ParticipantIndex;
|
use client2::ParticipantIndex;
|
||||||
use client2::{proto, User};
|
use client2::{proto, User};
|
||||||
use gpui2::WeakHandle;
|
use gpui2::WeakModel;
|
||||||
pub use live_kit_client::Frame;
|
pub use live_kit_client::Frame;
|
||||||
use project2::Project;
|
use project2::Project;
|
||||||
use std::{fmt, sync::Arc};
|
use std::{fmt, sync::Arc};
|
||||||
@ -33,7 +33,7 @@ impl ParticipantLocation {
|
|||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct LocalParticipant {
|
pub struct LocalParticipant {
|
||||||
pub projects: Vec<proto::ParticipantProject>,
|
pub projects: Vec<proto::ParticipantProject>,
|
||||||
pub active_project: Option<WeakHandle<Project>>,
|
pub active_project: Option<WeakModel<Project>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -13,10 +13,10 @@ use client2::{
|
|||||||
Client, ParticipantIndex, TypedEnvelope, User, UserStore,
|
Client, ParticipantIndex, TypedEnvelope, User, UserStore,
|
||||||
};
|
};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use fs::Fs;
|
use fs2::Fs;
|
||||||
use futures::{FutureExt, StreamExt};
|
use futures::{FutureExt, StreamExt};
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Task, WeakHandle,
|
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
|
||||||
};
|
};
|
||||||
use language2::LanguageRegistry;
|
use language2::LanguageRegistry;
|
||||||
use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate};
|
use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate};
|
||||||
@ -61,8 +61,8 @@ pub struct Room {
|
|||||||
channel_id: Option<u64>,
|
channel_id: Option<u64>,
|
||||||
// live_kit: Option<LiveKitRoom>,
|
// live_kit: Option<LiveKitRoom>,
|
||||||
status: RoomStatus,
|
status: RoomStatus,
|
||||||
shared_projects: HashSet<WeakHandle<Project>>,
|
shared_projects: HashSet<WeakModel<Project>>,
|
||||||
joined_projects: HashSet<WeakHandle<Project>>,
|
joined_projects: HashSet<WeakModel<Project>>,
|
||||||
local_participant: LocalParticipant,
|
local_participant: LocalParticipant,
|
||||||
remote_participants: BTreeMap<u64, RemoteParticipant>,
|
remote_participants: BTreeMap<u64, RemoteParticipant>,
|
||||||
pending_participants: Vec<Arc<User>>,
|
pending_participants: Vec<Arc<User>>,
|
||||||
@ -70,7 +70,7 @@ pub struct Room {
|
|||||||
pending_call_count: usize,
|
pending_call_count: usize,
|
||||||
leave_when_empty: bool,
|
leave_when_empty: bool,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Handle<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>,
|
follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>,
|
||||||
client_subscriptions: Vec<client2::Subscription>,
|
client_subscriptions: Vec<client2::Subscription>,
|
||||||
_subscriptions: Vec<gpui2::Subscription>,
|
_subscriptions: Vec<gpui2::Subscription>,
|
||||||
@ -111,7 +111,7 @@ impl Room {
|
|||||||
channel_id: Option<u64>,
|
channel_id: Option<u64>,
|
||||||
live_kit_connection_info: Option<proto::LiveKitConnectionInfo>,
|
live_kit_connection_info: Option<proto::LiveKitConnectionInfo>,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Handle<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
todo!()
|
todo!()
|
||||||
@ -237,15 +237,15 @@ impl Room {
|
|||||||
|
|
||||||
pub(crate) fn create(
|
pub(crate) fn create(
|
||||||
called_user_id: u64,
|
called_user_id: u64,
|
||||||
initial_project: Option<Handle<Project>>,
|
initial_project: Option<Model<Project>>,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Handle<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
cx: &mut AppContext,
|
cx: &mut AppContext,
|
||||||
) -> Task<Result<Handle<Self>>> {
|
) -> Task<Result<Model<Self>>> {
|
||||||
cx.spawn(move |mut cx| async move {
|
cx.spawn(move |mut cx| async move {
|
||||||
let response = client.request(proto::CreateRoom {}).await?;
|
let response = client.request(proto::CreateRoom {}).await?;
|
||||||
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
|
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
|
||||||
let room = cx.entity(|cx| {
|
let room = cx.build_model(|cx| {
|
||||||
Self::new(
|
Self::new(
|
||||||
room_proto.id,
|
room_proto.id,
|
||||||
None,
|
None,
|
||||||
@ -283,9 +283,9 @@ impl Room {
|
|||||||
pub(crate) fn join_channel(
|
pub(crate) fn join_channel(
|
||||||
channel_id: u64,
|
channel_id: u64,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Handle<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
cx: &mut AppContext,
|
cx: &mut AppContext,
|
||||||
) -> Task<Result<Handle<Self>>> {
|
) -> Task<Result<Model<Self>>> {
|
||||||
cx.spawn(move |cx| async move {
|
cx.spawn(move |cx| async move {
|
||||||
Self::from_join_response(
|
Self::from_join_response(
|
||||||
client.request(proto::JoinChannel { channel_id }).await?,
|
client.request(proto::JoinChannel { channel_id }).await?,
|
||||||
@ -299,9 +299,9 @@ impl Room {
|
|||||||
pub(crate) fn join(
|
pub(crate) fn join(
|
||||||
call: &IncomingCall,
|
call: &IncomingCall,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Handle<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
cx: &mut AppContext,
|
cx: &mut AppContext,
|
||||||
) -> Task<Result<Handle<Self>>> {
|
) -> Task<Result<Model<Self>>> {
|
||||||
let id = call.room_id;
|
let id = call.room_id;
|
||||||
cx.spawn(move |cx| async move {
|
cx.spawn(move |cx| async move {
|
||||||
Self::from_join_response(
|
Self::from_join_response(
|
||||||
@ -343,11 +343,11 @@ impl Room {
|
|||||||
fn from_join_response(
|
fn from_join_response(
|
||||||
response: proto::JoinRoomResponse,
|
response: proto::JoinRoomResponse,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
user_store: Handle<UserStore>,
|
user_store: Model<UserStore>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Handle<Self>> {
|
) -> Result<Model<Self>> {
|
||||||
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
|
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
|
||||||
let room = cx.entity(|cx| {
|
let room = cx.build_model(|cx| {
|
||||||
Self::new(
|
Self::new(
|
||||||
room_proto.id,
|
room_proto.id,
|
||||||
response.channel_id,
|
response.channel_id,
|
||||||
@ -424,7 +424,7 @@ impl Room {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn maintain_connection(
|
async fn maintain_connection(
|
||||||
this: WeakHandle<Self>,
|
this: WeakModel<Self>,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
@ -661,7 +661,7 @@ impl Room {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_room_updated(
|
async fn handle_room_updated(
|
||||||
this: Handle<Self>,
|
this: Model<Self>,
|
||||||
envelope: TypedEnvelope<proto::RoomUpdated>,
|
envelope: TypedEnvelope<proto::RoomUpdated>,
|
||||||
_: Arc<Client>,
|
_: Arc<Client>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
@ -1101,7 +1101,7 @@ impl Room {
|
|||||||
language_registry: Arc<LanguageRegistry>,
|
language_registry: Arc<LanguageRegistry>,
|
||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<Handle<Project>>> {
|
) -> Task<Result<Model<Project>>> {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let user_store = self.user_store.clone();
|
let user_store = self.user_store.clone();
|
||||||
cx.emit(Event::RemoteProjectJoined { project_id: id });
|
cx.emit(Event::RemoteProjectJoined { project_id: id });
|
||||||
@ -1125,7 +1125,7 @@ impl Room {
|
|||||||
|
|
||||||
pub(crate) fn share_project(
|
pub(crate) fn share_project(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<u64>> {
|
) -> Task<Result<u64>> {
|
||||||
if let Some(project_id) = project.read(cx).remote_id() {
|
if let Some(project_id) = project.read(cx).remote_id() {
|
||||||
@ -1161,7 +1161,7 @@ impl Room {
|
|||||||
|
|
||||||
pub(crate) fn unshare_project(
|
pub(crate) fn unshare_project(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let project_id = match project.read(cx).remote_id() {
|
let project_id = match project.read(cx).remote_id() {
|
||||||
@ -1175,7 +1175,7 @@ impl Room {
|
|||||||
|
|
||||||
pub(crate) fn set_location(
|
pub(crate) fn set_location(
|
||||||
&mut self,
|
&mut self,
|
||||||
project: Option<&Handle<Project>>,
|
project: Option<&Model<Project>>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if self.status.is_offline() {
|
if self.status.is_offline() {
|
||||||
|
@ -14,8 +14,8 @@ use futures::{
|
|||||||
future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt,
|
future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt,
|
||||||
};
|
};
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
serde_json, AnyHandle, AnyWeakHandle, AppContext, AsyncAppContext, Handle, SemanticVersion,
|
serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task,
|
||||||
Task, WeakHandle,
|
WeakModel,
|
||||||
};
|
};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
@ -227,7 +227,7 @@ struct ClientState {
|
|||||||
_reconnect_task: Option<Task<()>>,
|
_reconnect_task: Option<Task<()>>,
|
||||||
reconnect_interval: Duration,
|
reconnect_interval: Duration,
|
||||||
entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
|
entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
|
||||||
models_by_message_type: HashMap<TypeId, AnyWeakHandle>,
|
models_by_message_type: HashMap<TypeId, AnyWeakModel>,
|
||||||
entity_types_by_message_type: HashMap<TypeId, TypeId>,
|
entity_types_by_message_type: HashMap<TypeId, TypeId>,
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
message_handlers: HashMap<
|
message_handlers: HashMap<
|
||||||
@ -236,7 +236,7 @@ struct ClientState {
|
|||||||
dyn Send
|
dyn Send
|
||||||
+ Sync
|
+ Sync
|
||||||
+ Fn(
|
+ Fn(
|
||||||
AnyHandle,
|
AnyModel,
|
||||||
Box<dyn AnyTypedEnvelope>,
|
Box<dyn AnyTypedEnvelope>,
|
||||||
&Arc<Client>,
|
&Arc<Client>,
|
||||||
AsyncAppContext,
|
AsyncAppContext,
|
||||||
@ -246,7 +246,7 @@ struct ClientState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum WeakSubscriber {
|
enum WeakSubscriber {
|
||||||
Entity { handle: AnyWeakHandle },
|
Entity { handle: AnyWeakModel },
|
||||||
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
|
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,7 +314,7 @@ impl<T> PendingEntitySubscription<T>
|
|||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
pub fn set_model(mut self, model: &Handle<T>, cx: &mut AsyncAppContext) -> Subscription {
|
pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
|
||||||
self.consumed = true;
|
self.consumed = true;
|
||||||
let mut state = self.client.state.write();
|
let mut state = self.client.state.write();
|
||||||
let id = (TypeId::of::<T>(), self.remote_id);
|
let id = (TypeId::of::<T>(), self.remote_id);
|
||||||
@ -552,13 +552,13 @@ impl Client {
|
|||||||
#[track_caller]
|
#[track_caller]
|
||||||
pub fn add_message_handler<M, E, H, F>(
|
pub fn add_message_handler<M, E, H, F>(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
entity: WeakHandle<E>,
|
entity: WeakModel<E>,
|
||||||
handler: H,
|
handler: H,
|
||||||
) -> Subscription
|
) -> Subscription
|
||||||
where
|
where
|
||||||
M: EnvelopedMessage,
|
M: EnvelopedMessage,
|
||||||
E: 'static + Send,
|
E: 'static + Send,
|
||||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||||
F: 'static + Future<Output = Result<()>> + Send,
|
F: 'static + Future<Output = Result<()>> + Send,
|
||||||
{
|
{
|
||||||
let message_type_id = TypeId::of::<M>();
|
let message_type_id = TypeId::of::<M>();
|
||||||
@ -594,13 +594,13 @@ impl Client {
|
|||||||
|
|
||||||
pub fn add_request_handler<M, E, H, F>(
|
pub fn add_request_handler<M, E, H, F>(
|
||||||
self: &Arc<Self>,
|
self: &Arc<Self>,
|
||||||
model: WeakHandle<E>,
|
model: WeakModel<E>,
|
||||||
handler: H,
|
handler: H,
|
||||||
) -> Subscription
|
) -> Subscription
|
||||||
where
|
where
|
||||||
M: RequestMessage,
|
M: RequestMessage,
|
||||||
E: 'static + Send,
|
E: 'static + Send,
|
||||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||||
F: 'static + Future<Output = Result<M::Response>> + Send,
|
F: 'static + Future<Output = Result<M::Response>> + Send,
|
||||||
{
|
{
|
||||||
self.add_message_handler(model, move |handle, envelope, this, cx| {
|
self.add_message_handler(model, move |handle, envelope, this, cx| {
|
||||||
@ -616,7 +616,7 @@ impl Client {
|
|||||||
where
|
where
|
||||||
M: EntityMessage,
|
M: EntityMessage,
|
||||||
E: 'static + Send,
|
E: 'static + Send,
|
||||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||||
F: 'static + Future<Output = Result<()>> + Send,
|
F: 'static + Future<Output = Result<()>> + Send,
|
||||||
{
|
{
|
||||||
self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
|
self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
|
||||||
@ -628,7 +628,7 @@ impl Client {
|
|||||||
where
|
where
|
||||||
M: EntityMessage,
|
M: EntityMessage,
|
||||||
E: 'static + Send,
|
E: 'static + Send,
|
||||||
H: 'static + Send + Sync + Fn(AnyHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||||
F: 'static + Future<Output = Result<()>> + Send,
|
F: 'static + Future<Output = Result<()>> + Send,
|
||||||
{
|
{
|
||||||
let model_type_id = TypeId::of::<E>();
|
let model_type_id = TypeId::of::<E>();
|
||||||
@ -667,7 +667,7 @@ impl Client {
|
|||||||
where
|
where
|
||||||
M: EntityMessage + RequestMessage,
|
M: EntityMessage + RequestMessage,
|
||||||
E: 'static + Send,
|
E: 'static + Send,
|
||||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||||
F: 'static + Future<Output = Result<M::Response>> + Send,
|
F: 'static + Future<Output = Result<M::Response>> + Send,
|
||||||
{
|
{
|
||||||
self.add_model_message_handler(move |entity, envelope, client, cx| {
|
self.add_model_message_handler(move |entity, envelope, client, cx| {
|
||||||
@ -1546,7 +1546,7 @@ mod tests {
|
|||||||
let (done_tx1, mut done_rx1) = smol::channel::unbounded();
|
let (done_tx1, mut done_rx1) = smol::channel::unbounded();
|
||||||
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
|
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
|
||||||
client.add_model_message_handler(
|
client.add_model_message_handler(
|
||||||
move |model: Handle<Model>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
|
move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
|
||||||
match model.update(&mut cx, |model, _| model.id).unwrap() {
|
match model.update(&mut cx, |model, _| model.id).unwrap() {
|
||||||
1 => done_tx1.try_send(()).unwrap(),
|
1 => done_tx1.try_send(()).unwrap(),
|
||||||
2 => done_tx2.try_send(()).unwrap(),
|
2 => done_tx2.try_send(()).unwrap(),
|
||||||
@ -1555,15 +1555,15 @@ mod tests {
|
|||||||
async { Ok(()) }
|
async { Ok(()) }
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
let model1 = cx.entity(|_| Model {
|
let model1 = cx.build_model(|_| TestModel {
|
||||||
id: 1,
|
id: 1,
|
||||||
subscription: None,
|
subscription: None,
|
||||||
});
|
});
|
||||||
let model2 = cx.entity(|_| Model {
|
let model2 = cx.build_model(|_| TestModel {
|
||||||
id: 2,
|
id: 2,
|
||||||
subscription: None,
|
subscription: None,
|
||||||
});
|
});
|
||||||
let model3 = cx.entity(|_| Model {
|
let model3 = cx.build_model(|_| TestModel {
|
||||||
id: 3,
|
id: 3,
|
||||||
subscription: None,
|
subscription: None,
|
||||||
});
|
});
|
||||||
@ -1596,7 +1596,7 @@ mod tests {
|
|||||||
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
|
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
|
||||||
let server = FakeServer::for_client(user_id, &client, cx).await;
|
let server = FakeServer::for_client(user_id, &client, cx).await;
|
||||||
|
|
||||||
let model = cx.entity(|_| Model::default());
|
let model = cx.build_model(|_| TestModel::default());
|
||||||
let (done_tx1, _done_rx1) = smol::channel::unbounded();
|
let (done_tx1, _done_rx1) = smol::channel::unbounded();
|
||||||
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
|
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
|
||||||
let subscription1 = client.add_message_handler(
|
let subscription1 = client.add_message_handler(
|
||||||
@ -1624,11 +1624,11 @@ mod tests {
|
|||||||
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
|
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
|
||||||
let server = FakeServer::for_client(user_id, &client, cx).await;
|
let server = FakeServer::for_client(user_id, &client, cx).await;
|
||||||
|
|
||||||
let model = cx.entity(|_| Model::default());
|
let model = cx.build_model(|_| TestModel::default());
|
||||||
let (done_tx, mut done_rx) = smol::channel::unbounded();
|
let (done_tx, mut done_rx) = smol::channel::unbounded();
|
||||||
let subscription = client.add_message_handler(
|
let subscription = client.add_message_handler(
|
||||||
model.clone().downgrade(),
|
model.clone().downgrade(),
|
||||||
move |model: Handle<Model>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
|
move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
|
||||||
model
|
model
|
||||||
.update(&mut cx, |model, _| model.subscription.take())
|
.update(&mut cx, |model, _| model.subscription.take())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -1644,7 +1644,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
struct Model {
|
struct TestModel {
|
||||||
id: usize,
|
id: usize,
|
||||||
subscription: Option<Subscription>,
|
subscription: Option<Subscription>,
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,9 @@ use parking_lot::Mutex;
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use settings2::Settings;
|
use settings2::Settings;
|
||||||
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
|
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
|
||||||
use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt};
|
use sysinfo::{
|
||||||
|
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
|
||||||
|
};
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
use util::http::HttpClient;
|
use util::http::HttpClient;
|
||||||
use util::{channel::ReleaseChannel, TryFutureExt};
|
use util::{channel::ReleaseChannel, TryFutureExt};
|
||||||
@ -161,8 +163,16 @@ impl Telemetry {
|
|||||||
|
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
cx.spawn(|cx| async move {
|
cx.spawn(|cx| async move {
|
||||||
let mut system = System::new_all();
|
// Avoiding calling `System::new_all()`, as there have been crashes related to it
|
||||||
system.refresh_all();
|
let refresh_kind = RefreshKind::new()
|
||||||
|
.with_memory() // For memory usage
|
||||||
|
.with_processes(ProcessRefreshKind::everything()) // For process usage
|
||||||
|
.with_cpu(CpuRefreshKind::everything()); // For core count
|
||||||
|
|
||||||
|
let mut system = System::new_with_specifics(refresh_kind);
|
||||||
|
|
||||||
|
// Avoiding calling `refresh_all()`, just update what we need
|
||||||
|
system.refresh_specifics(refresh_kind);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Waiting some amount of time before the first query is important to get a reasonable value
|
// Waiting some amount of time before the first query is important to get a reasonable value
|
||||||
@ -170,8 +180,7 @@ impl Telemetry {
|
|||||||
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
|
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
|
||||||
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
|
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
|
||||||
|
|
||||||
system.refresh_memory();
|
system.refresh_specifics(refresh_kind);
|
||||||
system.refresh_processes();
|
|
||||||
|
|
||||||
let current_process = Pid::from_u32(std::process::id());
|
let current_process = Pid::from_u32(std::process::id());
|
||||||
let Some(process) = system.processes().get(¤t_process) else {
|
let Some(process) = system.processes().get(¤t_process) else {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use futures::{stream::BoxStream, StreamExt};
|
use futures::{stream::BoxStream, StreamExt};
|
||||||
use gpui2::{Context, Executor, Handle, TestAppContext};
|
use gpui2::{Context, Executor, Model, TestAppContext};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rpc2::{
|
use rpc2::{
|
||||||
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
|
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
|
||||||
@ -194,9 +194,9 @@ impl FakeServer {
|
|||||||
&self,
|
&self,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
) -> Handle<UserStore> {
|
) -> Model<UserStore> {
|
||||||
let http_client = FakeHttpClient::with_404_response();
|
let http_client = FakeHttpClient::with_404_response();
|
||||||
let user_store = cx.entity(|cx| UserStore::new(client, http_client, cx));
|
let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
self.receive::<proto::GetUsers>()
|
self.receive::<proto::GetUsers>()
|
||||||
.await
|
.await
|
||||||
|
@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result};
|
|||||||
use collections::{hash_map::Entry, HashMap, HashSet};
|
use collections::{hash_map::Entry, HashMap, HashSet};
|
||||||
use feature_flags2::FeatureFlagAppExt;
|
use feature_flags2::FeatureFlagAppExt;
|
||||||
use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
|
use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
|
||||||
use gpui2::{AsyncAppContext, EventEmitter, Handle, ImageData, ModelContext, Task};
|
use gpui2::{AsyncAppContext, EventEmitter, ImageData, Model, ModelContext, Task};
|
||||||
use postage::{sink::Sink, watch};
|
use postage::{sink::Sink, watch};
|
||||||
use rpc2::proto::{RequestMessage, UsersResponse};
|
use rpc2::proto::{RequestMessage, UsersResponse};
|
||||||
use std::sync::{Arc, Weak};
|
use std::sync::{Arc, Weak};
|
||||||
@ -213,7 +213,7 @@ impl UserStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_update_invite_info(
|
async fn handle_update_invite_info(
|
||||||
this: Handle<Self>,
|
this: Model<Self>,
|
||||||
message: TypedEnvelope<proto::UpdateInviteInfo>,
|
message: TypedEnvelope<proto::UpdateInviteInfo>,
|
||||||
_: Arc<Client>,
|
_: Arc<Client>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
@ -229,7 +229,7 @@ impl UserStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_show_contacts(
|
async fn handle_show_contacts(
|
||||||
this: Handle<Self>,
|
this: Model<Self>,
|
||||||
_: TypedEnvelope<proto::ShowContacts>,
|
_: TypedEnvelope<proto::ShowContacts>,
|
||||||
_: Arc<Client>,
|
_: Arc<Client>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
@ -243,7 +243,7 @@ impl UserStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_update_contacts(
|
async fn handle_update_contacts(
|
||||||
this: Handle<Self>,
|
this: Model<Self>,
|
||||||
message: TypedEnvelope<proto::UpdateContacts>,
|
message: TypedEnvelope<proto::UpdateContacts>,
|
||||||
_: Arc<Client>,
|
_: Arc<Client>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
@ -690,7 +690,7 @@ impl User {
|
|||||||
impl Contact {
|
impl Contact {
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
contact: proto::Contact,
|
contact: proto::Contact,
|
||||||
user_store: &Handle<UserStore>,
|
user_store: &Model<UserStore>,
|
||||||
cx: &mut AsyncAppContext,
|
cx: &mut AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let user = user_store
|
let user = user_store
|
||||||
|
@ -7,8 +7,8 @@ use async_tar::Archive;
|
|||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
|
use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Handle, ModelContext, Task,
|
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Model, ModelContext, Task,
|
||||||
WeakHandle,
|
WeakModel,
|
||||||
};
|
};
|
||||||
use language2::{
|
use language2::{
|
||||||
language_settings::{all_language_settings, language_settings},
|
language_settings::{all_language_settings, language_settings},
|
||||||
@ -49,7 +49,7 @@ pub fn init(
|
|||||||
node_runtime: Arc<dyn NodeRuntime>,
|
node_runtime: Arc<dyn NodeRuntime>,
|
||||||
cx: &mut AppContext,
|
cx: &mut AppContext,
|
||||||
) {
|
) {
|
||||||
let copilot = cx.entity({
|
let copilot = cx.build_model({
|
||||||
let node_runtime = node_runtime.clone();
|
let node_runtime = node_runtime.clone();
|
||||||
move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
|
move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
|
||||||
});
|
});
|
||||||
@ -183,7 +183,7 @@ struct RegisteredBuffer {
|
|||||||
impl RegisteredBuffer {
|
impl RegisteredBuffer {
|
||||||
fn report_changes(
|
fn report_changes(
|
||||||
&mut self,
|
&mut self,
|
||||||
buffer: &Handle<Buffer>,
|
buffer: &Model<Buffer>,
|
||||||
cx: &mut ModelContext<Copilot>,
|
cx: &mut ModelContext<Copilot>,
|
||||||
) -> oneshot::Receiver<(i32, BufferSnapshot)> {
|
) -> oneshot::Receiver<(i32, BufferSnapshot)> {
|
||||||
let (done_tx, done_rx) = oneshot::channel();
|
let (done_tx, done_rx) = oneshot::channel();
|
||||||
@ -278,7 +278,7 @@ pub struct Copilot {
|
|||||||
http: Arc<dyn HttpClient>,
|
http: Arc<dyn HttpClient>,
|
||||||
node_runtime: Arc<dyn NodeRuntime>,
|
node_runtime: Arc<dyn NodeRuntime>,
|
||||||
server: CopilotServer,
|
server: CopilotServer,
|
||||||
buffers: HashSet<WeakHandle<Buffer>>,
|
buffers: HashSet<WeakModel<Buffer>>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
_subscription: gpui2::Subscription,
|
_subscription: gpui2::Subscription,
|
||||||
}
|
}
|
||||||
@ -292,9 +292,9 @@ impl EventEmitter for Copilot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Copilot {
|
impl Copilot {
|
||||||
pub fn global(cx: &AppContext) -> Option<Handle<Self>> {
|
pub fn global(cx: &AppContext) -> Option<Model<Self>> {
|
||||||
if cx.has_global::<Handle<Self>>() {
|
if cx.has_global::<Model<Self>>() {
|
||||||
Some(cx.global::<Handle<Self>>().clone())
|
Some(cx.global::<Model<Self>>().clone())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@ -383,7 +383,7 @@ impl Copilot {
|
|||||||
new_server_id: LanguageServerId,
|
new_server_id: LanguageServerId,
|
||||||
http: Arc<dyn HttpClient>,
|
http: Arc<dyn HttpClient>,
|
||||||
node_runtime: Arc<dyn NodeRuntime>,
|
node_runtime: Arc<dyn NodeRuntime>,
|
||||||
this: WeakHandle<Self>,
|
this: WeakModel<Self>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> impl Future<Output = ()> {
|
) -> impl Future<Output = ()> {
|
||||||
async move {
|
async move {
|
||||||
@ -590,7 +590,7 @@ impl Copilot {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register_buffer(&mut self, buffer: &Handle<Buffer>, cx: &mut ModelContext<Self>) {
|
pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
|
||||||
let weak_buffer = buffer.downgrade();
|
let weak_buffer = buffer.downgrade();
|
||||||
self.buffers.insert(weak_buffer.clone());
|
self.buffers.insert(weak_buffer.clone());
|
||||||
|
|
||||||
@ -646,7 +646,7 @@ impl Copilot {
|
|||||||
|
|
||||||
fn handle_buffer_event(
|
fn handle_buffer_event(
|
||||||
&mut self,
|
&mut self,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
event: &language2::Event,
|
event: &language2::Event,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
@ -706,7 +706,7 @@ impl Copilot {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unregister_buffer(&mut self, buffer: &WeakHandle<Buffer>) {
|
fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
|
||||||
if let Ok(server) = self.server.as_running() {
|
if let Ok(server) = self.server.as_running() {
|
||||||
if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
|
if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
|
||||||
server
|
server
|
||||||
@ -723,7 +723,7 @@ impl Copilot {
|
|||||||
|
|
||||||
pub fn completions<T>(
|
pub fn completions<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
buffer: &Handle<Buffer>,
|
buffer: &Model<Buffer>,
|
||||||
position: T,
|
position: T,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<Vec<Completion>>>
|
) -> Task<Result<Vec<Completion>>>
|
||||||
@ -735,7 +735,7 @@ impl Copilot {
|
|||||||
|
|
||||||
pub fn completions_cycling<T>(
|
pub fn completions_cycling<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
buffer: &Handle<Buffer>,
|
buffer: &Model<Buffer>,
|
||||||
position: T,
|
position: T,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<Vec<Completion>>>
|
) -> Task<Result<Vec<Completion>>>
|
||||||
@ -792,7 +792,7 @@ impl Copilot {
|
|||||||
|
|
||||||
fn request_completions<R, T>(
|
fn request_completions<R, T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
buffer: &Handle<Buffer>,
|
buffer: &Model<Buffer>,
|
||||||
position: T,
|
position: T,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<Vec<Completion>>>
|
) -> Task<Result<Vec<Completion>>>
|
||||||
@ -926,7 +926,7 @@ fn id_for_language(language: Option<&Arc<Language>>) -> String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn uri_for_buffer(buffer: &Handle<Buffer>, cx: &AppContext) -> lsp2::Url {
|
fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp2::Url {
|
||||||
if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
|
if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
|
||||||
lsp2::Url::from_file_path(file.abs_path(cx)).unwrap()
|
lsp2::Url::from_file_path(file.abs_path(cx)).unwrap()
|
||||||
} else {
|
} else {
|
||||||
|
@ -967,7 +967,6 @@ impl CompletionsMenu {
|
|||||||
self.selected_item -= 1;
|
self.selected_item -= 1;
|
||||||
} else {
|
} else {
|
||||||
self.selected_item = self.matches.len() - 1;
|
self.selected_item = self.matches.len() - 1;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
}
|
}
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||||
@ -1538,7 +1537,6 @@ impl CodeActionsMenu {
|
|||||||
self.selected_item -= 1;
|
self.selected_item -= 1;
|
||||||
} else {
|
} else {
|
||||||
self.selected_item = self.actions.len() - 1;
|
self.selected_item = self.actions.len() - 1;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
}
|
}
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||||
cx.notify();
|
cx.notify();
|
||||||
@ -1547,11 +1545,10 @@ impl CodeActionsMenu {
|
|||||||
fn select_next(&mut self, cx: &mut ViewContext<Editor>) {
|
fn select_next(&mut self, cx: &mut ViewContext<Editor>) {
|
||||||
if self.selected_item + 1 < self.actions.len() {
|
if self.selected_item + 1 < self.actions.len() {
|
||||||
self.selected_item += 1;
|
self.selected_item += 1;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
} else {
|
} else {
|
||||||
self.selected_item = 0;
|
self.selected_item = 0;
|
||||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
|
||||||
}
|
}
|
||||||
|
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||||
cx.notify();
|
cx.notify();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ use crate::{
|
|||||||
current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AnyWindowHandle,
|
current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AnyWindowHandle,
|
||||||
AppMetadata, AssetSource, ClipboardItem, Context, DispatchPhase, DisplayId, Executor,
|
AppMetadata, AssetSource, ClipboardItem, Context, DispatchPhase, DisplayId, Executor,
|
||||||
FocusEvent, FocusHandle, FocusId, KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly,
|
FocusEvent, FocusHandle, FocusId, KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly,
|
||||||
Pixels, Platform, Point, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
|
Pixels, Platform, Point, Render, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
|
||||||
TextStyle, TextStyleRefinement, TextSystem, View, ViewContext, Window, WindowContext,
|
TextStyle, TextStyleRefinement, TextSystem, View, ViewContext, Window, WindowContext,
|
||||||
WindowHandle, WindowId,
|
WindowHandle, WindowId,
|
||||||
};
|
};
|
||||||
@ -309,10 +309,17 @@ impl AppContext {
|
|||||||
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
|
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
|
||||||
) -> Result<R>
|
) -> Result<R>
|
||||||
where
|
where
|
||||||
V: 'static,
|
V: 'static + Send,
|
||||||
{
|
{
|
||||||
self.update_window(handle.any_handle, |cx| {
|
self.update_window(handle.any_handle, |cx| {
|
||||||
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap();
|
let root_view = cx
|
||||||
|
.window
|
||||||
|
.root_view
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.clone()
|
||||||
|
.downcast()
|
||||||
|
.unwrap();
|
||||||
root_view.update(cx, update)
|
root_view.update(cx, update)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -685,7 +692,7 @@ impl AppContext {
|
|||||||
|
|
||||||
pub fn observe_release<E: 'static>(
|
pub fn observe_release<E: 'static>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<E>,
|
handle: &Model<E>,
|
||||||
mut on_release: impl FnMut(&mut E, &mut AppContext) + Send + 'static,
|
mut on_release: impl FnMut(&mut E, &mut AppContext) + Send + 'static,
|
||||||
) -> Subscription {
|
) -> Subscription {
|
||||||
self.release_listeners.insert(
|
self.release_listeners.insert(
|
||||||
@ -750,35 +757,35 @@ impl AppContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Context for AppContext {
|
impl Context for AppContext {
|
||||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||||
type Result<T> = T;
|
type Result<T> = T;
|
||||||
|
|
||||||
/// Build an entity that is owned by the application. The given function will be invoked with
|
/// Build an entity that is owned by the application. The given function will be invoked with
|
||||||
/// a `ModelContext` and must return an object representing the entity. A `Handle` will be returned
|
/// a `ModelContext` and must return an object representing the entity. A `Model` will be returned
|
||||||
/// which can be used to access the entity in a context.
|
/// which can be used to access the entity in a context.
|
||||||
fn entity<T: 'static + Send>(
|
fn build_model<T: 'static + Send>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Handle<T> {
|
) -> Model<T> {
|
||||||
self.update(|cx| {
|
self.update(|cx| {
|
||||||
let slot = cx.entities.reserve();
|
let slot = cx.entities.reserve();
|
||||||
let entity = build_entity(&mut ModelContext::mutable(cx, slot.downgrade()));
|
let entity = build_model(&mut ModelContext::mutable(cx, slot.downgrade()));
|
||||||
cx.entities.insert(slot, entity)
|
cx.entities.insert(slot, entity)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update the entity referenced by the given handle. The function is passed a mutable reference to the
|
/// Update the entity referenced by the given model. The function is passed a mutable reference to the
|
||||||
/// entity along with a `ModelContext` for the entity.
|
/// entity along with a `ModelContext` for the entity.
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
model: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> R {
|
) -> R {
|
||||||
self.update(|cx| {
|
self.update(|cx| {
|
||||||
let mut entity = cx.entities.lease(handle);
|
let mut entity = cx.entities.lease(model);
|
||||||
let result = update(
|
let result = update(
|
||||||
&mut entity,
|
&mut entity,
|
||||||
&mut ModelContext::mutable(cx, handle.downgrade()),
|
&mut ModelContext::mutable(cx, model.downgrade()),
|
||||||
);
|
);
|
||||||
cx.entities.end_lease(entity);
|
cx.entities.end_lease(entity);
|
||||||
result
|
result
|
||||||
@ -861,10 +868,17 @@ impl MainThread<AppContext> {
|
|||||||
update: impl FnOnce(&mut V, &mut MainThread<ViewContext<'_, '_, V>>) -> R,
|
update: impl FnOnce(&mut V, &mut MainThread<ViewContext<'_, '_, V>>) -> R,
|
||||||
) -> Result<R>
|
) -> Result<R>
|
||||||
where
|
where
|
||||||
V: 'static,
|
V: 'static + Send,
|
||||||
{
|
{
|
||||||
self.update_window(handle.any_handle, |cx| {
|
self.update_window(handle.any_handle, |cx| {
|
||||||
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap();
|
let root_view = cx
|
||||||
|
.window
|
||||||
|
.root_view
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.clone()
|
||||||
|
.downcast()
|
||||||
|
.unwrap();
|
||||||
root_view.update(cx, update)
|
root_view.update(cx, update)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -872,7 +886,7 @@ impl MainThread<AppContext> {
|
|||||||
/// Opens a new window with the given option and the root view returned by the given function.
|
/// Opens a new window with the given option and the root view returned by the given function.
|
||||||
/// The function is invoked with a `WindowContext`, which can be used to interact with window-specific
|
/// The function is invoked with a `WindowContext`, which can be used to interact with window-specific
|
||||||
/// functionality.
|
/// functionality.
|
||||||
pub fn open_window<V: 'static>(
|
pub fn open_window<V: Render>(
|
||||||
&mut self,
|
&mut self,
|
||||||
options: crate::WindowOptions,
|
options: crate::WindowOptions,
|
||||||
build_root_view: impl FnOnce(&mut WindowContext) -> View<V> + Send + 'static,
|
build_root_view: impl FnOnce(&mut WindowContext) -> View<V> + Send + 'static,
|
||||||
@ -955,10 +969,8 @@ impl<G: 'static> DerefMut for GlobalLease<G> {
|
|||||||
/// Contains state associated with an active drag operation, started by dragging an element
|
/// Contains state associated with an active drag operation, started by dragging an element
|
||||||
/// within the window or by dragging into the app from the underlying platform.
|
/// within the window or by dragging into the app from the underlying platform.
|
||||||
pub(crate) struct AnyDrag {
|
pub(crate) struct AnyDrag {
|
||||||
pub drag_handle_view: Option<AnyView>,
|
pub view: AnyView,
|
||||||
pub cursor_offset: Point<Pixels>,
|
pub cursor_offset: Point<Pixels>,
|
||||||
pub state: AnyBox,
|
|
||||||
pub state_type: TypeId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
AnyWindowHandle, AppContext, Component, Context, Executor, Handle, MainThread, ModelContext,
|
AnyWindowHandle, AppContext, Context, Executor, MainThread, Model, ModelContext, Result, Task,
|
||||||
Result, Task, View, ViewContext, VisualContext, WindowContext, WindowHandle,
|
View, ViewContext, VisualContext, WindowContext, WindowHandle,
|
||||||
};
|
};
|
||||||
use anyhow::Context as _;
|
use anyhow::Context as _;
|
||||||
use derive_more::{Deref, DerefMut};
|
use derive_more::{Deref, DerefMut};
|
||||||
@ -14,25 +14,25 @@ pub struct AsyncAppContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Context for AsyncAppContext {
|
impl Context for AsyncAppContext {
|
||||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||||
type Result<T> = Result<T>;
|
type Result<T> = Result<T>;
|
||||||
|
|
||||||
fn entity<T: 'static>(
|
fn build_model<T: 'static>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Self::Result<Handle<T>>
|
) -> Self::Result<Model<T>>
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
let app = self.app.upgrade().context("app was released")?;
|
let app = self.app.upgrade().context("app was released")?;
|
||||||
let mut lock = app.lock(); // Need this to compile
|
let mut lock = app.lock(); // Need this to compile
|
||||||
Ok(lock.entity(build_entity))
|
Ok(lock.build_model(build_model))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
handle: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> Self::Result<R> {
|
) -> Self::Result<R> {
|
||||||
let app = self.app.upgrade().context("app was released")?;
|
let app = self.app.upgrade().context("app was released")?;
|
||||||
let mut lock = app.lock(); // Need this to compile
|
let mut lock = app.lock(); // Need this to compile
|
||||||
@ -84,7 +84,7 @@ impl AsyncAppContext {
|
|||||||
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
|
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
|
||||||
) -> Result<R>
|
) -> Result<R>
|
||||||
where
|
where
|
||||||
V: 'static,
|
V: 'static + Send,
|
||||||
{
|
{
|
||||||
let app = self.app.upgrade().context("app was released")?;
|
let app = self.app.upgrade().context("app was released")?;
|
||||||
let mut app_context = app.lock();
|
let mut app_context = app.lock();
|
||||||
@ -234,24 +234,24 @@ impl AsyncWindowContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Context for AsyncWindowContext {
|
impl Context for AsyncWindowContext {
|
||||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||||
type Result<T> = Result<T>;
|
type Result<T> = Result<T>;
|
||||||
|
|
||||||
fn entity<T>(
|
fn build_model<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Result<Handle<T>>
|
) -> Result<Model<T>>
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
self.app
|
self.app
|
||||||
.update_window(self.window, |cx| cx.entity(build_entity))
|
.update_window(self.window, |cx| cx.build_model(build_model))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
handle: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> Result<R> {
|
) -> Result<R> {
|
||||||
self.app
|
self.app
|
||||||
.update_window(self.window, |cx| cx.update_entity(handle, update))
|
.update_window(self.window, |cx| cx.update_entity(handle, update))
|
||||||
@ -261,17 +261,15 @@ impl Context for AsyncWindowContext {
|
|||||||
impl VisualContext for AsyncWindowContext {
|
impl VisualContext for AsyncWindowContext {
|
||||||
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
|
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
|
||||||
|
|
||||||
fn build_view<E, V>(
|
fn build_view<V>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
|
||||||
) -> Self::Result<View<V>>
|
) -> Self::Result<View<V>>
|
||||||
where
|
where
|
||||||
E: Component<V>,
|
|
||||||
V: 'static + Send,
|
V: 'static + Send,
|
||||||
{
|
{
|
||||||
self.app
|
self.app
|
||||||
.update_window(self.window, |cx| cx.build_view(build_entity, render))
|
.update_window(self.window, |cx| cx.build_view(build_view_state))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_view<V: 'static, R>(
|
fn update_view<V: 'static, R>(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{AnyBox, AppContext, Context, EntityHandle};
|
use crate::{AnyBox, AppContext, Context};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use derive_more::{Deref, DerefMut};
|
use derive_more::{Deref, DerefMut};
|
||||||
use parking_lot::{RwLock, RwLockUpgradableReadGuard};
|
use parking_lot::{RwLock, RwLockUpgradableReadGuard};
|
||||||
@ -53,29 +53,29 @@ impl EntityMap {
|
|||||||
/// Reserve a slot for an entity, which you can subsequently use with `insert`.
|
/// Reserve a slot for an entity, which you can subsequently use with `insert`.
|
||||||
pub fn reserve<T: 'static>(&self) -> Slot<T> {
|
pub fn reserve<T: 'static>(&self) -> Slot<T> {
|
||||||
let id = self.ref_counts.write().counts.insert(1.into());
|
let id = self.ref_counts.write().counts.insert(1.into());
|
||||||
Slot(Handle::new(id, Arc::downgrade(&self.ref_counts)))
|
Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert an entity into a slot obtained by calling `reserve`.
|
/// Insert an entity into a slot obtained by calling `reserve`.
|
||||||
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Handle<T>
|
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
let handle = slot.0;
|
let model = slot.0;
|
||||||
self.entities.insert(handle.entity_id, Box::new(entity));
|
self.entities.insert(model.entity_id, Box::new(entity));
|
||||||
handle
|
model
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Move an entity to the stack.
|
/// Move an entity to the stack.
|
||||||
pub fn lease<'a, T>(&mut self, handle: &'a Handle<T>) -> Lease<'a, T> {
|
pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
|
||||||
self.assert_valid_context(handle);
|
self.assert_valid_context(model);
|
||||||
let entity = Some(
|
let entity = Some(
|
||||||
self.entities
|
self.entities
|
||||||
.remove(handle.entity_id)
|
.remove(model.entity_id)
|
||||||
.expect("Circular entity lease. Is the entity already being updated?"),
|
.expect("Circular entity lease. Is the entity already being updated?"),
|
||||||
);
|
);
|
||||||
Lease {
|
Lease {
|
||||||
handle,
|
model,
|
||||||
entity,
|
entity,
|
||||||
entity_type: PhantomData,
|
entity_type: PhantomData,
|
||||||
}
|
}
|
||||||
@ -84,18 +84,18 @@ impl EntityMap {
|
|||||||
/// Return an entity after moving it to the stack.
|
/// Return an entity after moving it to the stack.
|
||||||
pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
|
pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
|
||||||
self.entities
|
self.entities
|
||||||
.insert(lease.handle.entity_id, lease.entity.take().unwrap());
|
.insert(lease.model.entity_id, lease.entity.take().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read<T: 'static>(&self, handle: &Handle<T>) -> &T {
|
pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
|
||||||
self.assert_valid_context(handle);
|
self.assert_valid_context(model);
|
||||||
self.entities[handle.entity_id].downcast_ref().unwrap()
|
self.entities[model.entity_id].downcast_ref().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn assert_valid_context(&self, handle: &AnyHandle) {
|
fn assert_valid_context(&self, model: &AnyModel) {
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
Weak::ptr_eq(&handle.entity_map, &Arc::downgrade(&self.ref_counts)),
|
Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
|
||||||
"used a handle with the wrong context"
|
"used a model with the wrong context"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ impl EntityMap {
|
|||||||
|
|
||||||
pub struct Lease<'a, T> {
|
pub struct Lease<'a, T> {
|
||||||
entity: Option<AnyBox>,
|
entity: Option<AnyBox>,
|
||||||
pub handle: &'a Handle<T>,
|
pub model: &'a Model<T>,
|
||||||
entity_type: PhantomData<T>,
|
entity_type: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,15 +143,15 @@ impl<'a, T> Drop for Lease<'a, T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deref, DerefMut)]
|
#[derive(Deref, DerefMut)]
|
||||||
pub struct Slot<T>(Handle<T>);
|
pub struct Slot<T>(Model<T>);
|
||||||
|
|
||||||
pub struct AnyHandle {
|
pub struct AnyModel {
|
||||||
pub(crate) entity_id: EntityId,
|
pub(crate) entity_id: EntityId,
|
||||||
entity_type: TypeId,
|
pub(crate) entity_type: TypeId,
|
||||||
entity_map: Weak<RwLock<EntityRefCounts>>,
|
entity_map: Weak<RwLock<EntityRefCounts>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnyHandle {
|
impl AnyModel {
|
||||||
fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
|
fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entity_id: id,
|
entity_id: id,
|
||||||
@ -164,18 +164,18 @@ impl AnyHandle {
|
|||||||
self.entity_id
|
self.entity_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn downgrade(&self) -> AnyWeakHandle {
|
pub fn downgrade(&self) -> AnyWeakModel {
|
||||||
AnyWeakHandle {
|
AnyWeakModel {
|
||||||
entity_id: self.entity_id,
|
entity_id: self.entity_id,
|
||||||
entity_type: self.entity_type,
|
entity_type: self.entity_type,
|
||||||
entity_ref_counts: self.entity_map.clone(),
|
entity_ref_counts: self.entity_map.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn downcast<T: 'static>(&self) -> Option<Handle<T>> {
|
pub fn downcast<T: 'static>(&self) -> Option<Model<T>> {
|
||||||
if TypeId::of::<T>() == self.entity_type {
|
if TypeId::of::<T>() == self.entity_type {
|
||||||
Some(Handle {
|
Some(Model {
|
||||||
any_handle: self.clone(),
|
any_model: self.clone(),
|
||||||
entity_type: PhantomData,
|
entity_type: PhantomData,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
@ -184,16 +184,16 @@ impl AnyHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Clone for AnyHandle {
|
impl Clone for AnyModel {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
if let Some(entity_map) = self.entity_map.upgrade() {
|
if let Some(entity_map) = self.entity_map.upgrade() {
|
||||||
let entity_map = entity_map.read();
|
let entity_map = entity_map.read();
|
||||||
let count = entity_map
|
let count = entity_map
|
||||||
.counts
|
.counts
|
||||||
.get(self.entity_id)
|
.get(self.entity_id)
|
||||||
.expect("detected over-release of a handle");
|
.expect("detected over-release of a model");
|
||||||
let prev_count = count.fetch_add(1, SeqCst);
|
let prev_count = count.fetch_add(1, SeqCst);
|
||||||
assert_ne!(prev_count, 0, "Detected over-release of a handle.");
|
assert_ne!(prev_count, 0, "Detected over-release of a model.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@ -204,16 +204,16 @@ impl Clone for AnyHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for AnyHandle {
|
impl Drop for AnyModel {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if let Some(entity_map) = self.entity_map.upgrade() {
|
if let Some(entity_map) = self.entity_map.upgrade() {
|
||||||
let entity_map = entity_map.upgradable_read();
|
let entity_map = entity_map.upgradable_read();
|
||||||
let count = entity_map
|
let count = entity_map
|
||||||
.counts
|
.counts
|
||||||
.get(self.entity_id)
|
.get(self.entity_id)
|
||||||
.expect("Detected over-release of a handle.");
|
.expect("Detected over-release of a model.");
|
||||||
let prev_count = count.fetch_sub(1, SeqCst);
|
let prev_count = count.fetch_sub(1, SeqCst);
|
||||||
assert_ne!(prev_count, 0, "Detected over-release of a handle.");
|
assert_ne!(prev_count, 0, "Detected over-release of a model.");
|
||||||
if prev_count == 1 {
|
if prev_count == 1 {
|
||||||
// We were the last reference to this entity, so we can remove it.
|
// We were the last reference to this entity, so we can remove it.
|
||||||
let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
|
let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
|
||||||
@ -223,60 +223,65 @@ impl Drop for AnyHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<Handle<T>> for AnyHandle {
|
impl<T> From<Model<T>> for AnyModel {
|
||||||
fn from(handle: Handle<T>) -> Self {
|
fn from(model: Model<T>) -> Self {
|
||||||
handle.any_handle
|
model.any_model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Hash for AnyHandle {
|
impl Hash for AnyModel {
|
||||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
self.entity_id.hash(state);
|
self.entity_id.hash(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PartialEq for AnyHandle {
|
impl PartialEq for AnyModel {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.entity_id == other.entity_id
|
self.entity_id == other.entity_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Eq for AnyHandle {}
|
impl Eq for AnyModel {}
|
||||||
|
|
||||||
#[derive(Deref, DerefMut)]
|
#[derive(Deref, DerefMut)]
|
||||||
pub struct Handle<T> {
|
pub struct Model<T> {
|
||||||
#[deref]
|
#[deref]
|
||||||
#[deref_mut]
|
#[deref_mut]
|
||||||
any_handle: AnyHandle,
|
pub(crate) any_model: AnyModel,
|
||||||
entity_type: PhantomData<T>,
|
pub(crate) entity_type: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<T> Send for Handle<T> {}
|
unsafe impl<T> Send for Model<T> {}
|
||||||
unsafe impl<T> Sync for Handle<T> {}
|
unsafe impl<T> Sync for Model<T> {}
|
||||||
|
|
||||||
impl<T: 'static> Handle<T> {
|
impl<T: 'static> Model<T> {
|
||||||
fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
|
fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
|
||||||
where
|
where
|
||||||
T: 'static,
|
T: 'static,
|
||||||
{
|
{
|
||||||
Self {
|
Self {
|
||||||
any_handle: AnyHandle::new(id, TypeId::of::<T>(), entity_map),
|
any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
|
||||||
entity_type: PhantomData,
|
entity_type: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn downgrade(&self) -> WeakHandle<T> {
|
pub fn downgrade(&self) -> WeakModel<T> {
|
||||||
WeakHandle {
|
WeakModel {
|
||||||
any_handle: self.any_handle.downgrade(),
|
any_model: self.any_model.downgrade(),
|
||||||
entity_type: self.entity_type,
|
entity_type: self.entity_type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert this into a dynamically typed model.
|
||||||
|
pub fn into_any(self) -> AnyModel {
|
||||||
|
self.any_model
|
||||||
|
}
|
||||||
|
|
||||||
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
|
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
|
||||||
cx.entities.read(self)
|
cx.entities.read(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update the entity referenced by this handle with the given function.
|
/// Update the entity referenced by this model with the given function.
|
||||||
///
|
///
|
||||||
/// The update function receives a context appropriate for its environment.
|
/// The update function receives a context appropriate for its environment.
|
||||||
/// When updating in an `AppContext`, it receives a `ModelContext`.
|
/// When updating in an `AppContext`, it receives a `ModelContext`.
|
||||||
@ -284,7 +289,7 @@ impl<T: 'static> Handle<T> {
|
|||||||
pub fn update<C, R>(
|
pub fn update<C, R>(
|
||||||
&self,
|
&self,
|
||||||
cx: &mut C,
|
cx: &mut C,
|
||||||
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
|
||||||
) -> C::Result<R>
|
) -> C::Result<R>
|
||||||
where
|
where
|
||||||
C: Context,
|
C: Context,
|
||||||
@ -293,73 +298,54 @@ impl<T: 'static> Handle<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Clone for Handle<T> {
|
impl<T> Clone for Model<T> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
any_handle: self.any_handle.clone(),
|
any_model: self.any_model.clone(),
|
||||||
entity_type: self.entity_type,
|
entity_type: self.entity_type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> std::fmt::Debug for Handle<T> {
|
impl<T> std::fmt::Debug for Model<T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"Handle {{ entity_id: {:?}, entity_type: {:?} }}",
|
"Model {{ entity_id: {:?}, entity_type: {:?} }}",
|
||||||
self.any_handle.entity_id,
|
self.any_model.entity_id,
|
||||||
type_name::<T>()
|
type_name::<T>()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Hash for Handle<T> {
|
impl<T> Hash for Model<T> {
|
||||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
self.any_handle.hash(state);
|
self.any_model.hash(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> PartialEq for Handle<T> {
|
impl<T> PartialEq for Model<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.any_handle == other.any_handle
|
self.any_model == other.any_model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Eq for Handle<T> {}
|
impl<T> Eq for Model<T> {}
|
||||||
|
|
||||||
impl<T> PartialEq<WeakHandle<T>> for Handle<T> {
|
impl<T> PartialEq<WeakModel<T>> for Model<T> {
|
||||||
fn eq(&self, other: &WeakHandle<T>) -> bool {
|
fn eq(&self, other: &WeakModel<T>) -> bool {
|
||||||
self.entity_id == other.entity_id
|
self.entity_id() == other.entity_id()
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: 'static> EntityHandle<T> for Handle<T> {
|
|
||||||
type Weak = WeakHandle<T>;
|
|
||||||
|
|
||||||
fn entity_id(&self) -> EntityId {
|
|
||||||
self.entity_id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn downgrade(&self) -> Self::Weak {
|
|
||||||
self.downgrade()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn upgrade_from(weak: &Self::Weak) -> Option<Self>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
weak.upgrade()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AnyWeakHandle {
|
pub struct AnyWeakModel {
|
||||||
pub(crate) entity_id: EntityId,
|
pub(crate) entity_id: EntityId,
|
||||||
entity_type: TypeId,
|
entity_type: TypeId,
|
||||||
entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
|
entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnyWeakHandle {
|
impl AnyWeakModel {
|
||||||
pub fn entity_id(&self) -> EntityId {
|
pub fn entity_id(&self) -> EntityId {
|
||||||
self.entity_id
|
self.entity_id
|
||||||
}
|
}
|
||||||
@ -373,14 +359,14 @@ impl AnyWeakHandle {
|
|||||||
ref_count > 0
|
ref_count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn upgrade(&self) -> Option<AnyHandle> {
|
pub fn upgrade(&self) -> Option<AnyModel> {
|
||||||
let entity_map = self.entity_ref_counts.upgrade()?;
|
let entity_map = self.entity_ref_counts.upgrade()?;
|
||||||
entity_map
|
entity_map
|
||||||
.read()
|
.read()
|
||||||
.counts
|
.counts
|
||||||
.get(self.entity_id)?
|
.get(self.entity_id)?
|
||||||
.fetch_add(1, SeqCst);
|
.fetch_add(1, SeqCst);
|
||||||
Some(AnyHandle {
|
Some(AnyModel {
|
||||||
entity_id: self.entity_id,
|
entity_id: self.entity_id,
|
||||||
entity_type: self.entity_type,
|
entity_type: self.entity_type,
|
||||||
entity_map: self.entity_ref_counts.clone(),
|
entity_map: self.entity_ref_counts.clone(),
|
||||||
@ -388,55 +374,55 @@ impl AnyWeakHandle {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<WeakHandle<T>> for AnyWeakHandle {
|
impl<T> From<WeakModel<T>> for AnyWeakModel {
|
||||||
fn from(handle: WeakHandle<T>) -> Self {
|
fn from(model: WeakModel<T>) -> Self {
|
||||||
handle.any_handle
|
model.any_model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Hash for AnyWeakHandle {
|
impl Hash for AnyWeakModel {
|
||||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
self.entity_id.hash(state);
|
self.entity_id.hash(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PartialEq for AnyWeakHandle {
|
impl PartialEq for AnyWeakModel {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.entity_id == other.entity_id
|
self.entity_id == other.entity_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Eq for AnyWeakHandle {}
|
impl Eq for AnyWeakModel {}
|
||||||
|
|
||||||
#[derive(Deref, DerefMut)]
|
#[derive(Deref, DerefMut)]
|
||||||
pub struct WeakHandle<T> {
|
pub struct WeakModel<T> {
|
||||||
#[deref]
|
#[deref]
|
||||||
#[deref_mut]
|
#[deref_mut]
|
||||||
any_handle: AnyWeakHandle,
|
any_model: AnyWeakModel,
|
||||||
entity_type: PhantomData<T>,
|
entity_type: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<T> Send for WeakHandle<T> {}
|
unsafe impl<T> Send for WeakModel<T> {}
|
||||||
unsafe impl<T> Sync for WeakHandle<T> {}
|
unsafe impl<T> Sync for WeakModel<T> {}
|
||||||
|
|
||||||
impl<T> Clone for WeakHandle<T> {
|
impl<T> Clone for WeakModel<T> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
any_handle: self.any_handle.clone(),
|
any_model: self.any_model.clone(),
|
||||||
entity_type: self.entity_type,
|
entity_type: self.entity_type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: 'static> WeakHandle<T> {
|
impl<T: 'static> WeakModel<T> {
|
||||||
pub fn upgrade(&self) -> Option<Handle<T>> {
|
pub fn upgrade(&self) -> Option<Model<T>> {
|
||||||
Some(Handle {
|
Some(Model {
|
||||||
any_handle: self.any_handle.upgrade()?,
|
any_model: self.any_model.upgrade()?,
|
||||||
entity_type: self.entity_type,
|
entity_type: self.entity_type,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update the entity referenced by this handle with the given function if
|
/// Update the entity referenced by this model with the given function if
|
||||||
/// the referenced entity still exists. Returns an error if the entity has
|
/// the referenced entity still exists. Returns an error if the entity has
|
||||||
/// been released.
|
/// been released.
|
||||||
///
|
///
|
||||||
@ -446,7 +432,7 @@ impl<T: 'static> WeakHandle<T> {
|
|||||||
pub fn update<C, R>(
|
pub fn update<C, R>(
|
||||||
&self,
|
&self,
|
||||||
cx: &mut C,
|
cx: &mut C,
|
||||||
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
|
||||||
) -> Result<R>
|
) -> Result<R>
|
||||||
where
|
where
|
||||||
C: Context,
|
C: Context,
|
||||||
@ -460,22 +446,22 @@ impl<T: 'static> WeakHandle<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Hash for WeakHandle<T> {
|
impl<T> Hash for WeakModel<T> {
|
||||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
self.any_handle.hash(state);
|
self.any_model.hash(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> PartialEq for WeakHandle<T> {
|
impl<T> PartialEq for WeakModel<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.any_handle == other.any_handle
|
self.any_model == other.any_model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Eq for WeakHandle<T> {}
|
impl<T> Eq for WeakModel<T> {}
|
||||||
|
|
||||||
impl<T> PartialEq<Handle<T>> for WeakHandle<T> {
|
impl<T> PartialEq<Model<T>> for WeakModel<T> {
|
||||||
fn eq(&self, other: &Handle<T>) -> bool {
|
fn eq(&self, other: &Model<T>) -> bool {
|
||||||
self.entity_id == other.entity_id
|
self.entity_id() == other.entity_id()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, Handle, MainThread,
|
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, MainThread, Model,
|
||||||
Reference, Subscription, Task, WeakHandle,
|
Reference, Subscription, Task, WeakModel,
|
||||||
};
|
};
|
||||||
use derive_more::{Deref, DerefMut};
|
use derive_more::{Deref, DerefMut};
|
||||||
use futures::FutureExt;
|
use futures::FutureExt;
|
||||||
@ -15,11 +15,11 @@ pub struct ModelContext<'a, T> {
|
|||||||
#[deref]
|
#[deref]
|
||||||
#[deref_mut]
|
#[deref_mut]
|
||||||
app: Reference<'a, AppContext>,
|
app: Reference<'a, AppContext>,
|
||||||
model_state: WeakHandle<T>,
|
model_state: WeakModel<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: 'static> ModelContext<'a, T> {
|
impl<'a, T: 'static> ModelContext<'a, T> {
|
||||||
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakHandle<T>) -> Self {
|
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel<T>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
app: Reference::Mutable(app),
|
app: Reference::Mutable(app),
|
||||||
model_state,
|
model_state,
|
||||||
@ -30,20 +30,20 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
|||||||
self.model_state.entity_id
|
self.model_state.entity_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn handle(&self) -> Handle<T> {
|
pub fn handle(&self) -> Model<T> {
|
||||||
self.weak_handle()
|
self.weak_handle()
|
||||||
.upgrade()
|
.upgrade()
|
||||||
.expect("The entity must be alive if we have a model context")
|
.expect("The entity must be alive if we have a model context")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn weak_handle(&self) -> WeakHandle<T> {
|
pub fn weak_handle(&self) -> WeakModel<T> {
|
||||||
self.model_state.clone()
|
self.model_state.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn observe<T2: 'static>(
|
pub fn observe<T2: 'static>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T2>,
|
handle: &Model<T2>,
|
||||||
mut on_notify: impl FnMut(&mut T, Handle<T2>, &mut ModelContext<'_, T>) + Send + 'static,
|
mut on_notify: impl FnMut(&mut T, Model<T2>, &mut ModelContext<'_, T>) + Send + 'static,
|
||||||
) -> Subscription
|
) -> Subscription
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
@ -65,10 +65,8 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
|||||||
|
|
||||||
pub fn subscribe<E: 'static + EventEmitter>(
|
pub fn subscribe<E: 'static + EventEmitter>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<E>,
|
handle: &Model<E>,
|
||||||
mut on_event: impl FnMut(&mut T, Handle<E>, &E::Event, &mut ModelContext<'_, T>)
|
mut on_event: impl FnMut(&mut T, Model<E>, &E::Event, &mut ModelContext<'_, T>) + Send + 'static,
|
||||||
+ Send
|
|
||||||
+ 'static,
|
|
||||||
) -> Subscription
|
) -> Subscription
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
@ -107,7 +105,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
|||||||
|
|
||||||
pub fn observe_release<E: 'static>(
|
pub fn observe_release<E: 'static>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<E>,
|
handle: &Model<E>,
|
||||||
mut on_release: impl FnMut(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + 'static,
|
mut on_release: impl FnMut(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + 'static,
|
||||||
) -> Subscription
|
) -> Subscription
|
||||||
where
|
where
|
||||||
@ -182,7 +180,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
|||||||
|
|
||||||
pub fn spawn<Fut, R>(
|
pub fn spawn<Fut, R>(
|
||||||
&self,
|
&self,
|
||||||
f: impl FnOnce(WeakHandle<T>, AsyncAppContext) -> Fut + Send + 'static,
|
f: impl FnOnce(WeakModel<T>, AsyncAppContext) -> Fut + Send + 'static,
|
||||||
) -> Task<R>
|
) -> Task<R>
|
||||||
where
|
where
|
||||||
T: 'static,
|
T: 'static,
|
||||||
@ -195,7 +193,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
|||||||
|
|
||||||
pub fn spawn_on_main<Fut, R>(
|
pub fn spawn_on_main<Fut, R>(
|
||||||
&self,
|
&self,
|
||||||
f: impl FnOnce(WeakHandle<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
|
f: impl FnOnce(WeakModel<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
|
||||||
) -> Task<R>
|
) -> Task<R>
|
||||||
where
|
where
|
||||||
Fut: Future<Output = R> + 'static,
|
Fut: Future<Output = R> + 'static,
|
||||||
@ -220,23 +218,23 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T> Context for ModelContext<'a, T> {
|
impl<'a, T> Context for ModelContext<'a, T> {
|
||||||
type EntityContext<'b, U> = ModelContext<'b, U>;
|
type ModelContext<'b, U> = ModelContext<'b, U>;
|
||||||
type Result<U> = U;
|
type Result<U> = U;
|
||||||
|
|
||||||
fn entity<U>(
|
fn build_model<U>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, U>) -> U,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, U>) -> U,
|
||||||
) -> Handle<U>
|
) -> Model<U>
|
||||||
where
|
where
|
||||||
U: 'static + Send,
|
U: 'static + Send,
|
||||||
{
|
{
|
||||||
self.app.entity(build_entity)
|
self.app.build_model(build_model)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_entity<U: 'static, R>(
|
fn update_entity<U: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<U>,
|
handle: &Model<U>,
|
||||||
update: impl FnOnce(&mut U, &mut Self::EntityContext<'_, U>) -> R,
|
update: impl FnOnce(&mut U, &mut Self::ModelContext<'_, U>) -> R,
|
||||||
) -> R {
|
) -> R {
|
||||||
self.app.update_entity(handle, update)
|
self.app.update_entity(handle, update)
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, Handle, MainThread,
|
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, MainThread, Model,
|
||||||
ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext,
|
ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext,
|
||||||
};
|
};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
@ -12,24 +12,24 @@ pub struct TestAppContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Context for TestAppContext {
|
impl Context for TestAppContext {
|
||||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||||
type Result<T> = T;
|
type Result<T> = T;
|
||||||
|
|
||||||
fn entity<T: 'static>(
|
fn build_model<T: 'static>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Self::Result<Handle<T>>
|
) -> Self::Result<Model<T>>
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
let mut lock = self.app.lock();
|
let mut lock = self.app.lock();
|
||||||
lock.entity(build_entity)
|
lock.build_model(build_model)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
handle: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> Self::Result<R> {
|
) -> Self::Result<R> {
|
||||||
let mut lock = self.app.lock();
|
let mut lock = self.app.lock();
|
||||||
lock.update_entity(handle, update)
|
lock.update_entity(handle, update)
|
||||||
|
@ -4,7 +4,7 @@ pub(crate) use smallvec::SmallVec;
|
|||||||
use std::{any::Any, mem};
|
use std::{any::Any, mem};
|
||||||
|
|
||||||
pub trait Element<V: 'static> {
|
pub trait Element<V: 'static> {
|
||||||
type ElementState: 'static;
|
type ElementState: 'static + Send;
|
||||||
|
|
||||||
fn id(&self) -> Option<ElementId>;
|
fn id(&self) -> Option<ElementId>;
|
||||||
|
|
||||||
|
@ -70,33 +70,31 @@ use taffy::TaffyLayoutEngine;
|
|||||||
type AnyBox = Box<dyn Any + Send>;
|
type AnyBox = Box<dyn Any + Send>;
|
||||||
|
|
||||||
pub trait Context {
|
pub trait Context {
|
||||||
type EntityContext<'a, T>;
|
type ModelContext<'a, T>;
|
||||||
type Result<T>;
|
type Result<T>;
|
||||||
|
|
||||||
fn entity<T>(
|
fn build_model<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Self::Result<Handle<T>>
|
) -> Self::Result<Model<T>>
|
||||||
where
|
where
|
||||||
T: 'static + Send;
|
T: 'static + Send;
|
||||||
|
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
handle: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> Self::Result<R>;
|
) -> Self::Result<R>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VisualContext: Context {
|
pub trait VisualContext: Context {
|
||||||
type ViewContext<'a, 'w, V>;
|
type ViewContext<'a, 'w, V>;
|
||||||
|
|
||||||
fn build_view<E, V>(
|
fn build_view<V>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
|
||||||
) -> Self::Result<View<V>>
|
) -> Self::Result<View<V>>
|
||||||
where
|
where
|
||||||
E: Component<V>,
|
|
||||||
V: 'static + Send;
|
V: 'static + Send;
|
||||||
|
|
||||||
fn update_view<V: 'static, R>(
|
fn update_view<V: 'static, R>(
|
||||||
@ -140,37 +138,37 @@ impl<T> DerefMut for MainThread<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<C: Context> Context for MainThread<C> {
|
impl<C: Context> Context for MainThread<C> {
|
||||||
type EntityContext<'a, T> = MainThread<C::EntityContext<'a, T>>;
|
type ModelContext<'a, T> = MainThread<C::ModelContext<'a, T>>;
|
||||||
type Result<T> = C::Result<T>;
|
type Result<T> = C::Result<T>;
|
||||||
|
|
||||||
fn entity<T>(
|
fn build_model<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Self::Result<Handle<T>>
|
) -> Self::Result<Model<T>>
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
self.0.entity(|cx| {
|
self.0.build_model(|cx| {
|
||||||
let cx = unsafe {
|
let cx = unsafe {
|
||||||
mem::transmute::<
|
mem::transmute::<
|
||||||
&mut C::EntityContext<'_, T>,
|
&mut C::ModelContext<'_, T>,
|
||||||
&mut MainThread<C::EntityContext<'_, T>>,
|
&mut MainThread<C::ModelContext<'_, T>>,
|
||||||
>(cx)
|
>(cx)
|
||||||
};
|
};
|
||||||
build_entity(cx)
|
build_model(cx)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
handle: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> Self::Result<R> {
|
) -> Self::Result<R> {
|
||||||
self.0.update_entity(handle, |entity, cx| {
|
self.0.update_entity(handle, |entity, cx| {
|
||||||
let cx = unsafe {
|
let cx = unsafe {
|
||||||
mem::transmute::<
|
mem::transmute::<
|
||||||
&mut C::EntityContext<'_, T>,
|
&mut C::ModelContext<'_, T>,
|
||||||
&mut MainThread<C::EntityContext<'_, T>>,
|
&mut MainThread<C::ModelContext<'_, T>>,
|
||||||
>(cx)
|
>(cx)
|
||||||
};
|
};
|
||||||
update(entity, cx)
|
update(entity, cx)
|
||||||
@ -181,27 +179,22 @@ impl<C: Context> Context for MainThread<C> {
|
|||||||
impl<C: VisualContext> VisualContext for MainThread<C> {
|
impl<C: VisualContext> VisualContext for MainThread<C> {
|
||||||
type ViewContext<'a, 'w, V> = MainThread<C::ViewContext<'a, 'w, V>>;
|
type ViewContext<'a, 'w, V> = MainThread<C::ViewContext<'a, 'w, V>>;
|
||||||
|
|
||||||
fn build_view<E, V>(
|
fn build_view<V>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
|
||||||
) -> Self::Result<View<V>>
|
) -> Self::Result<View<V>>
|
||||||
where
|
where
|
||||||
E: Component<V>,
|
|
||||||
V: 'static + Send,
|
V: 'static + Send,
|
||||||
{
|
{
|
||||||
self.0.build_view(
|
self.0.build_view(|cx| {
|
||||||
|cx| {
|
let cx = unsafe {
|
||||||
let cx = unsafe {
|
mem::transmute::<
|
||||||
mem::transmute::<
|
&mut C::ViewContext<'_, '_, V>,
|
||||||
&mut C::ViewContext<'_, '_, V>,
|
&mut MainThread<C::ViewContext<'_, '_, V>>,
|
||||||
&mut MainThread<C::ViewContext<'_, '_, V>>,
|
>(cx)
|
||||||
>(cx)
|
};
|
||||||
};
|
build_view_state(cx)
|
||||||
build_entity(cx)
|
})
|
||||||
},
|
|
||||||
render,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_view<V: 'static, R>(
|
fn update_view<V: 'static, R>(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
point, px, Action, AnyBox, AnyDrag, AppContext, BorrowWindow, Bounds, Component,
|
div, point, px, Action, AnyDrag, AnyView, AppContext, BorrowWindow, Bounds, Component,
|
||||||
DispatchContext, DispatchPhase, Element, ElementId, FocusHandle, KeyMatch, Keystroke,
|
DispatchContext, DispatchPhase, Div, Element, ElementId, FocusHandle, KeyMatch, Keystroke,
|
||||||
Modifiers, Overflow, Pixels, Point, SharedString, Size, Style, StyleRefinement, View,
|
Modifiers, Overflow, Pixels, Point, Render, SharedString, Size, Style, StyleRefinement, View,
|
||||||
ViewContext,
|
ViewContext,
|
||||||
};
|
};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
@ -258,17 +258,17 @@ pub trait StatelessInteractive<V: 'static>: Element<V> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_drop<S: 'static>(
|
fn on_drop<W: 'static + Send>(
|
||||||
mut self,
|
mut self,
|
||||||
listener: impl Fn(&mut V, S, &mut ViewContext<V>) + Send + 'static,
|
listener: impl Fn(&mut V, View<W>, &mut ViewContext<V>) + Send + 'static,
|
||||||
) -> Self
|
) -> Self
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
{
|
{
|
||||||
self.stateless_interaction().drop_listeners.push((
|
self.stateless_interaction().drop_listeners.push((
|
||||||
TypeId::of::<S>(),
|
TypeId::of::<W>(),
|
||||||
Box::new(move |view, drag_state, cx| {
|
Box::new(move |view, dragged_view, cx| {
|
||||||
listener(view, *drag_state.downcast().unwrap(), cx);
|
listener(view, dragged_view.downcast().unwrap(), cx);
|
||||||
}),
|
}),
|
||||||
));
|
));
|
||||||
self
|
self
|
||||||
@ -314,36 +314,22 @@ pub trait StatefulInteractive<V: 'static>: StatelessInteractive<V> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_drag<S, R, E>(
|
fn on_drag<W>(
|
||||||
mut self,
|
mut self,
|
||||||
listener: impl Fn(&mut V, &mut ViewContext<V>) -> Drag<S, R, V, E> + Send + 'static,
|
listener: impl Fn(&mut V, &mut ViewContext<V>) -> View<W> + Send + 'static,
|
||||||
) -> Self
|
) -> Self
|
||||||
where
|
where
|
||||||
Self: Sized,
|
Self: Sized,
|
||||||
S: Any + Send,
|
W: 'static + Send + Render,
|
||||||
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
|
||||||
R: 'static + Send,
|
|
||||||
E: Component<V>,
|
|
||||||
{
|
{
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
self.stateful_interaction().drag_listener.is_none(),
|
self.stateful_interaction().drag_listener.is_none(),
|
||||||
"calling on_drag more than once on the same element is not supported"
|
"calling on_drag more than once on the same element is not supported"
|
||||||
);
|
);
|
||||||
self.stateful_interaction().drag_listener =
|
self.stateful_interaction().drag_listener =
|
||||||
Some(Box::new(move |view_state, cursor_offset, cx| {
|
Some(Box::new(move |view_state, cursor_offset, cx| AnyDrag {
|
||||||
let drag = listener(view_state, cx);
|
view: listener(view_state, cx).into_any(),
|
||||||
let drag_handle_view = Some(
|
cursor_offset,
|
||||||
View::for_handle(cx.handle().upgrade().unwrap(), move |view_state, cx| {
|
|
||||||
(drag.render_drag_handle)(view_state, cx)
|
|
||||||
})
|
|
||||||
.into_any(),
|
|
||||||
);
|
|
||||||
AnyDrag {
|
|
||||||
drag_handle_view,
|
|
||||||
cursor_offset,
|
|
||||||
state: Box::new(drag.state),
|
|
||||||
state_type: TypeId::of::<S>(),
|
|
||||||
}
|
|
||||||
}));
|
}));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -412,7 +398,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
|||||||
if let Some(drag) = cx.active_drag.take() {
|
if let Some(drag) = cx.active_drag.take() {
|
||||||
for (state_type, group_drag_style) in &self.as_stateless().group_drag_over_styles {
|
for (state_type, group_drag_style) in &self.as_stateless().group_drag_over_styles {
|
||||||
if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
|
if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
|
||||||
if *state_type == drag.state_type
|
if *state_type == drag.view.entity_type()
|
||||||
&& group_bounds.contains_point(&mouse_position)
|
&& group_bounds.contains_point(&mouse_position)
|
||||||
{
|
{
|
||||||
style.refine(&group_drag_style.style);
|
style.refine(&group_drag_style.style);
|
||||||
@ -421,7 +407,8 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (state_type, drag_over_style) in &self.as_stateless().drag_over_styles {
|
for (state_type, drag_over_style) in &self.as_stateless().drag_over_styles {
|
||||||
if *state_type == drag.state_type && bounds.contains_point(&mouse_position) {
|
if *state_type == drag.view.entity_type() && bounds.contains_point(&mouse_position)
|
||||||
|
{
|
||||||
style.refine(drag_over_style);
|
style.refine(drag_over_style);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -509,7 +496,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
|||||||
cx.on_mouse_event(move |view, event: &MouseUpEvent, phase, cx| {
|
cx.on_mouse_event(move |view, event: &MouseUpEvent, phase, cx| {
|
||||||
if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
|
if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
|
||||||
if let Some(drag_state_type) =
|
if let Some(drag_state_type) =
|
||||||
cx.active_drag.as_ref().map(|drag| drag.state_type)
|
cx.active_drag.as_ref().map(|drag| drag.view.entity_type())
|
||||||
{
|
{
|
||||||
for (drop_state_type, listener) in &drop_listeners {
|
for (drop_state_type, listener) in &drop_listeners {
|
||||||
if *drop_state_type == drag_state_type {
|
if *drop_state_type == drag_state_type {
|
||||||
@ -517,7 +504,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
|||||||
.active_drag
|
.active_drag
|
||||||
.take()
|
.take()
|
||||||
.expect("checked for type drag state type above");
|
.expect("checked for type drag state type above");
|
||||||
listener(view, drag.state, cx);
|
listener(view, drag.view.clone(), cx);
|
||||||
cx.notify();
|
cx.notify();
|
||||||
cx.stop_propagation();
|
cx.stop_propagation();
|
||||||
}
|
}
|
||||||
@ -685,7 +672,7 @@ impl<V> From<ElementId> for StatefulInteraction<V> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type DropListener<V> = dyn Fn(&mut V, AnyBox, &mut ViewContext<V>) + 'static + Send;
|
type DropListener<V> = dyn Fn(&mut V, AnyView, &mut ViewContext<V>) + 'static + Send;
|
||||||
|
|
||||||
pub struct StatelessInteraction<V> {
|
pub struct StatelessInteraction<V> {
|
||||||
pub dispatch_context: DispatchContext,
|
pub dispatch_context: DispatchContext,
|
||||||
@ -866,7 +853,7 @@ pub struct Drag<S, R, V, E>
|
|||||||
where
|
where
|
||||||
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
||||||
V: 'static,
|
V: 'static,
|
||||||
E: Component<V>,
|
E: Component<()>,
|
||||||
{
|
{
|
||||||
pub state: S,
|
pub state: S,
|
||||||
pub render_drag_handle: R,
|
pub render_drag_handle: R,
|
||||||
@ -877,7 +864,7 @@ impl<S, R, V, E> Drag<S, R, V, E>
|
|||||||
where
|
where
|
||||||
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
||||||
V: 'static,
|
V: 'static,
|
||||||
E: Component<V>,
|
E: Component<()>,
|
||||||
{
|
{
|
||||||
pub fn new(state: S, render_drag_handle: R) -> Self {
|
pub fn new(state: S, render_drag_handle: R) -> Self {
|
||||||
Drag {
|
Drag {
|
||||||
@ -888,6 +875,10 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// impl<S, R, V, E> Render for Drag<S, R, V, E> {
|
||||||
|
// // fn render(&mut self, cx: ViewContext<Self>) ->
|
||||||
|
// }
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
|
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
|
||||||
pub enum MouseButton {
|
pub enum MouseButton {
|
||||||
Left,
|
Left,
|
||||||
@ -995,6 +986,14 @@ impl Deref for MouseExitEvent {
|
|||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct ExternalPaths(pub(crate) SmallVec<[PathBuf; 2]>);
|
pub struct ExternalPaths(pub(crate) SmallVec<[PathBuf; 2]>);
|
||||||
|
|
||||||
|
impl Render for ExternalPaths {
|
||||||
|
type Element = Div<Self>;
|
||||||
|
|
||||||
|
fn render(&mut self, _: &mut ViewContext<Self>) -> Self::Element {
|
||||||
|
div() // Intentionally left empty because the platform will render icons for the dragged files
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum FileDropEvent {
|
pub enum FileDropEvent {
|
||||||
Entered {
|
Entered {
|
||||||
|
@ -1,45 +1,35 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
AnyBox, AnyElement, AppContext, AvailableSpace, BorrowWindow, Bounds, Component, Element,
|
AnyBox, AnyElement, AnyModel, AppContext, AvailableSpace, BorrowWindow, Bounds, Component,
|
||||||
ElementId, EntityHandle, EntityId, Flatten, Handle, LayoutId, Pixels, Size, ViewContext,
|
Element, ElementId, EntityHandle, EntityId, Flatten, LayoutId, Model, Pixels, Size,
|
||||||
VisualContext, WeakHandle, WindowContext,
|
ViewContext, VisualContext, WeakModel, WindowContext,
|
||||||
};
|
};
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::{
|
use std::{
|
||||||
any::Any,
|
any::{Any, TypeId},
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
sync::{Arc, Weak},
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct View<V> {
|
pub trait Render: 'static + Sized {
|
||||||
pub(crate) state: Handle<V>,
|
type Element: Element<Self> + 'static + Send;
|
||||||
render: Arc<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
|
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: 'static> View<V> {
|
pub struct View<V> {
|
||||||
pub fn for_handle<E>(
|
pub(crate) model: Model<V>,
|
||||||
state: Handle<V>,
|
}
|
||||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
|
||||||
) -> View<V>
|
|
||||||
where
|
|
||||||
E: Component<V>,
|
|
||||||
{
|
|
||||||
View {
|
|
||||||
state,
|
|
||||||
render: Arc::new(Mutex::new(
|
|
||||||
move |state: &mut V, cx: &mut ViewContext<'_, '_, V>| render(state, cx).render(),
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
impl<V: Render> View<V> {
|
||||||
pub fn into_any(self) -> AnyView {
|
pub fn into_any(self) -> AnyView {
|
||||||
AnyView(Arc::new(self))
|
AnyView(Arc::new(self))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: 'static> View<V> {
|
||||||
pub fn downgrade(&self) -> WeakView<V> {
|
pub fn downgrade(&self) -> WeakView<V> {
|
||||||
WeakView {
|
WeakView {
|
||||||
state: self.state.downgrade(),
|
model: self.model.downgrade(),
|
||||||
render: Arc::downgrade(&self.render),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,20 +45,19 @@ impl<V: 'static> View<V> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a V {
|
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a V {
|
||||||
cx.entities.read(&self.state)
|
self.model.read(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V> Clone for View<V> {
|
impl<V> Clone for View<V> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
state: self.state.clone(),
|
model: self.model.clone(),
|
||||||
render: self.render.clone(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V> {
|
impl<V: Render, ParentViewState: 'static> Component<ParentViewState> for View<V> {
|
||||||
fn render(self) -> AnyElement<ParentViewState> {
|
fn render(self) -> AnyElement<ParentViewState> {
|
||||||
AnyElement::new(EraseViewState {
|
AnyElement::new(EraseViewState {
|
||||||
view: self,
|
view: self,
|
||||||
@ -77,11 +66,14 @@ impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: 'static> Element<()> for View<V> {
|
impl<V> Element<()> for View<V>
|
||||||
|
where
|
||||||
|
V: Render,
|
||||||
|
{
|
||||||
type ElementState = AnyElement<V>;
|
type ElementState = AnyElement<V>;
|
||||||
|
|
||||||
fn id(&self) -> Option<ElementId> {
|
fn id(&self) -> Option<crate::ElementId> {
|
||||||
Some(ElementId::View(self.state.entity_id))
|
Some(ElementId::View(self.model.entity_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initialize(
|
fn initialize(
|
||||||
@ -91,7 +83,7 @@ impl<V: 'static> Element<()> for View<V> {
|
|||||||
cx: &mut ViewContext<()>,
|
cx: &mut ViewContext<()>,
|
||||||
) -> Self::ElementState {
|
) -> Self::ElementState {
|
||||||
self.update(cx, |state, cx| {
|
self.update(cx, |state, cx| {
|
||||||
let mut any_element = (self.render.lock())(state, cx);
|
let mut any_element = AnyElement::new(state.render(cx));
|
||||||
any_element.initialize(state, cx);
|
any_element.initialize(state, cx);
|
||||||
any_element
|
any_element
|
||||||
})
|
})
|
||||||
@ -121,7 +113,7 @@ impl<T: 'static> EntityHandle<T> for View<T> {
|
|||||||
type Weak = WeakView<T>;
|
type Weak = WeakView<T>;
|
||||||
|
|
||||||
fn entity_id(&self) -> EntityId {
|
fn entity_id(&self) -> EntityId {
|
||||||
self.state.entity_id
|
self.model.entity_id
|
||||||
}
|
}
|
||||||
|
|
||||||
fn downgrade(&self) -> Self::Weak {
|
fn downgrade(&self) -> Self::Weak {
|
||||||
@ -137,15 +129,13 @@ impl<T: 'static> EntityHandle<T> for View<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct WeakView<V> {
|
pub struct WeakView<V> {
|
||||||
pub(crate) state: WeakHandle<V>,
|
pub(crate) model: WeakModel<V>,
|
||||||
render: Weak<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: 'static> WeakView<V> {
|
impl<V: 'static> WeakView<V> {
|
||||||
pub fn upgrade(&self) -> Option<View<V>> {
|
pub fn upgrade(&self) -> Option<View<V>> {
|
||||||
let state = self.state.upgrade()?;
|
let model = self.model.upgrade()?;
|
||||||
let render = self.render.upgrade()?;
|
Some(View { model })
|
||||||
Some(View { state, render })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update<C, R>(
|
pub fn update<C, R>(
|
||||||
@ -165,8 +155,7 @@ impl<V: 'static> WeakView<V> {
|
|||||||
impl<V> Clone for WeakView<V> {
|
impl<V> Clone for WeakView<V> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
state: self.state.clone(),
|
model: self.model.clone(),
|
||||||
render: self.render.clone(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -178,13 +167,13 @@ struct EraseViewState<V, ParentV> {
|
|||||||
|
|
||||||
unsafe impl<V, ParentV> Send for EraseViewState<V, ParentV> {}
|
unsafe impl<V, ParentV> Send for EraseViewState<V, ParentV> {}
|
||||||
|
|
||||||
impl<V: 'static, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> {
|
impl<V: Render, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> {
|
||||||
fn render(self) -> AnyElement<ParentV> {
|
fn render(self) -> AnyElement<ParentV> {
|
||||||
AnyElement::new(self)
|
AnyElement::new(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> {
|
impl<V: Render, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> {
|
||||||
type ElementState = AnyBox;
|
type ElementState = AnyBox;
|
||||||
|
|
||||||
fn id(&self) -> Option<ElementId> {
|
fn id(&self) -> Option<ElementId> {
|
||||||
@ -221,30 +210,43 @@ impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, Parent
|
|||||||
}
|
}
|
||||||
|
|
||||||
trait ViewObject: Send + Sync {
|
trait ViewObject: Send + Sync {
|
||||||
|
fn entity_type(&self) -> TypeId;
|
||||||
fn entity_id(&self) -> EntityId;
|
fn entity_id(&self) -> EntityId;
|
||||||
|
fn model(&self) -> AnyModel;
|
||||||
fn initialize(&self, cx: &mut WindowContext) -> AnyBox;
|
fn initialize(&self, cx: &mut WindowContext) -> AnyBox;
|
||||||
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId;
|
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId;
|
||||||
fn paint(&self, bounds: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext);
|
fn paint(&self, bounds: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext);
|
||||||
fn as_any(&self) -> &dyn Any;
|
fn as_any(&self) -> &dyn Any;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: 'static> ViewObject for View<V> {
|
impl<V> ViewObject for View<V>
|
||||||
|
where
|
||||||
|
V: Render,
|
||||||
|
{
|
||||||
|
fn entity_type(&self) -> TypeId {
|
||||||
|
TypeId::of::<V>()
|
||||||
|
}
|
||||||
|
|
||||||
fn entity_id(&self) -> EntityId {
|
fn entity_id(&self) -> EntityId {
|
||||||
self.state.entity_id
|
self.model.entity_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> AnyModel {
|
||||||
|
self.model.clone().into_any()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initialize(&self, cx: &mut WindowContext) -> AnyBox {
|
fn initialize(&self, cx: &mut WindowContext) -> AnyBox {
|
||||||
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
|
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
|
||||||
self.update(cx, |state, cx| {
|
self.update(cx, |state, cx| {
|
||||||
let mut any_element = Box::new((self.render.lock())(state, cx));
|
let mut any_element = Box::new(AnyElement::new(state.render(cx)));
|
||||||
any_element.initialize(state, cx);
|
any_element.initialize(state, cx);
|
||||||
any_element as AnyBox
|
any_element
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId {
|
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId {
|
||||||
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
|
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
|
||||||
self.update(cx, |state, cx| {
|
self.update(cx, |state, cx| {
|
||||||
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
|
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
|
||||||
element.layout(state, cx)
|
element.layout(state, cx)
|
||||||
@ -253,7 +255,7 @@ impl<V: 'static> ViewObject for View<V> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn paint(&self, _: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext) {
|
fn paint(&self, _: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext) {
|
||||||
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
|
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
|
||||||
self.update(cx, |state, cx| {
|
self.update(cx, |state, cx| {
|
||||||
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
|
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
|
||||||
element.paint(state, cx);
|
element.paint(state, cx);
|
||||||
@ -270,8 +272,12 @@ impl<V: 'static> ViewObject for View<V> {
|
|||||||
pub struct AnyView(Arc<dyn ViewObject>);
|
pub struct AnyView(Arc<dyn ViewObject>);
|
||||||
|
|
||||||
impl AnyView {
|
impl AnyView {
|
||||||
pub fn downcast<V: 'static>(&self) -> Option<View<V>> {
|
pub fn downcast<V: 'static + Send>(self) -> Option<View<V>> {
|
||||||
self.0.as_any().downcast_ref().cloned()
|
self.0.model().downcast().map(|model| View { model })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn entity_type(&self) -> TypeId {
|
||||||
|
self.0.entity_type()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn draw(&self, available_space: Size<AvailableSpace>, cx: &mut WindowContext) {
|
pub(crate) fn draw(&self, available_space: Size<AvailableSpace>, cx: &mut WindowContext) {
|
||||||
@ -343,6 +349,18 @@ impl<ParentV: 'static> Component<ParentV> for EraseAnyViewState<ParentV> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T, E> Render for T
|
||||||
|
where
|
||||||
|
T: 'static + FnMut(&mut WindowContext) -> E,
|
||||||
|
E: 'static + Send + Element<T>,
|
||||||
|
{
|
||||||
|
type Element = E;
|
||||||
|
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||||
|
(self)(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<ParentV: 'static> Element<ParentV> for EraseAnyViewState<ParentV> {
|
impl<ParentV: 'static> Element<ParentV> for EraseAnyViewState<ParentV> {
|
||||||
type ElementState = AnyBox;
|
type ElementState = AnyBox;
|
||||||
|
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
px, size, Action, AnyBox, AnyDrag, AnyView, AppContext, AsyncWindowContext, AvailableSpace,
|
px, size, Action, AnyBox, AnyDrag, AnyView, AppContext, AsyncWindowContext, AvailableSpace,
|
||||||
Bounds, BoxShadow, Context, Corners, DevicePixels, DispatchContext, DisplayId, Edges, Effect,
|
Bounds, BoxShadow, Context, Corners, DevicePixels, DispatchContext, DisplayId, Edges, Effect,
|
||||||
EntityHandle, EntityId, EventEmitter, ExternalPaths, FileDropEvent, FocusEvent, FontId,
|
EntityHandle, EntityId, EventEmitter, FileDropEvent, FocusEvent, FontId, GlobalElementId,
|
||||||
GlobalElementId, GlyphId, Handle, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch,
|
GlyphId, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch, KeyMatcher, Keystroke,
|
||||||
KeyMatcher, Keystroke, LayoutId, MainThread, MainThreadOnly, ModelContext, Modifiers,
|
LayoutId, MainThread, MainThreadOnly, Model, ModelContext, Modifiers, MonochromeSprite,
|
||||||
MonochromeSprite, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels,
|
MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas,
|
||||||
PlatformAtlas, PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams,
|
PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams, RenderImageParams,
|
||||||
RenderImageParams, RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size,
|
RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size, Style, Subscription,
|
||||||
Style, Subscription, TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext,
|
TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakModel, WeakView,
|
||||||
WeakHandle, WeakView, WindowOptions, SUBPIXEL_VARIANTS,
|
WindowOptions, SUBPIXEL_VARIANTS,
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
@ -918,15 +918,13 @@ impl<'a, 'w> WindowContext<'a, 'w> {
|
|||||||
root_view.draw(available_space, cx);
|
root_view.draw(available_space, cx);
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(mut active_drag) = self.app.active_drag.take() {
|
if let Some(active_drag) = self.app.active_drag.take() {
|
||||||
self.stack(1, |cx| {
|
self.stack(1, |cx| {
|
||||||
let offset = cx.mouse_position() - active_drag.cursor_offset;
|
let offset = cx.mouse_position() - active_drag.cursor_offset;
|
||||||
cx.with_element_offset(Some(offset), |cx| {
|
cx.with_element_offset(Some(offset), |cx| {
|
||||||
let available_space =
|
let available_space =
|
||||||
size(AvailableSpace::MinContent, AvailableSpace::MinContent);
|
size(AvailableSpace::MinContent, AvailableSpace::MinContent);
|
||||||
if let Some(drag_handle_view) = &mut active_drag.drag_handle_view {
|
active_drag.view.draw(available_space, cx);
|
||||||
drag_handle_view.draw(available_space, cx);
|
|
||||||
}
|
|
||||||
cx.active_drag = Some(active_drag);
|
cx.active_drag = Some(active_drag);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@ -994,12 +992,12 @@ impl<'a, 'w> WindowContext<'a, 'w> {
|
|||||||
InputEvent::FileDrop(file_drop) => match file_drop {
|
InputEvent::FileDrop(file_drop) => match file_drop {
|
||||||
FileDropEvent::Entered { position, files } => {
|
FileDropEvent::Entered { position, files } => {
|
||||||
self.window.mouse_position = position;
|
self.window.mouse_position = position;
|
||||||
self.active_drag.get_or_insert_with(|| AnyDrag {
|
if self.active_drag.is_none() {
|
||||||
drag_handle_view: None,
|
self.active_drag = Some(AnyDrag {
|
||||||
cursor_offset: position,
|
view: self.build_view(|_| files).into_any(),
|
||||||
state: Box::new(files),
|
cursor_offset: position,
|
||||||
state_type: TypeId::of::<ExternalPaths>(),
|
});
|
||||||
});
|
}
|
||||||
InputEvent::MouseDown(MouseDownEvent {
|
InputEvent::MouseDown(MouseDownEvent {
|
||||||
position,
|
position,
|
||||||
button: MouseButton::Left,
|
button: MouseButton::Left,
|
||||||
@ -1267,30 +1265,30 @@ impl<'a, 'w> WindowContext<'a, 'w> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Context for WindowContext<'_, '_> {
|
impl Context for WindowContext<'_, '_> {
|
||||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||||
type Result<T> = T;
|
type Result<T> = T;
|
||||||
|
|
||||||
fn entity<T>(
|
fn build_model<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Handle<T>
|
) -> Model<T>
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
let slot = self.app.entities.reserve();
|
let slot = self.app.entities.reserve();
|
||||||
let entity = build_entity(&mut ModelContext::mutable(&mut *self.app, slot.downgrade()));
|
let model = build_model(&mut ModelContext::mutable(&mut *self.app, slot.downgrade()));
|
||||||
self.entities.insert(slot, entity)
|
self.entities.insert(slot, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
model: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> R {
|
) -> R {
|
||||||
let mut entity = self.entities.lease(handle);
|
let mut entity = self.entities.lease(model);
|
||||||
let result = update(
|
let result = update(
|
||||||
&mut *entity,
|
&mut *entity,
|
||||||
&mut ModelContext::mutable(&mut *self.app, handle.downgrade()),
|
&mut ModelContext::mutable(&mut *self.app, model.downgrade()),
|
||||||
);
|
);
|
||||||
self.entities.end_lease(entity);
|
self.entities.end_lease(entity);
|
||||||
result
|
result
|
||||||
@ -1300,21 +1298,17 @@ impl Context for WindowContext<'_, '_> {
|
|||||||
impl VisualContext for WindowContext<'_, '_> {
|
impl VisualContext for WindowContext<'_, '_> {
|
||||||
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
|
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
|
||||||
|
|
||||||
/// Builds a new view in the current window. The first argument is a function that builds
|
fn build_view<V>(
|
||||||
/// an entity representing the view's state. It is invoked with a `ViewContext` that provides
|
|
||||||
/// entity-specific access to the window and application state during construction. The second
|
|
||||||
/// argument is a render function that returns a component based on the view's state.
|
|
||||||
fn build_view<E, V>(
|
|
||||||
&mut self,
|
&mut self,
|
||||||
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
|
||||||
) -> Self::Result<View<V>>
|
) -> Self::Result<View<V>>
|
||||||
where
|
where
|
||||||
E: crate::Component<V>,
|
|
||||||
V: 'static + Send,
|
V: 'static + Send,
|
||||||
{
|
{
|
||||||
let slot = self.app.entities.reserve();
|
let slot = self.app.entities.reserve();
|
||||||
let view = View::for_handle(slot.clone(), render);
|
let view = View {
|
||||||
|
model: slot.clone(),
|
||||||
|
};
|
||||||
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
|
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
|
||||||
let entity = build_view_state(&mut cx);
|
let entity = build_view_state(&mut cx);
|
||||||
self.entities.insert(slot, entity);
|
self.entities.insert(slot, entity);
|
||||||
@ -1327,7 +1321,7 @@ impl VisualContext for WindowContext<'_, '_> {
|
|||||||
view: &View<T>,
|
view: &View<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::ViewContext<'_, '_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ViewContext<'_, '_, T>) -> R,
|
||||||
) -> Self::Result<R> {
|
) -> Self::Result<R> {
|
||||||
let mut lease = self.app.entities.lease(&view.state);
|
let mut lease = self.app.entities.lease(&view.model);
|
||||||
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
|
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
|
||||||
let result = update(&mut *lease, &mut cx);
|
let result = update(&mut *lease, &mut cx);
|
||||||
cx.app.entities.end_lease(lease);
|
cx.app.entities.end_lease(lease);
|
||||||
@ -1582,8 +1576,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
|||||||
self.view.clone()
|
self.view.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn handle(&self) -> WeakHandle<V> {
|
pub fn model(&self) -> WeakModel<V> {
|
||||||
self.view.state.clone()
|
self.view.model.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stack<R>(&mut self, order: u32, f: impl FnOnce(&mut Self) -> R) -> R {
|
pub fn stack<R>(&mut self, order: u32, f: impl FnOnce(&mut Self) -> R) -> R {
|
||||||
@ -1603,8 +1597,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
|||||||
|
|
||||||
pub fn observe<E>(
|
pub fn observe<E>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<E>,
|
handle: &Model<E>,
|
||||||
mut on_notify: impl FnMut(&mut V, Handle<E>, &mut ViewContext<'_, '_, V>) + Send + 'static,
|
mut on_notify: impl FnMut(&mut V, Model<E>, &mut ViewContext<'_, '_, V>) + Send + 'static,
|
||||||
) -> Subscription
|
) -> Subscription
|
||||||
where
|
where
|
||||||
E: 'static,
|
E: 'static,
|
||||||
@ -1665,7 +1659,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
|||||||
) -> Subscription {
|
) -> Subscription {
|
||||||
let window_handle = self.window.handle;
|
let window_handle = self.window.handle;
|
||||||
self.app.release_listeners.insert(
|
self.app.release_listeners.insert(
|
||||||
self.view.state.entity_id,
|
self.view.model.entity_id,
|
||||||
Box::new(move |this, cx| {
|
Box::new(move |this, cx| {
|
||||||
let this = this.downcast_mut().expect("invalid entity type");
|
let this = this.downcast_mut().expect("invalid entity type");
|
||||||
// todo!("are we okay with silently swallowing the error?")
|
// todo!("are we okay with silently swallowing the error?")
|
||||||
@ -1676,7 +1670,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
|||||||
|
|
||||||
pub fn observe_release<T: 'static>(
|
pub fn observe_release<T: 'static>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
handle: &Model<T>,
|
||||||
mut on_release: impl FnMut(&mut V, &mut T, &mut ViewContext<'_, '_, V>) + Send + 'static,
|
mut on_release: impl FnMut(&mut V, &mut T, &mut ViewContext<'_, '_, V>) + Send + 'static,
|
||||||
) -> Subscription
|
) -> Subscription
|
||||||
where
|
where
|
||||||
@ -1698,7 +1692,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
|||||||
pub fn notify(&mut self) {
|
pub fn notify(&mut self) {
|
||||||
self.window_cx.notify();
|
self.window_cx.notify();
|
||||||
self.window_cx.app.push_effect(Effect::Notify {
|
self.window_cx.app.push_effect(Effect::Notify {
|
||||||
emitter: self.view.state.entity_id,
|
emitter: self.view.model.entity_id,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1878,7 +1872,7 @@ where
|
|||||||
V::Event: Any + Send,
|
V::Event: Any + Send,
|
||||||
{
|
{
|
||||||
pub fn emit(&mut self, event: V::Event) {
|
pub fn emit(&mut self, event: V::Event) {
|
||||||
let emitter = self.view.state.entity_id;
|
let emitter = self.view.model.entity_id;
|
||||||
self.app.push_effect(Effect::Emit {
|
self.app.push_effect(Effect::Emit {
|
||||||
emitter,
|
emitter,
|
||||||
event: Box::new(event),
|
event: Box::new(event),
|
||||||
@ -1897,41 +1891,36 @@ impl<'a, 'w, V: 'static> MainThread<ViewContext<'a, 'w, V>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> {
|
impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> {
|
||||||
type EntityContext<'b, U> = ModelContext<'b, U>;
|
type ModelContext<'b, U> = ModelContext<'b, U>;
|
||||||
type Result<U> = U;
|
type Result<U> = U;
|
||||||
|
|
||||||
fn entity<T>(
|
fn build_model<T>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||||
) -> Handle<T>
|
) -> Model<T>
|
||||||
where
|
where
|
||||||
T: 'static + Send,
|
T: 'static + Send,
|
||||||
{
|
{
|
||||||
self.window_cx.entity(build_entity)
|
self.window_cx.build_model(build_model)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_entity<T: 'static, R>(
|
fn update_entity<T: 'static, R>(
|
||||||
&mut self,
|
&mut self,
|
||||||
handle: &Handle<T>,
|
model: &Model<T>,
|
||||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||||
) -> R {
|
) -> R {
|
||||||
self.window_cx.update_entity(handle, update)
|
self.window_cx.update_entity(model, update)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<V: 'static> VisualContext for ViewContext<'_, '_, V> {
|
impl<V: 'static> VisualContext for ViewContext<'_, '_, V> {
|
||||||
type ViewContext<'a, 'w, V2> = ViewContext<'a, 'w, V2>;
|
type ViewContext<'a, 'w, V2> = ViewContext<'a, 'w, V2>;
|
||||||
|
|
||||||
fn build_view<E, V2>(
|
fn build_view<W: 'static + Send>(
|
||||||
&mut self,
|
&mut self,
|
||||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V2>) -> V2,
|
build_view: impl FnOnce(&mut Self::ViewContext<'_, '_, W>) -> W,
|
||||||
render: impl Fn(&mut V2, &mut ViewContext<'_, '_, V2>) -> E + Send + 'static,
|
) -> Self::Result<View<W>> {
|
||||||
) -> Self::Result<View<V2>>
|
self.window_cx.build_view(build_view)
|
||||||
where
|
|
||||||
E: crate::Component<V2>,
|
|
||||||
V2: 'static + Send,
|
|
||||||
{
|
|
||||||
self.window_cx.build_view(build_entity, render)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_view<V2: 'static, R>(
|
fn update_view<V2: 'static, R>(
|
||||||
|
@ -5,7 +5,7 @@ use crate::language_settings::{
|
|||||||
use crate::Buffer;
|
use crate::Buffer;
|
||||||
use clock::ReplicaId;
|
use clock::ReplicaId;
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use gpui2::{AppContext, Handle};
|
use gpui2::{AppContext, Model};
|
||||||
use gpui2::{Context, TestAppContext};
|
use gpui2::{Context, TestAppContext};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use proto::deserialize_operation;
|
use proto::deserialize_operation;
|
||||||
@ -42,7 +42,7 @@ fn init_logger() {
|
|||||||
fn test_line_endings(cx: &mut gpui2::AppContext) {
|
fn test_line_endings(cx: &mut gpui2::AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "one\r\ntwo\rthree")
|
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "one\r\ntwo\rthree")
|
||||||
.with_language(Arc::new(rust_lang()), cx);
|
.with_language(Arc::new(rust_lang()), cx);
|
||||||
assert_eq!(buffer.text(), "one\ntwo\nthree");
|
assert_eq!(buffer.text(), "one\ntwo\nthree");
|
||||||
@ -138,8 +138,8 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
|
|||||||
let buffer_1_events = Arc::new(Mutex::new(Vec::new()));
|
let buffer_1_events = Arc::new(Mutex::new(Vec::new()));
|
||||||
let buffer_2_events = Arc::new(Mutex::new(Vec::new()));
|
let buffer_2_events = Arc::new(Mutex::new(Vec::new()));
|
||||||
|
|
||||||
let buffer1 = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef"));
|
let buffer1 = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef"));
|
||||||
let buffer2 = cx.entity(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef"));
|
let buffer2 = cx.build_model(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef"));
|
||||||
let buffer1_ops = Arc::new(Mutex::new(Vec::new()));
|
let buffer1_ops = Arc::new(Mutex::new(Vec::new()));
|
||||||
buffer1.update(cx, {
|
buffer1.update(cx, {
|
||||||
let buffer1_ops = buffer1_ops.clone();
|
let buffer1_ops = buffer1_ops.clone();
|
||||||
@ -218,7 +218,7 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
|
|||||||
#[gpui2::test]
|
#[gpui2::test]
|
||||||
async fn test_apply_diff(cx: &mut TestAppContext) {
|
async fn test_apply_diff(cx: &mut TestAppContext) {
|
||||||
let text = "a\nbb\nccc\ndddd\neeeee\nffffff\n";
|
let text = "a\nbb\nccc\ndddd\neeeee\nffffff\n";
|
||||||
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
||||||
let anchor = buffer.update(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)));
|
let anchor = buffer.update(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)));
|
||||||
|
|
||||||
let text = "a\nccc\ndddd\nffffff\n";
|
let text = "a\nccc\ndddd\nffffff\n";
|
||||||
@ -250,7 +250,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
|
|||||||
]
|
]
|
||||||
.join("\n");
|
.join("\n");
|
||||||
|
|
||||||
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
||||||
|
|
||||||
// Spawn a task to format the buffer's whitespace.
|
// Spawn a task to format the buffer's whitespace.
|
||||||
// Pause so that the foratting task starts running.
|
// Pause so that the foratting task starts running.
|
||||||
@ -314,7 +314,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
|
|||||||
#[gpui2::test]
|
#[gpui2::test]
|
||||||
async fn test_reparse(cx: &mut gpui2::TestAppContext) {
|
async fn test_reparse(cx: &mut gpui2::TestAppContext) {
|
||||||
let text = "fn a() {}";
|
let text = "fn a() {}";
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -442,7 +442,7 @@ async fn test_reparse(cx: &mut gpui2::TestAppContext) {
|
|||||||
|
|
||||||
#[gpui2::test]
|
#[gpui2::test]
|
||||||
async fn test_resetting_language(cx: &mut gpui2::TestAppContext) {
|
async fn test_resetting_language(cx: &mut gpui2::TestAppContext) {
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
let mut buffer =
|
let mut buffer =
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), "{}").with_language(Arc::new(rust_lang()), cx);
|
Buffer::new(0, cx.entity_id().as_u64(), "{}").with_language(Arc::new(rust_lang()), cx);
|
||||||
buffer.set_sync_parse_timeout(Duration::ZERO);
|
buffer.set_sync_parse_timeout(Duration::ZERO);
|
||||||
@ -492,7 +492,7 @@ async fn test_outline(cx: &mut gpui2::TestAppContext) {
|
|||||||
"#
|
"#
|
||||||
.unindent();
|
.unindent();
|
||||||
|
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||||
});
|
});
|
||||||
let outline = buffer
|
let outline = buffer
|
||||||
@ -578,7 +578,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui2::TestAppContext) {
|
|||||||
"#
|
"#
|
||||||
.unindent();
|
.unindent();
|
||||||
|
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||||
});
|
});
|
||||||
let outline = buffer
|
let outline = buffer
|
||||||
@ -616,7 +616,7 @@ async fn test_outline_with_extra_context(cx: &mut gpui2::TestAppContext) {
|
|||||||
"#
|
"#
|
||||||
.unindent();
|
.unindent();
|
||||||
|
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(language), cx)
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(language), cx)
|
||||||
});
|
});
|
||||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
|
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
|
||||||
@ -660,7 +660,7 @@ async fn test_symbols_containing(cx: &mut gpui2::TestAppContext) {
|
|||||||
"#
|
"#
|
||||||
.unindent();
|
.unindent();
|
||||||
|
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||||
});
|
});
|
||||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
|
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
|
||||||
@ -881,7 +881,7 @@ fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &
|
|||||||
|
|
||||||
#[gpui2::test]
|
#[gpui2::test]
|
||||||
fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
|
fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = "fn a() { b(|c| {}) }";
|
let text = "fn a() { b(|c| {}) }";
|
||||||
let buffer =
|
let buffer =
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||||
@ -922,7 +922,7 @@ fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
|
|||||||
fn test_autoindent_with_soft_tabs(cx: &mut AppContext) {
|
fn test_autoindent_with_soft_tabs(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = "fn a() {}";
|
let text = "fn a() {}";
|
||||||
let mut buffer =
|
let mut buffer =
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||||
@ -965,7 +965,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
|
|||||||
settings.defaults.hard_tabs = Some(true);
|
settings.defaults.hard_tabs = Some(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = "fn a() {}";
|
let text = "fn a() {}";
|
||||||
let mut buffer =
|
let mut buffer =
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||||
@ -1006,7 +1006,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
|
|||||||
fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppContext) {
|
fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let entity_id = cx.entity_id();
|
let entity_id = cx.entity_id();
|
||||||
let mut buffer = Buffer::new(
|
let mut buffer = Buffer::new(
|
||||||
0,
|
0,
|
||||||
@ -1080,7 +1080,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
|
|||||||
buffer
|
buffer
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
eprintln!("second buffer: {:?}", cx.entity_id());
|
eprintln!("second buffer: {:?}", cx.entity_id());
|
||||||
|
|
||||||
let mut buffer = Buffer::new(
|
let mut buffer = Buffer::new(
|
||||||
@ -1147,7 +1147,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
|
|||||||
fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut AppContext) {
|
fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let mut buffer = Buffer::new(
|
let mut buffer = Buffer::new(
|
||||||
0,
|
0,
|
||||||
cx.entity_id().as_u64(),
|
cx.entity_id().as_u64(),
|
||||||
@ -1209,7 +1209,7 @@ fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut Ap
|
|||||||
fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
|
fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let mut buffer = Buffer::new(
|
let mut buffer = Buffer::new(
|
||||||
0,
|
0,
|
||||||
cx.entity_id().as_u64(),
|
cx.entity_id().as_u64(),
|
||||||
@ -1266,7 +1266,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
|
|||||||
fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
|
fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = "a\nb";
|
let text = "a\nb";
|
||||||
let mut buffer =
|
let mut buffer =
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||||
@ -1284,7 +1284,7 @@ fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
|
|||||||
fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
|
fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = "
|
let text = "
|
||||||
const a: usize = 1;
|
const a: usize = 1;
|
||||||
fn b() {
|
fn b() {
|
||||||
@ -1326,7 +1326,7 @@ fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
|
|||||||
fn test_autoindent_block_mode(cx: &mut AppContext) {
|
fn test_autoindent_block_mode(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = r#"
|
let text = r#"
|
||||||
fn a() {
|
fn a() {
|
||||||
b();
|
b();
|
||||||
@ -1410,7 +1410,7 @@ fn test_autoindent_block_mode(cx: &mut AppContext) {
|
|||||||
fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContext) {
|
fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = r#"
|
let text = r#"
|
||||||
fn a() {
|
fn a() {
|
||||||
if b() {
|
if b() {
|
||||||
@ -1490,7 +1490,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContex
|
|||||||
fn test_autoindent_language_without_indents_query(cx: &mut AppContext) {
|
fn test_autoindent_language_without_indents_query(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = "
|
let text = "
|
||||||
* one
|
* one
|
||||||
- a
|
- a
|
||||||
@ -1559,7 +1559,7 @@ fn test_autoindent_with_injected_languages(cx: &mut AppContext) {
|
|||||||
language_registry.add(html_language.clone());
|
language_registry.add(html_language.clone());
|
||||||
language_registry.add(javascript_language.clone());
|
language_registry.add(javascript_language.clone());
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let (text, ranges) = marked_text_ranges(
|
let (text, ranges) = marked_text_ranges(
|
||||||
&"
|
&"
|
||||||
<div>ˇ
|
<div>ˇ
|
||||||
@ -1610,7 +1610,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
|
|||||||
settings.defaults.tab_size = Some(2.try_into().unwrap());
|
settings.defaults.tab_size = Some(2.try_into().unwrap());
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let mut buffer =
|
let mut buffer =
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), "").with_language(Arc::new(ruby_lang()), cx);
|
Buffer::new(0, cx.entity_id().as_u64(), "").with_language(Arc::new(ruby_lang()), cx);
|
||||||
|
|
||||||
@ -1653,7 +1653,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
|
|||||||
fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
|
fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let language = Language::new(
|
let language = Language::new(
|
||||||
LanguageConfig {
|
LanguageConfig {
|
||||||
name: "JavaScript".into(),
|
name: "JavaScript".into(),
|
||||||
@ -1742,7 +1742,7 @@ fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
|
|||||||
fn test_language_scope_at_with_rust(cx: &mut AppContext) {
|
fn test_language_scope_at_with_rust(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let language = Language::new(
|
let language = Language::new(
|
||||||
LanguageConfig {
|
LanguageConfig {
|
||||||
name: "Rust".into(),
|
name: "Rust".into(),
|
||||||
@ -1810,7 +1810,7 @@ fn test_language_scope_at_with_rust(cx: &mut AppContext) {
|
|||||||
fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
|
fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
|
||||||
init_settings(cx, |_| {});
|
init_settings(cx, |_| {});
|
||||||
|
|
||||||
cx.entity(|cx| {
|
cx.build_model(|cx| {
|
||||||
let text = r#"
|
let text = r#"
|
||||||
<ol>
|
<ol>
|
||||||
<% people.each do |person| %>
|
<% people.each do |person| %>
|
||||||
@ -1858,7 +1858,7 @@ fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
|
|||||||
fn test_serialization(cx: &mut gpui2::AppContext) {
|
fn test_serialization(cx: &mut gpui2::AppContext) {
|
||||||
let mut now = Instant::now();
|
let mut now = Instant::now();
|
||||||
|
|
||||||
let buffer1 = cx.entity(|cx| {
|
let buffer1 = cx.build_model(|cx| {
|
||||||
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "abc");
|
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "abc");
|
||||||
buffer.edit([(3..3, "D")], None, cx);
|
buffer.edit([(3..3, "D")], None, cx);
|
||||||
|
|
||||||
@ -1881,7 +1881,7 @@ fn test_serialization(cx: &mut gpui2::AppContext) {
|
|||||||
let ops = cx
|
let ops = cx
|
||||||
.executor()
|
.executor()
|
||||||
.block(buffer1.read(cx).serialize_ops(None, cx));
|
.block(buffer1.read(cx).serialize_ops(None, cx));
|
||||||
let buffer2 = cx.entity(|cx| {
|
let buffer2 = cx.build_model(|cx| {
|
||||||
let mut buffer = Buffer::from_proto(1, state, None).unwrap();
|
let mut buffer = Buffer::from_proto(1, state, None).unwrap();
|
||||||
buffer
|
buffer
|
||||||
.apply_ops(
|
.apply_ops(
|
||||||
@ -1914,10 +1914,11 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
|
|||||||
let mut replica_ids = Vec::new();
|
let mut replica_ids = Vec::new();
|
||||||
let mut buffers = Vec::new();
|
let mut buffers = Vec::new();
|
||||||
let network = Arc::new(Mutex::new(Network::new(rng.clone())));
|
let network = Arc::new(Mutex::new(Network::new(rng.clone())));
|
||||||
let base_buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str()));
|
let base_buffer =
|
||||||
|
cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str()));
|
||||||
|
|
||||||
for i in 0..rng.gen_range(min_peers..=max_peers) {
|
for i in 0..rng.gen_range(min_peers..=max_peers) {
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
let state = base_buffer.read(cx).to_proto();
|
let state = base_buffer.read(cx).to_proto();
|
||||||
let ops = cx
|
let ops = cx
|
||||||
.executor()
|
.executor()
|
||||||
@ -2034,7 +2035,7 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
|
|||||||
new_replica_id,
|
new_replica_id,
|
||||||
replica_id
|
replica_id
|
||||||
);
|
);
|
||||||
new_buffer = Some(cx.entity(|cx| {
|
new_buffer = Some(cx.build_model(|cx| {
|
||||||
let mut new_buffer =
|
let mut new_buffer =
|
||||||
Buffer::from_proto(new_replica_id, old_buffer_state, None).unwrap();
|
Buffer::from_proto(new_replica_id, old_buffer_state, None).unwrap();
|
||||||
new_buffer
|
new_buffer
|
||||||
@ -2396,7 +2397,7 @@ fn javascript_lang() -> Language {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_tree_sexp(buffer: &Handle<Buffer>, cx: &mut gpui2::TestAppContext) -> String {
|
fn get_tree_sexp(buffer: &Model<Buffer>, cx: &mut gpui2::TestAppContext) -> String {
|
||||||
buffer.update(cx, |buffer, _| {
|
buffer.update(cx, |buffer, _| {
|
||||||
let snapshot = buffer.snapshot();
|
let snapshot = buffer.snapshot();
|
||||||
let layers = snapshot.syntax.layers(buffer.as_text_snapshot());
|
let layers = snapshot.syntax.layers(buffer.as_text_snapshot());
|
||||||
@ -2412,7 +2413,7 @@ fn assert_bracket_pairs(
|
|||||||
cx: &mut AppContext,
|
cx: &mut AppContext,
|
||||||
) {
|
) {
|
||||||
let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false);
|
let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false);
|
||||||
let buffer = cx.entity(|cx| {
|
let buffer = cx.build_model(|cx| {
|
||||||
Buffer::new(0, cx.entity_id().as_u64(), expected_text.clone())
|
Buffer::new(0, cx.entity_id().as_u64(), expected_text.clone())
|
||||||
.with_language(Arc::new(language), cx)
|
.with_language(Arc::new(language), cx)
|
||||||
});
|
});
|
||||||
|
12
crates/menu2/Cargo.toml
Normal file
12
crates/menu2/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "menu2"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/menu2.rs"
|
||||||
|
doctest = false
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
gpui2 = { path = "../gpui2" }
|
25
crates/menu2/src/menu2.rs
Normal file
25
crates/menu2/src/menu2.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
// todo!(use actions! macro)
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct Cancel;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct Confirm;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct SecondaryConfirm;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct SelectPrev;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct SelectNext;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct SelectFirst;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct SelectLast;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, PartialEq)]
|
||||||
|
pub struct ShowContextMenu;
|
@ -16,7 +16,7 @@ client2 = { path = "../client2" }
|
|||||||
collections = { path = "../collections"}
|
collections = { path = "../collections"}
|
||||||
language2 = { path = "../language2" }
|
language2 = { path = "../language2" }
|
||||||
gpui2 = { path = "../gpui2" }
|
gpui2 = { path = "../gpui2" }
|
||||||
fs = { path = "../fs" }
|
fs2 = { path = "../fs2" }
|
||||||
lsp2 = { path = "../lsp2" }
|
lsp2 = { path = "../lsp2" }
|
||||||
node_runtime = { path = "../node_runtime"}
|
node_runtime = { path = "../node_runtime"}
|
||||||
util = { path = "../util" }
|
util = { path = "../util" }
|
||||||
@ -32,4 +32,4 @@ parking_lot.workspace = true
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
language2 = { path = "../language2", features = ["test-support"] }
|
language2 = { path = "../language2", features = ["test-support"] }
|
||||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
||||||
fs = { path = "../fs", features = ["test-support"] }
|
fs2 = { path = "../fs2", features = ["test-support"] }
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use fs::Fs;
|
use fs2::Fs;
|
||||||
use gpui2::{AsyncAppContext, Handle};
|
use gpui2::{AsyncAppContext, Model};
|
||||||
use language2::{language_settings::language_settings, Buffer, BundledFormatter, Diff};
|
use language2::{language_settings::language_settings, Buffer, BundledFormatter, Diff};
|
||||||
use lsp2::{LanguageServer, LanguageServerId};
|
use lsp2::{LanguageServer, LanguageServerId};
|
||||||
use node_runtime::NodeRuntime;
|
use node_runtime::NodeRuntime;
|
||||||
@ -183,7 +183,7 @@ impl Prettier {
|
|||||||
|
|
||||||
pub async fn format(
|
pub async fn format(
|
||||||
&self,
|
&self,
|
||||||
buffer: &Handle<Buffer>,
|
buffer: &Model<Buffer>,
|
||||||
buffer_path: Option<PathBuf>,
|
buffer_path: Option<PathBuf>,
|
||||||
cx: &mut AsyncAppContext,
|
cx: &mut AsyncAppContext,
|
||||||
) -> anyhow::Result<Diff> {
|
) -> anyhow::Result<Diff> {
|
||||||
|
@ -25,7 +25,7 @@ client2 = { path = "../client2" }
|
|||||||
clock = { path = "../clock" }
|
clock = { path = "../clock" }
|
||||||
collections = { path = "../collections" }
|
collections = { path = "../collections" }
|
||||||
db2 = { path = "../db2" }
|
db2 = { path = "../db2" }
|
||||||
fs = { path = "../fs" }
|
fs2 = { path = "../fs2" }
|
||||||
fsevent = { path = "../fsevent" }
|
fsevent = { path = "../fsevent" }
|
||||||
fuzzy2 = { path = "../fuzzy2" }
|
fuzzy2 = { path = "../fuzzy2" }
|
||||||
git = { path = "../git" }
|
git = { path = "../git" }
|
||||||
@ -71,7 +71,7 @@ pretty_assertions.workspace = true
|
|||||||
client2 = { path = "../client2", features = ["test-support"] }
|
client2 = { path = "../client2", features = ["test-support"] }
|
||||||
collections = { path = "../collections", features = ["test-support"] }
|
collections = { path = "../collections", features = ["test-support"] }
|
||||||
db2 = { path = "../db2", features = ["test-support"] }
|
db2 = { path = "../db2", features = ["test-support"] }
|
||||||
fs = { path = "../fs", features = ["test-support"] }
|
fs2 = { path = "../fs2", features = ["test-support"] }
|
||||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
||||||
language2 = { path = "../language2", features = ["test-support"] }
|
language2 = { path = "../language2", features = ["test-support"] }
|
||||||
lsp2 = { path = "../lsp2", features = ["test-support"] }
|
lsp2 = { path = "../lsp2", features = ["test-support"] }
|
||||||
|
@ -7,7 +7,7 @@ use anyhow::{anyhow, Context, Result};
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use client2::proto::{self, PeerId};
|
use client2::proto::{self, PeerId};
|
||||||
use futures::future;
|
use futures::future;
|
||||||
use gpui2::{AppContext, AsyncAppContext, Handle};
|
use gpui2::{AppContext, AsyncAppContext, Model};
|
||||||
use language2::{
|
use language2::{
|
||||||
language_settings::{language_settings, InlayHintKind},
|
language_settings::{language_settings, InlayHintKind},
|
||||||
point_from_lsp, point_to_lsp,
|
point_from_lsp, point_to_lsp,
|
||||||
@ -53,8 +53,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: <Self::LspRequest as lsp2::request::Request>::Result,
|
message: <Self::LspRequest as lsp2::request::Request>::Result,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
cx: AsyncAppContext,
|
cx: AsyncAppContext,
|
||||||
) -> Result<Self::Response>;
|
) -> Result<Self::Response>;
|
||||||
@ -63,8 +63,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: Self::ProtoRequest,
|
message: Self::ProtoRequest,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
cx: AsyncAppContext,
|
cx: AsyncAppContext,
|
||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
@ -79,8 +79,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: <Self::ProtoRequest as proto::RequestMessage>::Response,
|
message: <Self::ProtoRequest as proto::RequestMessage>::Response,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
cx: AsyncAppContext,
|
cx: AsyncAppContext,
|
||||||
) -> Result<Self::Response>;
|
) -> Result<Self::Response>;
|
||||||
|
|
||||||
@ -180,8 +180,8 @@ impl LspCommand for PrepareRename {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: Option<lsp2::PrepareRenameResponse>,
|
message: Option<lsp2::PrepareRenameResponse>,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
_: LanguageServerId,
|
_: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Option<Range<Anchor>>> {
|
) -> Result<Option<Range<Anchor>>> {
|
||||||
@ -215,8 +215,8 @@ impl LspCommand for PrepareRename {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::PrepareRename,
|
message: proto::PrepareRename,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -256,8 +256,8 @@ impl LspCommand for PrepareRename {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::PrepareRenameResponse,
|
message: proto::PrepareRenameResponse,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Option<Range<Anchor>>> {
|
) -> Result<Option<Range<Anchor>>> {
|
||||||
if message.can_rename {
|
if message.can_rename {
|
||||||
@ -307,8 +307,8 @@ impl LspCommand for PerformRename {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: Option<lsp2::WorkspaceEdit>,
|
message: Option<lsp2::WorkspaceEdit>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<ProjectTransaction> {
|
) -> Result<ProjectTransaction> {
|
||||||
@ -343,8 +343,8 @@ impl LspCommand for PerformRename {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::PerformRename,
|
message: proto::PerformRename,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -379,8 +379,8 @@ impl LspCommand for PerformRename {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::PerformRenameResponse,
|
message: proto::PerformRenameResponse,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
_: Handle<Buffer>,
|
_: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<ProjectTransaction> {
|
) -> Result<ProjectTransaction> {
|
||||||
let message = message
|
let message = message
|
||||||
@ -426,8 +426,8 @@ impl LspCommand for GetDefinition {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: Option<lsp2::GotoDefinitionResponse>,
|
message: Option<lsp2::GotoDefinitionResponse>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
cx: AsyncAppContext,
|
cx: AsyncAppContext,
|
||||||
) -> Result<Vec<LocationLink>> {
|
) -> Result<Vec<LocationLink>> {
|
||||||
@ -447,8 +447,8 @@ impl LspCommand for GetDefinition {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::GetDefinition,
|
message: proto::GetDefinition,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -479,8 +479,8 @@ impl LspCommand for GetDefinition {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::GetDefinitionResponse,
|
message: proto::GetDefinitionResponse,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
_: Handle<Buffer>,
|
_: Model<Buffer>,
|
||||||
cx: AsyncAppContext,
|
cx: AsyncAppContext,
|
||||||
) -> Result<Vec<LocationLink>> {
|
) -> Result<Vec<LocationLink>> {
|
||||||
location_links_from_proto(message.links, project, cx).await
|
location_links_from_proto(message.links, project, cx).await
|
||||||
@ -527,8 +527,8 @@ impl LspCommand for GetTypeDefinition {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: Option<lsp2::GotoTypeDefinitionResponse>,
|
message: Option<lsp2::GotoTypeDefinitionResponse>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
cx: AsyncAppContext,
|
cx: AsyncAppContext,
|
||||||
) -> Result<Vec<LocationLink>> {
|
) -> Result<Vec<LocationLink>> {
|
||||||
@ -548,8 +548,8 @@ impl LspCommand for GetTypeDefinition {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::GetTypeDefinition,
|
message: proto::GetTypeDefinition,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -580,8 +580,8 @@ impl LspCommand for GetTypeDefinition {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::GetTypeDefinitionResponse,
|
message: proto::GetTypeDefinitionResponse,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
_: Handle<Buffer>,
|
_: Model<Buffer>,
|
||||||
cx: AsyncAppContext,
|
cx: AsyncAppContext,
|
||||||
) -> Result<Vec<LocationLink>> {
|
) -> Result<Vec<LocationLink>> {
|
||||||
location_links_from_proto(message.links, project, cx).await
|
location_links_from_proto(message.links, project, cx).await
|
||||||
@ -593,8 +593,8 @@ impl LspCommand for GetTypeDefinition {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn language_server_for_buffer(
|
fn language_server_for_buffer(
|
||||||
project: &Handle<Project>,
|
project: &Model<Project>,
|
||||||
buffer: &Handle<Buffer>,
|
buffer: &Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
cx: &mut AsyncAppContext,
|
cx: &mut AsyncAppContext,
|
||||||
) -> Result<(Arc<CachedLspAdapter>, Arc<LanguageServer>)> {
|
) -> Result<(Arc<CachedLspAdapter>, Arc<LanguageServer>)> {
|
||||||
@ -609,7 +609,7 @@ fn language_server_for_buffer(
|
|||||||
|
|
||||||
async fn location_links_from_proto(
|
async fn location_links_from_proto(
|
||||||
proto_links: Vec<proto::LocationLink>,
|
proto_links: Vec<proto::LocationLink>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<LocationLink>> {
|
) -> Result<Vec<LocationLink>> {
|
||||||
let mut links = Vec::new();
|
let mut links = Vec::new();
|
||||||
@ -671,8 +671,8 @@ async fn location_links_from_proto(
|
|||||||
|
|
||||||
async fn location_links_from_lsp(
|
async fn location_links_from_lsp(
|
||||||
message: Option<lsp2::GotoDefinitionResponse>,
|
message: Option<lsp2::GotoDefinitionResponse>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<LocationLink>> {
|
) -> Result<Vec<LocationLink>> {
|
||||||
@ -814,8 +814,8 @@ impl LspCommand for GetReferences {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
locations: Option<Vec<lsp2::Location>>,
|
locations: Option<Vec<lsp2::Location>>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<Location>> {
|
) -> Result<Vec<Location>> {
|
||||||
@ -868,8 +868,8 @@ impl LspCommand for GetReferences {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::GetReferences,
|
message: proto::GetReferences,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -910,8 +910,8 @@ impl LspCommand for GetReferences {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::GetReferencesResponse,
|
message: proto::GetReferencesResponse,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
_: Handle<Buffer>,
|
_: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<Location>> {
|
) -> Result<Vec<Location>> {
|
||||||
let mut locations = Vec::new();
|
let mut locations = Vec::new();
|
||||||
@ -977,8 +977,8 @@ impl LspCommand for GetDocumentHighlights {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
lsp_highlights: Option<Vec<lsp2::DocumentHighlight>>,
|
lsp_highlights: Option<Vec<lsp2::DocumentHighlight>>,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
_: LanguageServerId,
|
_: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<DocumentHighlight>> {
|
) -> Result<Vec<DocumentHighlight>> {
|
||||||
@ -1016,8 +1016,8 @@ impl LspCommand for GetDocumentHighlights {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::GetDocumentHighlights,
|
message: proto::GetDocumentHighlights,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -1060,8 +1060,8 @@ impl LspCommand for GetDocumentHighlights {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::GetDocumentHighlightsResponse,
|
message: proto::GetDocumentHighlightsResponse,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<DocumentHighlight>> {
|
) -> Result<Vec<DocumentHighlight>> {
|
||||||
let mut highlights = Vec::new();
|
let mut highlights = Vec::new();
|
||||||
@ -1123,8 +1123,8 @@ impl LspCommand for GetHover {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: Option<lsp2::Hover>,
|
message: Option<lsp2::Hover>,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
_: LanguageServerId,
|
_: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self::Response> {
|
) -> Result<Self::Response> {
|
||||||
@ -1206,8 +1206,8 @@ impl LspCommand for GetHover {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: Self::ProtoRequest,
|
message: Self::ProtoRequest,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -1272,8 +1272,8 @@ impl LspCommand for GetHover {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::GetHoverResponse,
|
message: proto::GetHoverResponse,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self::Response> {
|
) -> Result<Self::Response> {
|
||||||
let contents: Vec<_> = message
|
let contents: Vec<_> = message
|
||||||
@ -1341,8 +1341,8 @@ impl LspCommand for GetCompletions {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
completions: Option<lsp2::CompletionResponse>,
|
completions: Option<lsp2::CompletionResponse>,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<Completion>> {
|
) -> Result<Vec<Completion>> {
|
||||||
@ -1484,8 +1484,8 @@ impl LspCommand for GetCompletions {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::GetCompletions,
|
message: proto::GetCompletions,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let version = deserialize_version(&message.version);
|
let version = deserialize_version(&message.version);
|
||||||
@ -1523,8 +1523,8 @@ impl LspCommand for GetCompletions {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::GetCompletionsResponse,
|
message: proto::GetCompletionsResponse,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<Completion>> {
|
) -> Result<Vec<Completion>> {
|
||||||
buffer
|
buffer
|
||||||
@ -1589,8 +1589,8 @@ impl LspCommand for GetCodeActions {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
actions: Option<lsp2::CodeActionResponse>,
|
actions: Option<lsp2::CodeActionResponse>,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
_: Handle<Buffer>,
|
_: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
_: AsyncAppContext,
|
_: AsyncAppContext,
|
||||||
) -> Result<Vec<CodeAction>> {
|
) -> Result<Vec<CodeAction>> {
|
||||||
@ -1623,8 +1623,8 @@ impl LspCommand for GetCodeActions {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::GetCodeActions,
|
message: proto::GetCodeActions,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let start = message
|
let start = message
|
||||||
@ -1663,8 +1663,8 @@ impl LspCommand for GetCodeActions {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::GetCodeActionsResponse,
|
message: proto::GetCodeActionsResponse,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Vec<CodeAction>> {
|
) -> Result<Vec<CodeAction>> {
|
||||||
buffer
|
buffer
|
||||||
@ -1726,8 +1726,8 @@ impl LspCommand for OnTypeFormatting {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: Option<Vec<lsp2::TextEdit>>,
|
message: Option<Vec<lsp2::TextEdit>>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Option<Transaction>> {
|
) -> Result<Option<Transaction>> {
|
||||||
@ -1763,8 +1763,8 @@ impl LspCommand for OnTypeFormatting {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::OnTypeFormatting,
|
message: proto::OnTypeFormatting,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let position = message
|
let position = message
|
||||||
@ -1805,8 +1805,8 @@ impl LspCommand for OnTypeFormatting {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::OnTypeFormattingResponse,
|
message: proto::OnTypeFormattingResponse,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
_: Handle<Buffer>,
|
_: Model<Buffer>,
|
||||||
_: AsyncAppContext,
|
_: AsyncAppContext,
|
||||||
) -> Result<Option<Transaction>> {
|
) -> Result<Option<Transaction>> {
|
||||||
let Some(transaction) = message.transaction else {
|
let Some(transaction) = message.transaction else {
|
||||||
@ -1825,7 +1825,7 @@ impl LspCommand for OnTypeFormatting {
|
|||||||
impl InlayHints {
|
impl InlayHints {
|
||||||
pub async fn lsp_to_project_hint(
|
pub async fn lsp_to_project_hint(
|
||||||
lsp_hint: lsp2::InlayHint,
|
lsp_hint: lsp2::InlayHint,
|
||||||
buffer_handle: &Handle<Buffer>,
|
buffer_handle: &Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
resolve_state: ResolveState,
|
resolve_state: ResolveState,
|
||||||
force_no_type_left_padding: bool,
|
force_no_type_left_padding: bool,
|
||||||
@ -2230,8 +2230,8 @@ impl LspCommand for InlayHints {
|
|||||||
async fn response_from_lsp(
|
async fn response_from_lsp(
|
||||||
self,
|
self,
|
||||||
message: Option<Vec<lsp2::InlayHint>>,
|
message: Option<Vec<lsp2::InlayHint>>,
|
||||||
project: Handle<Project>,
|
project: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
server_id: LanguageServerId,
|
server_id: LanguageServerId,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> anyhow::Result<Vec<InlayHint>> {
|
) -> anyhow::Result<Vec<InlayHint>> {
|
||||||
@ -2286,8 +2286,8 @@ impl LspCommand for InlayHints {
|
|||||||
|
|
||||||
async fn from_proto(
|
async fn from_proto(
|
||||||
message: proto::InlayHints,
|
message: proto::InlayHints,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let start = message
|
let start = message
|
||||||
@ -2326,8 +2326,8 @@ impl LspCommand for InlayHints {
|
|||||||
async fn response_from_proto(
|
async fn response_from_proto(
|
||||||
self,
|
self,
|
||||||
message: proto::InlayHintsResponse,
|
message: proto::InlayHintsResponse,
|
||||||
_: Handle<Project>,
|
_: Model<Project>,
|
||||||
buffer: Handle<Buffer>,
|
buffer: Model<Buffer>,
|
||||||
mut cx: AsyncAppContext,
|
mut cx: AsyncAppContext,
|
||||||
) -> anyhow::Result<Vec<InlayHint>> {
|
) -> anyhow::Result<Vec<InlayHint>> {
|
||||||
buffer
|
buffer
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
|||||||
use crate::Project;
|
use crate::Project;
|
||||||
use gpui2::{AnyWindowHandle, Context, Handle, ModelContext, WeakHandle};
|
use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakModel};
|
||||||
use settings2::Settings;
|
use settings2::Settings;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use terminal2::{
|
use terminal2::{
|
||||||
@ -11,7 +11,7 @@ use terminal2::{
|
|||||||
use std::os::unix::ffi::OsStrExt;
|
use std::os::unix::ffi::OsStrExt;
|
||||||
|
|
||||||
pub struct Terminals {
|
pub struct Terminals {
|
||||||
pub(crate) local_handles: Vec<WeakHandle<terminal2::Terminal>>,
|
pub(crate) local_handles: Vec<WeakModel<terminal2::Terminal>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Project {
|
impl Project {
|
||||||
@ -20,7 +20,7 @@ impl Project {
|
|||||||
working_directory: Option<PathBuf>,
|
working_directory: Option<PathBuf>,
|
||||||
window: AnyWindowHandle,
|
window: AnyWindowHandle,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> anyhow::Result<Handle<Terminal>> {
|
) -> anyhow::Result<Model<Terminal>> {
|
||||||
if self.is_remote() {
|
if self.is_remote() {
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"creating terminals as a guest is not supported yet"
|
"creating terminals as a guest is not supported yet"
|
||||||
@ -40,7 +40,7 @@ impl Project {
|
|||||||
|_, _| todo!("color_for_index"),
|
|_, _| todo!("color_for_index"),
|
||||||
)
|
)
|
||||||
.map(|builder| {
|
.map(|builder| {
|
||||||
let terminal_handle = cx.entity(|cx| builder.subscribe(cx));
|
let terminal_handle = cx.build_model(|cx| builder.subscribe(cx));
|
||||||
|
|
||||||
self.terminals
|
self.terminals
|
||||||
.local_handles
|
.local_handles
|
||||||
@ -108,7 +108,7 @@ impl Project {
|
|||||||
fn activate_python_virtual_environment(
|
fn activate_python_virtual_environment(
|
||||||
&mut self,
|
&mut self,
|
||||||
activate_script: Option<PathBuf>,
|
activate_script: Option<PathBuf>,
|
||||||
terminal_handle: &Handle<Terminal>,
|
terminal_handle: &Model<Terminal>,
|
||||||
cx: &mut ModelContext<Project>,
|
cx: &mut ModelContext<Project>,
|
||||||
) {
|
) {
|
||||||
if let Some(activate_script) = activate_script {
|
if let Some(activate_script) = activate_script {
|
||||||
@ -121,7 +121,7 @@ impl Project {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn local_terminal_handles(&self) -> &Vec<WeakHandle<terminal2::Terminal>> {
|
pub fn local_terminal_handles(&self) -> &Vec<WeakModel<terminal2::Terminal>> {
|
||||||
&self.terminals.local_handles
|
&self.terminals.local_handles
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ use anyhow::{anyhow, Context as _, Result};
|
|||||||
use client2::{proto, Client};
|
use client2::{proto, Client};
|
||||||
use clock::ReplicaId;
|
use clock::ReplicaId;
|
||||||
use collections::{HashMap, HashSet, VecDeque};
|
use collections::{HashMap, HashSet, VecDeque};
|
||||||
use fs::{
|
use fs2::{
|
||||||
repository::{GitFileStatus, GitRepository, RepoPath},
|
repository::{GitFileStatus, GitRepository, RepoPath},
|
||||||
Fs,
|
Fs,
|
||||||
};
|
};
|
||||||
@ -22,7 +22,7 @@ use futures::{
|
|||||||
use fuzzy2::CharBag;
|
use fuzzy2::CharBag;
|
||||||
use git::{DOT_GIT, GITIGNORE};
|
use git::{DOT_GIT, GITIGNORE};
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
AppContext, AsyncAppContext, Context, EventEmitter, Executor, Handle, ModelContext, Task,
|
AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext, Task,
|
||||||
};
|
};
|
||||||
use language2::{
|
use language2::{
|
||||||
proto::{
|
proto::{
|
||||||
@ -292,7 +292,7 @@ impl Worktree {
|
|||||||
fs: Arc<dyn Fs>,
|
fs: Arc<dyn Fs>,
|
||||||
next_entry_id: Arc<AtomicUsize>,
|
next_entry_id: Arc<AtomicUsize>,
|
||||||
cx: &mut AsyncAppContext,
|
cx: &mut AsyncAppContext,
|
||||||
) -> Result<Handle<Self>> {
|
) -> Result<Model<Self>> {
|
||||||
// After determining whether the root entry is a file or a directory, populate the
|
// After determining whether the root entry is a file or a directory, populate the
|
||||||
// snapshot's "root name", which will be used for the purpose of fuzzy matching.
|
// snapshot's "root name", which will be used for the purpose of fuzzy matching.
|
||||||
let abs_path = path.into();
|
let abs_path = path.into();
|
||||||
@ -301,7 +301,7 @@ impl Worktree {
|
|||||||
.await
|
.await
|
||||||
.context("failed to stat worktree path")?;
|
.context("failed to stat worktree path")?;
|
||||||
|
|
||||||
cx.entity(move |cx: &mut ModelContext<Worktree>| {
|
cx.build_model(move |cx: &mut ModelContext<Worktree>| {
|
||||||
let root_name = abs_path
|
let root_name = abs_path
|
||||||
.file_name()
|
.file_name()
|
||||||
.map_or(String::new(), |f| f.to_string_lossy().to_string());
|
.map_or(String::new(), |f| f.to_string_lossy().to_string());
|
||||||
@ -406,8 +406,8 @@ impl Worktree {
|
|||||||
worktree: proto::WorktreeMetadata,
|
worktree: proto::WorktreeMetadata,
|
||||||
client: Arc<Client>,
|
client: Arc<Client>,
|
||||||
cx: &mut AppContext,
|
cx: &mut AppContext,
|
||||||
) -> Handle<Self> {
|
) -> Model<Self> {
|
||||||
cx.entity(|cx: &mut ModelContext<Self>| {
|
cx.build_model(|cx: &mut ModelContext<Self>| {
|
||||||
let snapshot = Snapshot {
|
let snapshot = Snapshot {
|
||||||
id: WorktreeId(worktree.id as usize),
|
id: WorktreeId(worktree.id as usize),
|
||||||
abs_path: Arc::from(PathBuf::from(worktree.abs_path)),
|
abs_path: Arc::from(PathBuf::from(worktree.abs_path)),
|
||||||
@ -593,7 +593,7 @@ impl LocalWorktree {
|
|||||||
id: u64,
|
id: u64,
|
||||||
path: &Path,
|
path: &Path,
|
||||||
cx: &mut ModelContext<Worktree>,
|
cx: &mut ModelContext<Worktree>,
|
||||||
) -> Task<Result<Handle<Buffer>>> {
|
) -> Task<Result<Model<Buffer>>> {
|
||||||
let path = Arc::from(path);
|
let path = Arc::from(path);
|
||||||
cx.spawn(move |this, mut cx| async move {
|
cx.spawn(move |this, mut cx| async move {
|
||||||
let (file, contents, diff_base) = this
|
let (file, contents, diff_base) = this
|
||||||
@ -603,7 +603,7 @@ impl LocalWorktree {
|
|||||||
.executor()
|
.executor()
|
||||||
.spawn(async move { text::Buffer::new(0, id, contents) })
|
.spawn(async move { text::Buffer::new(0, id, contents) })
|
||||||
.await;
|
.await;
|
||||||
cx.entity(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file))))
|
cx.build_model(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file))))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -920,7 +920,7 @@ impl LocalWorktree {
|
|||||||
|
|
||||||
pub fn save_buffer(
|
pub fn save_buffer(
|
||||||
&self,
|
&self,
|
||||||
buffer_handle: Handle<Buffer>,
|
buffer_handle: Model<Buffer>,
|
||||||
path: Arc<Path>,
|
path: Arc<Path>,
|
||||||
has_changed_file: bool,
|
has_changed_file: bool,
|
||||||
cx: &mut ModelContext<Worktree>,
|
cx: &mut ModelContext<Worktree>,
|
||||||
@ -1331,7 +1331,7 @@ impl RemoteWorktree {
|
|||||||
|
|
||||||
pub fn save_buffer(
|
pub fn save_buffer(
|
||||||
&self,
|
&self,
|
||||||
buffer_handle: Handle<Buffer>,
|
buffer_handle: Model<Buffer>,
|
||||||
cx: &mut ModelContext<Worktree>,
|
cx: &mut ModelContext<Worktree>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
let buffer = buffer_handle.read(cx);
|
let buffer = buffer_handle.read(cx);
|
||||||
@ -2577,7 +2577,7 @@ impl fmt::Debug for Snapshot {
|
|||||||
|
|
||||||
#[derive(Clone, PartialEq)]
|
#[derive(Clone, PartialEq)]
|
||||||
pub struct File {
|
pub struct File {
|
||||||
pub worktree: Handle<Worktree>,
|
pub worktree: Model<Worktree>,
|
||||||
pub path: Arc<Path>,
|
pub path: Arc<Path>,
|
||||||
pub mtime: SystemTime,
|
pub mtime: SystemTime,
|
||||||
pub(crate) entry_id: ProjectEntryId,
|
pub(crate) entry_id: ProjectEntryId,
|
||||||
@ -2701,7 +2701,7 @@ impl language2::LocalFile for File {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl File {
|
impl File {
|
||||||
pub fn for_entry(entry: Entry, worktree: Handle<Worktree>) -> Arc<Self> {
|
pub fn for_entry(entry: Entry, worktree: Model<Worktree>) -> Arc<Self> {
|
||||||
Arc::new(Self {
|
Arc::new(Self {
|
||||||
worktree,
|
worktree,
|
||||||
path: entry.path.clone(),
|
path: entry.path.clone(),
|
||||||
@ -2714,7 +2714,7 @@ impl File {
|
|||||||
|
|
||||||
pub fn from_proto(
|
pub fn from_proto(
|
||||||
proto: rpc2::proto::File,
|
proto: rpc2::proto::File,
|
||||||
worktree: Handle<Worktree>,
|
worktree: Model<Worktree>,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let worktree_id = worktree
|
let worktree_id = worktree
|
||||||
@ -2815,7 +2815,7 @@ pub type UpdatedGitRepositoriesSet = Arc<[(Arc<Path>, GitRepositoryChange)]>;
|
|||||||
impl Entry {
|
impl Entry {
|
||||||
fn new(
|
fn new(
|
||||||
path: Arc<Path>,
|
path: Arc<Path>,
|
||||||
metadata: &fs::Metadata,
|
metadata: &fs2::Metadata,
|
||||||
next_entry_id: &AtomicUsize,
|
next_entry_id: &AtomicUsize,
|
||||||
root_char_bag: CharBag,
|
root_char_bag: CharBag,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
@ -42,6 +42,7 @@ sha1 = "0.10.5"
|
|||||||
ndarray = { version = "0.15.0" }
|
ndarray = { version = "0.15.0" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
ai = { path = "../ai", features = ["test-support"] }
|
||||||
collections = { path = "../collections", features = ["test-support"] }
|
collections = { path = "../collections", features = ["test-support"] }
|
||||||
gpui = { path = "../gpui", features = ["test-support"] }
|
gpui = { path = "../gpui", features = ["test-support"] }
|
||||||
language = { path = "../language", features = ["test-support"] }
|
language = { path = "../language", features = ["test-support"] }
|
||||||
|
@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
|
|||||||
pending_batch_token_count: usize,
|
pending_batch_token_count: usize,
|
||||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||||
api_key: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingQueue {
|
impl EmbeddingQueue {
|
||||||
pub fn new(
|
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
|
||||||
executor: Arc<Background>,
|
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||||
Self {
|
Self {
|
||||||
embedding_provider,
|
embedding_provider,
|
||||||
@ -64,14 +59,9 @@ impl EmbeddingQueue {
|
|||||||
pending_batch_token_count: 0,
|
pending_batch_token_count: 0,
|
||||||
finished_files_tx,
|
finished_files_tx,
|
||||||
finished_files_rx,
|
finished_files_rx,
|
||||||
api_key,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_api_key(&mut self, api_key: Option<String>) {
|
|
||||||
self.api_key = api_key
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn push(&mut self, file: FileToEmbed) {
|
pub fn push(&mut self, file: FileToEmbed) {
|
||||||
if file.spans.is_empty() {
|
if file.spans.is_empty() {
|
||||||
self.finished_files_tx.try_send(file).unwrap();
|
self.finished_files_tx.try_send(file).unwrap();
|
||||||
@ -118,7 +108,6 @@ impl EmbeddingQueue {
|
|||||||
|
|
||||||
let finished_files_tx = self.finished_files_tx.clone();
|
let finished_files_tx = self.finished_files_tx.clone();
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
|
|
||||||
self.executor
|
self.executor
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
@ -143,7 +132,7 @@ impl EmbeddingQueue {
|
|||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
match embedding_provider.embed_batch(spans, api_key).await {
|
match embedding_provider.embed_batch(spans).await {
|
||||||
Ok(embeddings) => {
|
Ok(embeddings) => {
|
||||||
let mut embeddings = embeddings.into_iter();
|
let mut embeddings = embeddings.into_iter();
|
||||||
for fragment in batch {
|
for fragment in batch {
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
use ai::{
|
||||||
|
embedding::{Embedding, EmbeddingProvider},
|
||||||
|
models::TruncationDirection,
|
||||||
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use language::{Grammar, Language};
|
use language::{Grammar, Language};
|
||||||
use rusqlite::{
|
use rusqlite::{
|
||||||
@ -108,7 +111,14 @@ impl CodeContextRetriever {
|
|||||||
.replace("<language>", language_name.as_ref())
|
.replace("<language>", language_name.as_ref())
|
||||||
.replace("<item>", &content);
|
.replace("<item>", &content);
|
||||||
let digest = SpanDigest::from(document_span.as_str());
|
let digest = SpanDigest::from(document_span.as_str());
|
||||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
let model = self.embedding_provider.base_model();
|
||||||
|
let document_span = model.truncate(
|
||||||
|
&document_span,
|
||||||
|
model.capacity()?,
|
||||||
|
ai::models::TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
let token_count = model.count_tokens(&document_span)?;
|
||||||
|
|
||||||
Ok(vec![Span {
|
Ok(vec![Span {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
content: document_span,
|
content: document_span,
|
||||||
@ -131,7 +141,15 @@ impl CodeContextRetriever {
|
|||||||
)
|
)
|
||||||
.replace("<item>", &content);
|
.replace("<item>", &content);
|
||||||
let digest = SpanDigest::from(document_span.as_str());
|
let digest = SpanDigest::from(document_span.as_str());
|
||||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
|
||||||
|
let model = self.embedding_provider.base_model();
|
||||||
|
let document_span = model.truncate(
|
||||||
|
&document_span,
|
||||||
|
model.capacity()?,
|
||||||
|
ai::models::TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
let token_count = model.count_tokens(&document_span)?;
|
||||||
|
|
||||||
Ok(vec![Span {
|
Ok(vec![Span {
|
||||||
range: 0..content.len(),
|
range: 0..content.len(),
|
||||||
content: document_span,
|
content: document_span,
|
||||||
@ -222,8 +240,13 @@ impl CodeContextRetriever {
|
|||||||
.replace("<language>", language_name.as_ref())
|
.replace("<language>", language_name.as_ref())
|
||||||
.replace("item", &span.content);
|
.replace("item", &span.content);
|
||||||
|
|
||||||
let (document_content, token_count) =
|
let model = self.embedding_provider.base_model();
|
||||||
self.embedding_provider.truncate(&document_content);
|
let document_content = model.truncate(
|
||||||
|
&document_content,
|
||||||
|
model.capacity()?,
|
||||||
|
TruncationDirection::End,
|
||||||
|
)?;
|
||||||
|
let token_count = model.count_tokens(&document_content)?;
|
||||||
|
|
||||||
span.content = document_content;
|
span.content = document_content;
|
||||||
span.token_count = token_count;
|
span.token_count = token_count;
|
||||||
|
@ -7,7 +7,8 @@ pub mod semantic_index_settings;
|
|||||||
mod semantic_index_tests;
|
mod semantic_index_tests;
|
||||||
|
|
||||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||||
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||||
|
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use collections::{BTreeMap, HashMap, HashSet};
|
use collections::{BTreeMap, HashMap, HashSet};
|
||||||
use db::VectorDatabase;
|
use db::VectorDatabase;
|
||||||
@ -88,7 +89,7 @@ pub fn init(
|
|||||||
let semantic_index = SemanticIndex::new(
|
let semantic_index = SemanticIndex::new(
|
||||||
fs,
|
fs,
|
||||||
db_file_path,
|
db_file_path,
|
||||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
|
||||||
language_registry,
|
language_registry,
|
||||||
cx.clone(),
|
cx.clone(),
|
||||||
)
|
)
|
||||||
@ -123,8 +124,6 @@ pub struct SemanticIndex {
|
|||||||
_embedding_task: Task<()>,
|
_embedding_task: Task<()>,
|
||||||
_parsing_files_tasks: Vec<Task<()>>,
|
_parsing_files_tasks: Vec<Task<()>>,
|
||||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||||
api_key: Option<String>,
|
|
||||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ProjectState {
|
struct ProjectState {
|
||||||
@ -278,18 +277,18 @@ impl SemanticIndex {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&mut self, cx: &AppContext) {
|
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||||
if self.api_key.is_none() {
|
if !self.embedding_provider.has_credentials() {
|
||||||
self.api_key = self.embedding_provider.retrieve_credentials(cx);
|
self.embedding_provider.retrieve_credentials(cx);
|
||||||
|
} else {
|
||||||
self.embedding_queue
|
return true;
|
||||||
.lock()
|
|
||||||
.set_api_key(self.api_key.clone());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.embedding_provider.has_credentials()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
pub fn is_authenticated(&self) -> bool {
|
||||||
self.api_key.is_some()
|
self.embedding_provider.has_credentials()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn enabled(cx: &AppContext) -> bool {
|
pub fn enabled(cx: &AppContext) -> bool {
|
||||||
@ -339,7 +338,7 @@ impl SemanticIndex {
|
|||||||
Ok(cx.add_model(|cx| {
|
Ok(cx.add_model(|cx| {
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
let embedding_queue =
|
let embedding_queue =
|
||||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
|
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||||
let _embedding_task = cx.background().spawn({
|
let _embedding_task = cx.background().spawn({
|
||||||
let embedded_files = embedding_queue.finished_files();
|
let embedded_files = embedding_queue.finished_files();
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
@ -404,8 +403,6 @@ impl SemanticIndex {
|
|||||||
_embedding_task,
|
_embedding_task,
|
||||||
_parsing_files_tasks,
|
_parsing_files_tasks,
|
||||||
projects: Default::default(),
|
projects: Default::default(),
|
||||||
api_key: None,
|
|
||||||
embedding_queue
|
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@ -720,13 +717,13 @@ impl SemanticIndex {
|
|||||||
|
|
||||||
let index = self.index_project(project.clone(), cx);
|
let index = self.index_project(project.clone(), cx);
|
||||||
let embedding_provider = self.embedding_provider.clone();
|
let embedding_provider = self.embedding_provider.clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
|
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
index.await?;
|
index.await?;
|
||||||
let t0 = Instant::now();
|
let t0 = Instant::now();
|
||||||
|
|
||||||
let query = embedding_provider
|
let query = embedding_provider
|
||||||
.embed_batch(vec![query], api_key)
|
.embed_batch(vec![query])
|
||||||
.await?
|
.await?
|
||||||
.pop()
|
.pop()
|
||||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||||
@ -944,7 +941,6 @@ impl SemanticIndex {
|
|||||||
let fs = self.fs.clone();
|
let fs = self.fs.clone();
|
||||||
let db_path = self.db.path().clone();
|
let db_path = self.db.path().clone();
|
||||||
let background = cx.background().clone();
|
let background = cx.background().clone();
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
cx.background().spawn(async move {
|
cx.background().spawn(async move {
|
||||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||||
let mut results = Vec::<SearchResult>::new();
|
let mut results = Vec::<SearchResult>::new();
|
||||||
@ -959,15 +955,10 @@ impl SemanticIndex {
|
|||||||
.parse_file_with_template(None, &snapshot.text(), language)
|
.parse_file_with_template(None, &snapshot.text(), language)
|
||||||
.log_err()
|
.log_err()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
if Self::embed_spans(
|
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||||
&mut spans,
|
.await
|
||||||
embedding_provider.as_ref(),
|
.log_err()
|
||||||
&db,
|
.is_some()
|
||||||
api_key.clone(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.log_err()
|
|
||||||
.is_some()
|
|
||||||
{
|
{
|
||||||
for span in spans {
|
for span in spans {
|
||||||
let similarity = span.embedding.unwrap().similarity(&query);
|
let similarity = span.embedding.unwrap().similarity(&query);
|
||||||
@ -1007,9 +998,8 @@ impl SemanticIndex {
|
|||||||
project: ModelHandle<Project>,
|
project: ModelHandle<Project>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<Result<()>> {
|
) -> Task<Result<()>> {
|
||||||
if self.api_key.is_none() {
|
if !self.is_authenticated() {
|
||||||
self.authenticate(cx);
|
if !self.authenticate(cx) {
|
||||||
if self.api_key.is_none() {
|
|
||||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1192,7 +1182,6 @@ impl SemanticIndex {
|
|||||||
spans: &mut [Span],
|
spans: &mut [Span],
|
||||||
embedding_provider: &dyn EmbeddingProvider,
|
embedding_provider: &dyn EmbeddingProvider,
|
||||||
db: &VectorDatabase,
|
db: &VectorDatabase,
|
||||||
api_key: Option<String>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut batch_tokens = 0;
|
let mut batch_tokens = 0;
|
||||||
@ -1215,7 +1204,7 @@ impl SemanticIndex {
|
|||||||
|
|
||||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), api_key.clone())
|
.embed_batch(mem::take(&mut batch))
|
||||||
.await?;
|
.await?;
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
batch_tokens = 0;
|
batch_tokens = 0;
|
||||||
@ -1227,7 +1216,7 @@ impl SemanticIndex {
|
|||||||
|
|
||||||
if !batch.is_empty() {
|
if !batch.is_empty() {
|
||||||
let batch_embeddings = embedding_provider
|
let batch_embeddings = embedding_provider
|
||||||
.embed_batch(mem::take(&mut batch), api_key)
|
.embed_batch(mem::take(&mut batch))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
embeddings.extend(batch_embeddings);
|
embeddings.extend(batch_embeddings);
|
||||||
|
@ -4,10 +4,9 @@ use crate::{
|
|||||||
semantic_index_settings::SemanticIndexSettings,
|
semantic_index_settings::SemanticIndexSettings,
|
||||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||||
};
|
};
|
||||||
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
|
use ai::test::FakeEmbeddingProvider;
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||||
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
|
|
||||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
|
|||||||
use rand::{rngs::StdRng, Rng};
|
use rand::{rngs::StdRng, Rng};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use settings::SettingsStore;
|
use settings::SettingsStore;
|
||||||
use std::{
|
use std::{path::Path, sync::Arc, time::SystemTime};
|
||||||
path::Path,
|
|
||||||
sync::{
|
|
||||||
atomic::{self, AtomicUsize},
|
|
||||||
Arc,
|
|
||||||
},
|
|
||||||
time::{Instant, SystemTime},
|
|
||||||
};
|
|
||||||
use unindent::Unindent;
|
use unindent::Unindent;
|
||||||
use util::RandomCharIter;
|
use util::RandomCharIter;
|
||||||
|
|
||||||
@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
|||||||
|
|
||||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
|
|
||||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
|
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||||
for file in &files {
|
for file in &files {
|
||||||
queue.push(file.clone());
|
queue.push(file.clone());
|
||||||
}
|
}
|
||||||
@ -280,7 +272,7 @@ fn assert_search_results(
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_rust() {
|
async fn test_code_context_retrieval_rust() {
|
||||||
let language = rust_lang();
|
let language = rust_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_json() {
|
async fn test_code_context_retrieval_json() {
|
||||||
let language = json_lang();
|
let language = json_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
@ -466,7 +458,7 @@ fn assert_documents_eq(
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_javascript() {
|
async fn test_code_context_retrieval_javascript() {
|
||||||
let language = js_lang();
|
let language = js_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_lua() {
|
async fn test_code_context_retrieval_lua() {
|
||||||
let language = lua_lang();
|
let language = lua_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_elixir() {
|
async fn test_code_context_retrieval_elixir() {
|
||||||
let language = elixir_lang();
|
let language = elixir_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_cpp() {
|
async fn test_code_context_retrieval_cpp() {
|
||||||
let language = cpp_lang();
|
let language = cpp_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = "
|
let text = "
|
||||||
@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_ruby() {
|
async fn test_code_context_retrieval_ruby() {
|
||||||
let language = ruby_lang();
|
let language = ruby_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
|
|||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_code_context_retrieval_php() {
|
async fn test_code_context_retrieval_php() {
|
||||||
let language = php_lang();
|
let language = php_lang();
|
||||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||||
|
|
||||||
let text = r#"
|
let text = r#"
|
||||||
@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct FakeEmbeddingProvider {
|
|
||||||
embedding_count: AtomicUsize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FakeEmbeddingProvider {
|
|
||||||
fn embedding_count(&self) -> usize {
|
|
||||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embed_sync(&self, span: &str) -> Embedding {
|
|
||||||
let mut result = vec![1.0; 26];
|
|
||||||
for letter in span.chars() {
|
|
||||||
let letter = letter.to_ascii_lowercase();
|
|
||||||
if letter as u32 >= 'a' as u32 {
|
|
||||||
let ix = (letter as u32) - ('a' as u32);
|
|
||||||
if ix < 26 {
|
|
||||||
result[ix as usize] += 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
||||||
for x in &mut result {
|
|
||||||
*x /= norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
|
||||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
|
||||||
Some("Fake Credentials".to_string())
|
|
||||||
}
|
|
||||||
fn truncate(&self, span: &str) -> (String, usize) {
|
|
||||||
(span.to_string(), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_tokens_per_batch(&self) -> usize {
|
|
||||||
200
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_batch(
|
|
||||||
&self,
|
|
||||||
spans: Vec<String>,
|
|
||||||
_api_key: Option<String>,
|
|
||||||
) -> Result<Vec<Embedding>> {
|
|
||||||
self.embedding_count
|
|
||||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
|
||||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn js_lang() -> Arc<Language> {
|
fn js_lang() -> Arc<Language> {
|
||||||
Arc::new(
|
Arc::new(
|
||||||
Language::new(
|
Language::new(
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
mod colors;
|
||||||
mod focus;
|
mod focus;
|
||||||
mod kitchen_sink;
|
mod kitchen_sink;
|
||||||
mod scroll;
|
mod scroll;
|
||||||
mod text;
|
mod text;
|
||||||
mod z_index;
|
mod z_index;
|
||||||
|
|
||||||
|
pub use colors::*;
|
||||||
pub use focus::*;
|
pub use focus::*;
|
||||||
pub use kitchen_sink::*;
|
pub use kitchen_sink::*;
|
||||||
pub use scroll::*;
|
pub use scroll::*;
|
||||||
|
38
crates/storybook2/src/stories/colors.rs
Normal file
38
crates/storybook2/src/stories/colors.rs
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
use crate::story::Story;
|
||||||
|
use gpui2::{px, Div, Render};
|
||||||
|
use ui::prelude::*;
|
||||||
|
|
||||||
|
pub struct ColorsStory;
|
||||||
|
|
||||||
|
impl Render for ColorsStory {
|
||||||
|
type Element = Div<Self>;
|
||||||
|
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||||
|
let color_scales = theme2::default_color_scales();
|
||||||
|
|
||||||
|
Story::container(cx)
|
||||||
|
.child(Story::title(cx, "Colors"))
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.id("colors")
|
||||||
|
.flex()
|
||||||
|
.flex_col()
|
||||||
|
.gap_1()
|
||||||
|
.overflow_y_scroll()
|
||||||
|
.text_color(gpui2::white())
|
||||||
|
.children(color_scales.into_iter().map(|(name, scale)| {
|
||||||
|
div()
|
||||||
|
.flex()
|
||||||
|
.child(
|
||||||
|
div()
|
||||||
|
.w(px(75.))
|
||||||
|
.line_height(px(24.))
|
||||||
|
.child(name.to_string()),
|
||||||
|
)
|
||||||
|
.child(div().flex().gap_1().children(
|
||||||
|
(1..=12).map(|step| div().flex().size_6().bg(scale.step(cx, step))),
|
||||||
|
))
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
@ -1,9 +1,9 @@
|
|||||||
use crate::themes::rose_pine;
|
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
div, Focusable, KeyBinding, ParentElement, StatelessInteractive, Styled, View, VisualContext,
|
div, Div, FocusEnabled, Focusable, KeyBinding, ParentElement, Render, StatefulInteraction,
|
||||||
WindowContext,
|
StatelessInteractive, Styled, View, VisualContext, WindowContext,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use theme2::theme;
|
||||||
|
|
||||||
#[derive(Clone, Default, PartialEq, Deserialize)]
|
#[derive(Clone, Default, PartialEq, Deserialize)]
|
||||||
struct ActionA;
|
struct ActionA;
|
||||||
@ -14,12 +14,10 @@ struct ActionB;
|
|||||||
#[derive(Clone, Default, PartialEq, Deserialize)]
|
#[derive(Clone, Default, PartialEq, Deserialize)]
|
||||||
struct ActionC;
|
struct ActionC;
|
||||||
|
|
||||||
pub struct FocusStory {
|
pub struct FocusStory {}
|
||||||
text: View<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FocusStory {
|
impl FocusStory {
|
||||||
pub fn view(cx: &mut WindowContext) -> View<()> {
|
pub fn view(cx: &mut WindowContext) -> View<Self> {
|
||||||
cx.bind_keys([
|
cx.bind_keys([
|
||||||
KeyBinding::new("cmd-a", ActionA, Some("parent")),
|
KeyBinding::new("cmd-a", ActionA, Some("parent")),
|
||||||
KeyBinding::new("cmd-a", ActionB, Some("child-1")),
|
KeyBinding::new("cmd-a", ActionB, Some("child-1")),
|
||||||
@ -27,91 +25,92 @@ impl FocusStory {
|
|||||||
]);
|
]);
|
||||||
cx.register_action_type::<ActionA>();
|
cx.register_action_type::<ActionA>();
|
||||||
cx.register_action_type::<ActionB>();
|
cx.register_action_type::<ActionB>();
|
||||||
let theme = rose_pine();
|
|
||||||
|
|
||||||
let color_1 = theme.lowest.negative.default.foreground;
|
cx.build_view(move |cx| Self {})
|
||||||
let color_2 = theme.lowest.positive.default.foreground;
|
}
|
||||||
let color_3 = theme.lowest.warning.default.foreground;
|
}
|
||||||
let color_4 = theme.lowest.accent.default.foreground;
|
|
||||||
let color_5 = theme.lowest.variant.default.foreground;
|
impl Render for FocusStory {
|
||||||
let color_6 = theme.highest.negative.default.foreground;
|
type Element = Div<Self, StatefulInteraction<Self>, FocusEnabled<Self>>;
|
||||||
|
|
||||||
|
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
|
||||||
|
let theme = theme(cx);
|
||||||
|
let color_1 = theme.git_created;
|
||||||
|
let color_2 = theme.git_modified;
|
||||||
|
let color_3 = theme.git_deleted;
|
||||||
|
let color_4 = theme.git_conflict;
|
||||||
|
let color_5 = theme.git_ignored;
|
||||||
|
let color_6 = theme.git_renamed;
|
||||||
let child_1 = cx.focus_handle();
|
let child_1 = cx.focus_handle();
|
||||||
let child_2 = cx.focus_handle();
|
let child_2 = cx.focus_handle();
|
||||||
|
|
||||||
cx.build_view(
|
div()
|
||||||
|_| (),
|
.id("parent")
|
||||||
move |_, cx| {
|
.focusable()
|
||||||
|
.context("parent")
|
||||||
|
.on_action(|_, action: &ActionA, phase, cx| {
|
||||||
|
println!("Action A dispatched on parent during {:?}", phase);
|
||||||
|
})
|
||||||
|
.on_action(|_, action: &ActionB, phase, cx| {
|
||||||
|
println!("Action B dispatched on parent during {:?}", phase);
|
||||||
|
})
|
||||||
|
.on_focus(|_, _, _| println!("Parent focused"))
|
||||||
|
.on_blur(|_, _, _| println!("Parent blurred"))
|
||||||
|
.on_focus_in(|_, _, _| println!("Parent focus_in"))
|
||||||
|
.on_focus_out(|_, _, _| println!("Parent focus_out"))
|
||||||
|
.on_key_down(|_, event, phase, _| {
|
||||||
|
println!("Key down on parent {:?} {:?}", phase, event)
|
||||||
|
})
|
||||||
|
.on_key_up(|_, event, phase, _| println!("Key up on parent {:?} {:?}", phase, event))
|
||||||
|
.size_full()
|
||||||
|
.bg(color_1)
|
||||||
|
.focus(|style| style.bg(color_2))
|
||||||
|
.focus_in(|style| style.bg(color_3))
|
||||||
|
.child(
|
||||||
div()
|
div()
|
||||||
.id("parent")
|
.track_focus(&child_1)
|
||||||
.focusable()
|
.context("child-1")
|
||||||
.context("parent")
|
|
||||||
.on_action(|_, action: &ActionA, phase, cx| {
|
|
||||||
println!("Action A dispatched on parent during {:?}", phase);
|
|
||||||
})
|
|
||||||
.on_action(|_, action: &ActionB, phase, cx| {
|
.on_action(|_, action: &ActionB, phase, cx| {
|
||||||
println!("Action B dispatched on parent during {:?}", phase);
|
println!("Action B dispatched on child 1 during {:?}", phase);
|
||||||
})
|
})
|
||||||
.on_focus(|_, _, _| println!("Parent focused"))
|
.w_full()
|
||||||
.on_blur(|_, _, _| println!("Parent blurred"))
|
.h_6()
|
||||||
.on_focus_in(|_, _, _| println!("Parent focus_in"))
|
.bg(color_4)
|
||||||
.on_focus_out(|_, _, _| println!("Parent focus_out"))
|
.focus(|style| style.bg(color_5))
|
||||||
|
.in_focus(|style| style.bg(color_6))
|
||||||
|
.on_focus(|_, _, _| println!("Child 1 focused"))
|
||||||
|
.on_blur(|_, _, _| println!("Child 1 blurred"))
|
||||||
|
.on_focus_in(|_, _, _| println!("Child 1 focus_in"))
|
||||||
|
.on_focus_out(|_, _, _| println!("Child 1 focus_out"))
|
||||||
.on_key_down(|_, event, phase, _| {
|
.on_key_down(|_, event, phase, _| {
|
||||||
println!("Key down on parent {:?} {:?}", phase, event)
|
println!("Key down on child 1 {:?} {:?}", phase, event)
|
||||||
})
|
})
|
||||||
.on_key_up(|_, event, phase, _| {
|
.on_key_up(|_, event, phase, _| {
|
||||||
println!("Key up on parent {:?} {:?}", phase, event)
|
println!("Key up on child 1 {:?} {:?}", phase, event)
|
||||||
})
|
})
|
||||||
.size_full()
|
.child("Child 1"),
|
||||||
.bg(color_1)
|
)
|
||||||
.focus(|style| style.bg(color_2))
|
.child(
|
||||||
.focus_in(|style| style.bg(color_3))
|
div()
|
||||||
.child(
|
.track_focus(&child_2)
|
||||||
div()
|
.context("child-2")
|
||||||
.track_focus(&child_1)
|
.on_action(|_, action: &ActionC, phase, cx| {
|
||||||
.context("child-1")
|
println!("Action C dispatched on child 2 during {:?}", phase);
|
||||||
.on_action(|_, action: &ActionB, phase, cx| {
|
})
|
||||||
println!("Action B dispatched on child 1 during {:?}", phase);
|
.w_full()
|
||||||
})
|
.h_6()
|
||||||
.w_full()
|
.bg(color_4)
|
||||||
.h_6()
|
.on_focus(|_, _, _| println!("Child 2 focused"))
|
||||||
.bg(color_4)
|
.on_blur(|_, _, _| println!("Child 2 blurred"))
|
||||||
.focus(|style| style.bg(color_5))
|
.on_focus_in(|_, _, _| println!("Child 2 focus_in"))
|
||||||
.in_focus(|style| style.bg(color_6))
|
.on_focus_out(|_, _, _| println!("Child 2 focus_out"))
|
||||||
.on_focus(|_, _, _| println!("Child 1 focused"))
|
.on_key_down(|_, event, phase, _| {
|
||||||
.on_blur(|_, _, _| println!("Child 1 blurred"))
|
println!("Key down on child 2 {:?} {:?}", phase, event)
|
||||||
.on_focus_in(|_, _, _| println!("Child 1 focus_in"))
|
})
|
||||||
.on_focus_out(|_, _, _| println!("Child 1 focus_out"))
|
.on_key_up(|_, event, phase, _| {
|
||||||
.on_key_down(|_, event, phase, _| {
|
println!("Key up on child 2 {:?} {:?}", phase, event)
|
||||||
println!("Key down on child 1 {:?} {:?}", phase, event)
|
})
|
||||||
})
|
.child("Child 2"),
|
||||||
.on_key_up(|_, event, phase, _| {
|
)
|
||||||
println!("Key up on child 1 {:?} {:?}", phase, event)
|
|
||||||
})
|
|
||||||
.child("Child 1"),
|
|
||||||
)
|
|
||||||
.child(
|
|
||||||
div()
|
|
||||||
.track_focus(&child_2)
|
|
||||||
.context("child-2")
|
|
||||||
.on_action(|_, action: &ActionC, phase, cx| {
|
|
||||||
println!("Action C dispatched on child 2 during {:?}", phase);
|
|
||||||
})
|
|
||||||
.w_full()
|
|
||||||
.h_6()
|
|
||||||
.bg(color_4)
|
|
||||||
.on_focus(|_, _, _| println!("Child 2 focused"))
|
|
||||||
.on_blur(|_, _, _| println!("Child 2 blurred"))
|
|
||||||
.on_focus_in(|_, _, _| println!("Child 2 focus_in"))
|
|
||||||
.on_focus_out(|_, _, _| println!("Child 2 focus_out"))
|
|
||||||
.on_key_down(|_, event, phase, _| {
|
|
||||||
println!("Key down on child 2 {:?} {:?}", phase, event)
|
|
||||||
})
|
|
||||||
.on_key_up(|_, event, phase, _| {
|
|
||||||
println!("Key up on child 2 {:?} {:?}", phase, event)
|
|
||||||
})
|
|
||||||
.child("Child 2"),
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,26 +1,23 @@
|
|||||||
use gpui2::{AppContext, Context, View};
|
use crate::{
|
||||||
|
story::Story,
|
||||||
|
story_selector::{ComponentStory, ElementStory},
|
||||||
|
};
|
||||||
|
use gpui2::{Div, Render, StatefulInteraction, View, VisualContext};
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
|
||||||
use crate::story::Story;
|
pub struct KitchenSinkStory;
|
||||||
use crate::story_selector::{ComponentStory, ElementStory};
|
|
||||||
|
|
||||||
pub struct KitchenSinkStory {}
|
|
||||||
|
|
||||||
impl KitchenSinkStory {
|
impl KitchenSinkStory {
|
||||||
pub fn new() -> Self {
|
pub fn view(cx: &mut WindowContext) -> View<Self> {
|
||||||
Self {}
|
cx.build_view(|cx| Self)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn view(cx: &mut AppContext) -> View<Self> {
|
impl Render for KitchenSinkStory {
|
||||||
{
|
type Element = Div<Self, StatefulInteraction<Self>>;
|
||||||
let state = cx.entity(|cx| Self::new());
|
|
||||||
let render = Self::render;
|
|
||||||
View::for_handle(state, render)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> {
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||||
let element_stories = ElementStory::iter()
|
let element_stories = ElementStory::iter()
|
||||||
.map(|selector| selector.story(cx))
|
.map(|selector| selector.story(cx))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
@ -1,58 +1,54 @@
|
|||||||
use crate::themes::rose_pine;
|
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
div, px, Component, ParentElement, SharedString, Styled, View, VisualContext, WindowContext,
|
div, px, Component, Div, ParentElement, Render, SharedString, StatefulInteraction, Styled,
|
||||||
|
View, VisualContext, WindowContext,
|
||||||
};
|
};
|
||||||
|
use theme2::theme;
|
||||||
|
|
||||||
pub struct ScrollStory {
|
pub struct ScrollStory;
|
||||||
text: View<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ScrollStory {
|
impl ScrollStory {
|
||||||
pub fn view(cx: &mut WindowContext) -> View<()> {
|
pub fn view(cx: &mut WindowContext) -> View<ScrollStory> {
|
||||||
let theme = rose_pine();
|
cx.build_view(|cx| ScrollStory)
|
||||||
|
|
||||||
{
|
|
||||||
cx.build_view(|cx| (), move |_, cx| checkerboard(1))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn checkerboard<S>(depth: usize) -> impl Component<S>
|
impl Render for ScrollStory {
|
||||||
where
|
type Element = Div<Self, StatefulInteraction<Self>>;
|
||||||
S: 'static + Send + Sync,
|
|
||||||
{
|
|
||||||
let theme = rose_pine();
|
|
||||||
let color_1 = theme.lowest.positive.default.background;
|
|
||||||
let color_2 = theme.lowest.warning.default.background;
|
|
||||||
|
|
||||||
div()
|
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
|
||||||
.id("parent")
|
let theme = theme(cx);
|
||||||
.bg(theme.lowest.base.default.background)
|
let color_1 = theme.git_created;
|
||||||
.size_full()
|
let color_2 = theme.git_modified;
|
||||||
.overflow_scroll()
|
|
||||||
.children((0..10).map(|row| {
|
div()
|
||||||
div()
|
.id("parent")
|
||||||
.w(px(1000.))
|
.bg(theme.background)
|
||||||
.h(px(100.))
|
.size_full()
|
||||||
.flex()
|
.overflow_scroll()
|
||||||
.flex_row()
|
.children((0..10).map(|row| {
|
||||||
.children((0..10).map(|column| {
|
div()
|
||||||
let id = SharedString::from(format!("{}, {}", row, column));
|
.w(px(1000.))
|
||||||
let bg = if row % 2 == column % 2 {
|
.h(px(100.))
|
||||||
color_1
|
.flex()
|
||||||
} else {
|
.flex_row()
|
||||||
color_2
|
.children((0..10).map(|column| {
|
||||||
};
|
let id = SharedString::from(format!("{}, {}", row, column));
|
||||||
div().id(id).bg(bg).size(px(100. / depth as f32)).when(
|
let bg = if row % 2 == column % 2 {
|
||||||
row >= 5 && column >= 5,
|
color_1
|
||||||
|d| {
|
} else {
|
||||||
d.overflow_scroll()
|
color_2
|
||||||
.child(div().size(px(50.)).bg(color_1))
|
};
|
||||||
.child(div().size(px(50.)).bg(color_2))
|
div().id(id).bg(bg).size(px(100. as f32)).when(
|
||||||
.child(div().size(px(50.)).bg(color_1))
|
row >= 5 && column >= 5,
|
||||||
.child(div().size(px(50.)).bg(color_2))
|
|d| {
|
||||||
},
|
d.overflow_scroll()
|
||||||
)
|
.child(div().size(px(50.)).bg(color_1))
|
||||||
}))
|
.child(div().size(px(50.)).bg(color_2))
|
||||||
}))
|
.child(div().size(px(50.)).bg(color_1))
|
||||||
|
.child(div().size(px(50.)).bg(color_2))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,20 +1,21 @@
|
|||||||
use gpui2::{div, white, ParentElement, Styled, View, VisualContext, WindowContext};
|
use gpui2::{div, white, Div, ParentElement, Render, Styled, View, VisualContext, WindowContext};
|
||||||
|
|
||||||
pub struct TextStory {
|
pub struct TextStory;
|
||||||
text: View<()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextStory {
|
impl TextStory {
|
||||||
pub fn view(cx: &mut WindowContext) -> View<()> {
|
pub fn view(cx: &mut WindowContext) -> View<Self> {
|
||||||
cx.build_view(|cx| (), |_, cx| {
|
cx.build_view(|cx| Self)
|
||||||
div()
|
}
|
||||||
.size_full()
|
}
|
||||||
.bg(white())
|
|
||||||
.child(concat!(
|
impl Render for TextStory {
|
||||||
"The quick brown fox jumps over the lazy dog. ",
|
type Element = Div<Self>;
|
||||||
"Meanwhile, the lazy dog decided it was time for a change. ",
|
|
||||||
"He started daily workout routines, ate healthier and became the fastest dog in town.",
|
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
|
||||||
))
|
div().size_full().bg(white()).child(concat!(
|
||||||
})
|
"The quick brown fox jumps over the lazy dog. ",
|
||||||
|
"Meanwhile, the lazy dog decided it was time for a change. ",
|
||||||
|
"He started daily workout routines, ate healthier and became the fastest dog in town.",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
use gpui2::{px, rgb, Div, Hsla};
|
use gpui2::{px, rgb, Div, Hsla, Render};
|
||||||
use ui::prelude::*;
|
use ui::prelude::*;
|
||||||
|
|
||||||
use crate::story::Story;
|
use crate::story::Story;
|
||||||
|
|
||||||
/// A reimplementation of the MDN `z-index` example, found here:
|
/// A reimplementation of the MDN `z-index` example, found here:
|
||||||
/// [https://developer.mozilla.org/en-US/docs/Web/CSS/z-index](https://developer.mozilla.org/en-US/docs/Web/CSS/z-index).
|
/// [https://developer.mozilla.org/en-US/docs/Web/CSS/z-index](https://developer.mozilla.org/en-US/docs/Web/CSS/z-index).
|
||||||
#[derive(Component)]
|
|
||||||
pub struct ZIndexStory;
|
pub struct ZIndexStory;
|
||||||
|
|
||||||
impl ZIndexStory {
|
impl Render for ZIndexStory {
|
||||||
fn render<V: 'static>(self, _view: &mut V, cx: &mut ViewContext<V>) -> impl Component<V> {
|
type Element = Div<Self>;
|
||||||
|
|
||||||
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||||
Story::container(cx)
|
Story::container(cx)
|
||||||
.child(Story::title(cx, "z-index"))
|
.child(Story::title(cx, "z-index"))
|
||||||
.child(
|
.child(
|
||||||
|
@ -7,13 +7,14 @@ use clap::builder::PossibleValue;
|
|||||||
use clap::ValueEnum;
|
use clap::ValueEnum;
|
||||||
use gpui2::{AnyView, VisualContext};
|
use gpui2::{AnyView, VisualContext};
|
||||||
use strum::{EnumIter, EnumString, IntoEnumIterator};
|
use strum::{EnumIter, EnumString, IntoEnumIterator};
|
||||||
use ui::prelude::*;
|
use ui::{prelude::*, AvatarStory, ButtonStory, DetailsStory, IconStory, InputStory, LabelStory};
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display, EnumString, EnumIter)]
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display, EnumString, EnumIter)]
|
||||||
#[strum(serialize_all = "snake_case")]
|
#[strum(serialize_all = "snake_case")]
|
||||||
pub enum ElementStory {
|
pub enum ElementStory {
|
||||||
Avatar,
|
Avatar,
|
||||||
Button,
|
Button,
|
||||||
|
Colors,
|
||||||
Details,
|
Details,
|
||||||
Focus,
|
Focus,
|
||||||
Icon,
|
Icon,
|
||||||
@ -27,18 +28,17 @@ pub enum ElementStory {
|
|||||||
impl ElementStory {
|
impl ElementStory {
|
||||||
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
|
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
match self {
|
match self {
|
||||||
Self::Avatar => { cx.build_view(|cx| (), |_, _| ui::AvatarStory.render()) }.into_any(),
|
Self::Colors => cx.build_view(|_| ColorsStory).into_any(),
|
||||||
Self::Button => { cx.build_view(|cx| (), |_, _| ui::ButtonStory.render()) }.into_any(),
|
Self::Avatar => cx.build_view(|_| AvatarStory).into_any(),
|
||||||
Self::Details => {
|
Self::Button => cx.build_view(|_| ButtonStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::DetailsStory.render()) }.into_any()
|
Self::Details => cx.build_view(|_| DetailsStory).into_any(),
|
||||||
}
|
|
||||||
Self::Focus => FocusStory::view(cx).into_any(),
|
Self::Focus => FocusStory::view(cx).into_any(),
|
||||||
Self::Icon => { cx.build_view(|cx| (), |_, _| ui::IconStory.render()) }.into_any(),
|
Self::Icon => cx.build_view(|_| IconStory).into_any(),
|
||||||
Self::Input => { cx.build_view(|cx| (), |_, _| ui::InputStory.render()) }.into_any(),
|
Self::Input => cx.build_view(|_| InputStory).into_any(),
|
||||||
Self::Label => { cx.build_view(|cx| (), |_, _| ui::LabelStory.render()) }.into_any(),
|
Self::Label => cx.build_view(|_| LabelStory).into_any(),
|
||||||
Self::Scroll => ScrollStory::view(cx).into_any(),
|
Self::Scroll => ScrollStory::view(cx).into_any(),
|
||||||
Self::Text => TextStory::view(cx).into_any(),
|
Self::Text => TextStory::view(cx).into_any(),
|
||||||
Self::ZIndex => { cx.build_view(|cx| (), |_, _| ZIndexStory.render()) }.into_any(),
|
Self::ZIndex => cx.build_view(|_| ZIndexStory).into_any(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -77,69 +77,31 @@ pub enum ComponentStory {
|
|||||||
impl ComponentStory {
|
impl ComponentStory {
|
||||||
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
|
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
match self {
|
match self {
|
||||||
Self::AssistantPanel => {
|
Self::AssistantPanel => cx.build_view(|_| ui::AssistantPanelStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::AssistantPanelStory.render()) }.into_any()
|
Self::Buffer => cx.build_view(|_| ui::BufferStory).into_any(),
|
||||||
}
|
Self::Breadcrumb => cx.build_view(|_| ui::BreadcrumbStory).into_any(),
|
||||||
Self::Buffer => { cx.build_view(|cx| (), |_, _| ui::BufferStory.render()) }.into_any(),
|
Self::ChatPanel => cx.build_view(|_| ui::ChatPanelStory).into_any(),
|
||||||
Self::Breadcrumb => {
|
Self::CollabPanel => cx.build_view(|_| ui::CollabPanelStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::BreadcrumbStory.render()) }.into_any()
|
Self::CommandPalette => cx.build_view(|_| ui::CommandPaletteStory).into_any(),
|
||||||
}
|
Self::ContextMenu => cx.build_view(|_| ui::ContextMenuStory).into_any(),
|
||||||
Self::ChatPanel => {
|
Self::Facepile => cx.build_view(|_| ui::FacepileStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::ChatPanelStory.render()) }.into_any()
|
Self::Keybinding => cx.build_view(|_| ui::KeybindingStory).into_any(),
|
||||||
}
|
Self::LanguageSelector => cx.build_view(|_| ui::LanguageSelectorStory).into_any(),
|
||||||
Self::CollabPanel => {
|
Self::MultiBuffer => cx.build_view(|_| ui::MultiBufferStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::CollabPanelStory.render()) }.into_any()
|
Self::NotificationsPanel => cx.build_view(|cx| ui::NotificationsPanelStory).into_any(),
|
||||||
}
|
Self::Palette => cx.build_view(|cx| ui::PaletteStory).into_any(),
|
||||||
Self::CommandPalette => {
|
Self::Panel => cx.build_view(|cx| ui::PanelStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::CommandPaletteStory.render()) }.into_any()
|
Self::ProjectPanel => cx.build_view(|_| ui::ProjectPanelStory).into_any(),
|
||||||
}
|
Self::RecentProjects => cx.build_view(|_| ui::RecentProjectsStory).into_any(),
|
||||||
Self::ContextMenu => {
|
Self::Tab => cx.build_view(|_| ui::TabStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::ContextMenuStory.render()) }.into_any()
|
Self::TabBar => cx.build_view(|_| ui::TabBarStory).into_any(),
|
||||||
}
|
Self::Terminal => cx.build_view(|_| ui::TerminalStory).into_any(),
|
||||||
Self::Facepile => {
|
Self::ThemeSelector => cx.build_view(|_| ui::ThemeSelectorStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::FacepileStory.render()) }.into_any()
|
Self::Toast => cx.build_view(|_| ui::ToastStory).into_any(),
|
||||||
}
|
Self::Toolbar => cx.build_view(|_| ui::ToolbarStory).into_any(),
|
||||||
Self::Keybinding => {
|
Self::TrafficLights => cx.build_view(|_| ui::TrafficLightsStory).into_any(),
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::KeybindingStory.render()) }.into_any()
|
Self::Copilot => cx.build_view(|_| ui::CopilotModalStory).into_any(),
|
||||||
}
|
|
||||||
Self::LanguageSelector => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::LanguageSelectorStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::MultiBuffer => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::MultiBufferStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::NotificationsPanel => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::NotificationsPanelStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::Palette => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::PaletteStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::Panel => { cx.build_view(|cx| (), |_, _| ui::PanelStory.render()) }.into_any(),
|
|
||||||
Self::ProjectPanel => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::ProjectPanelStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::RecentProjects => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::RecentProjectsStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::Tab => { cx.build_view(|cx| (), |_, _| ui::TabStory.render()) }.into_any(),
|
|
||||||
Self::TabBar => { cx.build_view(|cx| (), |_, _| ui::TabBarStory.render()) }.into_any(),
|
|
||||||
Self::Terminal => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::TerminalStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::ThemeSelector => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::ThemeSelectorStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::TitleBar => ui::TitleBarStory::view(cx).into_any(),
|
Self::TitleBar => ui::TitleBarStory::view(cx).into_any(),
|
||||||
Self::Toast => { cx.build_view(|cx| (), |_, _| ui::ToastStory.render()) }.into_any(),
|
|
||||||
Self::Toolbar => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::ToolbarStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::TrafficLights => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::TrafficLightsStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::Copilot => {
|
|
||||||
{ cx.build_view(|cx| (), |_, _| ui::CopilotModalStory.render()) }.into_any()
|
|
||||||
}
|
|
||||||
Self::Workspace => ui::WorkspaceStory::view(cx).into_any(),
|
Self::Workspace => ui::WorkspaceStory::view(cx).into_any(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,21 +4,20 @@ mod assets;
|
|||||||
mod stories;
|
mod stories;
|
||||||
mod story;
|
mod story;
|
||||||
mod story_selector;
|
mod story_selector;
|
||||||
mod themes;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use gpui2::{
|
use gpui2::{
|
||||||
div, px, size, AnyView, AppContext, Bounds, ViewContext, VisualContext, WindowBounds,
|
div, px, size, AnyView, AppContext, Bounds, Div, Render, ViewContext, VisualContext,
|
||||||
WindowOptions,
|
WindowBounds, WindowOptions,
|
||||||
};
|
};
|
||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use settings2::{default_settings, Settings, SettingsStore};
|
use settings2::{default_settings, Settings, SettingsStore};
|
||||||
use simplelog::SimpleLogger;
|
use simplelog::SimpleLogger;
|
||||||
use story_selector::ComponentStory;
|
use story_selector::ComponentStory;
|
||||||
use theme2::{ThemeRegistry, ThemeSettings};
|
use theme2::{ThemeRegistry, ThemeSettings};
|
||||||
use ui::{prelude::*, themed};
|
use ui::prelude::*;
|
||||||
|
|
||||||
use crate::assets::Assets;
|
use crate::assets::Assets;
|
||||||
use crate::story_selector::StorySelector;
|
use crate::story_selector::StorySelector;
|
||||||
@ -50,7 +49,6 @@ fn main() {
|
|||||||
|
|
||||||
let story_selector = args.story.clone();
|
let story_selector = args.story.clone();
|
||||||
let theme_name = args.theme.unwrap_or("One Dark".to_string());
|
let theme_name = args.theme.unwrap_or("One Dark".to_string());
|
||||||
let theme = themes::load_theme(theme_name.clone()).unwrap();
|
|
||||||
|
|
||||||
let asset_source = Arc::new(Assets);
|
let asset_source = Arc::new(Assets);
|
||||||
gpui2::App::production(asset_source).run(move |cx| {
|
gpui2::App::production(asset_source).run(move |cx| {
|
||||||
@ -84,12 +82,7 @@ fn main() {
|
|||||||
}),
|
}),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
move |cx| {
|
move |cx| cx.build_view(|cx| StoryWrapper::new(selector.story(cx))),
|
||||||
cx.build_view(
|
|
||||||
|cx| StoryWrapper::new(selector.story(cx), theme),
|
|
||||||
StoryWrapper::render,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
|
|
||||||
cx.activate(true);
|
cx.activate(true);
|
||||||
@ -99,22 +92,23 @@ fn main() {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct StoryWrapper {
|
pub struct StoryWrapper {
|
||||||
story: AnyView,
|
story: AnyView,
|
||||||
theme: Theme,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StoryWrapper {
|
impl StoryWrapper {
|
||||||
pub(crate) fn new(story: AnyView, theme: Theme) -> Self {
|
pub(crate) fn new(story: AnyView) -> Self {
|
||||||
Self { story, theme }
|
Self { story }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> {
|
impl Render for StoryWrapper {
|
||||||
themed(self.theme.clone(), cx, |cx| {
|
type Element = Div<Self>;
|
||||||
div()
|
|
||||||
.flex()
|
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||||
.flex_col()
|
div()
|
||||||
.size_full()
|
.flex()
|
||||||
.child(self.story.clone())
|
.flex_col()
|
||||||
})
|
.size_full()
|
||||||
|
.child(self.story.clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
mod rose_pine;
|
|
||||||
|
|
||||||
pub use rose_pine::*;
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use gpui2::serde_json;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use ui::Theme;
|
|
||||||
|
|
||||||
use crate::assets::Assets;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct LegacyTheme {
|
|
||||||
pub base_theme: serde_json::Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Loads the [`Theme`] with the given name.
|
|
||||||
pub fn load_theme(name: String) -> Result<Theme> {
|
|
||||||
let theme_contents = Assets::get(&format!("themes/{name}.json"))
|
|
||||||
.with_context(|| format!("theme file not found: '{name}'"))?;
|
|
||||||
|
|
||||||
let legacy_theme: LegacyTheme =
|
|
||||||
serde_json::from_str(std::str::from_utf8(&theme_contents.data)?)
|
|
||||||
.context("failed to parse legacy theme")?;
|
|
||||||
|
|
||||||
let theme: Theme = serde_json::from_value(legacy_theme.base_theme.clone())
|
|
||||||
.context("failed to parse `base_theme`")?;
|
|
||||||
|
|
||||||
Ok(theme)
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
2118
crates/theme2/src/default.rs
Normal file
2118
crates/theme2/src/default.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,4 @@
|
|||||||
use crate::{
|
use crate::{themes, Theme, ThemeMetadata};
|
||||||
themes::{one_dark, rose_pine, rose_pine_dawn, rose_pine_moon, sandcastle},
|
|
||||||
Theme, ThemeMetadata,
|
|
||||||
};
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use gpui2::SharedString;
|
use gpui2::SharedString;
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
@ -41,11 +38,45 @@ impl Default for ThemeRegistry {
|
|||||||
};
|
};
|
||||||
|
|
||||||
this.insert_themes([
|
this.insert_themes([
|
||||||
one_dark(),
|
themes::andromeda(),
|
||||||
rose_pine(),
|
themes::atelier_cave_dark(),
|
||||||
rose_pine_dawn(),
|
themes::atelier_cave_light(),
|
||||||
rose_pine_moon(),
|
themes::atelier_dune_dark(),
|
||||||
sandcastle(),
|
themes::atelier_dune_light(),
|
||||||
|
themes::atelier_estuary_dark(),
|
||||||
|
themes::atelier_estuary_light(),
|
||||||
|
themes::atelier_forest_dark(),
|
||||||
|
themes::atelier_forest_light(),
|
||||||
|
themes::atelier_heath_dark(),
|
||||||
|
themes::atelier_heath_light(),
|
||||||
|
themes::atelier_lakeside_dark(),
|
||||||
|
themes::atelier_lakeside_light(),
|
||||||
|
themes::atelier_plateau_dark(),
|
||||||
|
themes::atelier_plateau_light(),
|
||||||
|
themes::atelier_savanna_dark(),
|
||||||
|
themes::atelier_savanna_light(),
|
||||||
|
themes::atelier_seaside_dark(),
|
||||||
|
themes::atelier_seaside_light(),
|
||||||
|
themes::atelier_sulphurpool_dark(),
|
||||||
|
themes::atelier_sulphurpool_light(),
|
||||||
|
themes::ayu_dark(),
|
||||||
|
themes::ayu_light(),
|
||||||
|
themes::ayu_mirage(),
|
||||||
|
themes::gruvbox_dark(),
|
||||||
|
themes::gruvbox_dark_hard(),
|
||||||
|
themes::gruvbox_dark_soft(),
|
||||||
|
themes::gruvbox_light(),
|
||||||
|
themes::gruvbox_light_hard(),
|
||||||
|
themes::gruvbox_light_soft(),
|
||||||
|
themes::one_dark(),
|
||||||
|
themes::one_light(),
|
||||||
|
themes::rose_pine(),
|
||||||
|
themes::rose_pine_dawn(),
|
||||||
|
themes::rose_pine_moon(),
|
||||||
|
themes::sandcastle(),
|
||||||
|
themes::solarized_dark(),
|
||||||
|
themes::solarized_light(),
|
||||||
|
themes::summercamp(),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
this
|
this
|
||||||
|
164
crates/theme2/src/scale.rs
Normal file
164
crates/theme2/src/scale.rs
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
use gpui2::{AppContext, Hsla};
|
||||||
|
use indexmap::IndexMap;
|
||||||
|
|
||||||
|
use crate::{theme, Appearance};
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
|
pub enum ColorScaleName {
|
||||||
|
Gray,
|
||||||
|
Mauve,
|
||||||
|
Slate,
|
||||||
|
Sage,
|
||||||
|
Olive,
|
||||||
|
Sand,
|
||||||
|
Gold,
|
||||||
|
Bronze,
|
||||||
|
Brown,
|
||||||
|
Yellow,
|
||||||
|
Amber,
|
||||||
|
Orange,
|
||||||
|
Tomato,
|
||||||
|
Red,
|
||||||
|
Ruby,
|
||||||
|
Crimson,
|
||||||
|
Pink,
|
||||||
|
Plum,
|
||||||
|
Purple,
|
||||||
|
Violet,
|
||||||
|
Iris,
|
||||||
|
Indigo,
|
||||||
|
Blue,
|
||||||
|
Cyan,
|
||||||
|
Teal,
|
||||||
|
Jade,
|
||||||
|
Green,
|
||||||
|
Grass,
|
||||||
|
Lime,
|
||||||
|
Mint,
|
||||||
|
Sky,
|
||||||
|
Black,
|
||||||
|
White,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ColorScaleName {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"{}",
|
||||||
|
match self {
|
||||||
|
Self::Gray => "Gray",
|
||||||
|
Self::Mauve => "Mauve",
|
||||||
|
Self::Slate => "Slate",
|
||||||
|
Self::Sage => "Sage",
|
||||||
|
Self::Olive => "Olive",
|
||||||
|
Self::Sand => "Sand",
|
||||||
|
Self::Gold => "Gold",
|
||||||
|
Self::Bronze => "Bronze",
|
||||||
|
Self::Brown => "Brown",
|
||||||
|
Self::Yellow => "Yellow",
|
||||||
|
Self::Amber => "Amber",
|
||||||
|
Self::Orange => "Orange",
|
||||||
|
Self::Tomato => "Tomato",
|
||||||
|
Self::Red => "Red",
|
||||||
|
Self::Ruby => "Ruby",
|
||||||
|
Self::Crimson => "Crimson",
|
||||||
|
Self::Pink => "Pink",
|
||||||
|
Self::Plum => "Plum",
|
||||||
|
Self::Purple => "Purple",
|
||||||
|
Self::Violet => "Violet",
|
||||||
|
Self::Iris => "Iris",
|
||||||
|
Self::Indigo => "Indigo",
|
||||||
|
Self::Blue => "Blue",
|
||||||
|
Self::Cyan => "Cyan",
|
||||||
|
Self::Teal => "Teal",
|
||||||
|
Self::Jade => "Jade",
|
||||||
|
Self::Green => "Green",
|
||||||
|
Self::Grass => "Grass",
|
||||||
|
Self::Lime => "Lime",
|
||||||
|
Self::Mint => "Mint",
|
||||||
|
Self::Sky => "Sky",
|
||||||
|
Self::Black => "Black",
|
||||||
|
Self::White => "White",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ColorScale = [Hsla; 12];
|
||||||
|
|
||||||
|
pub type ColorScales = IndexMap<ColorScaleName, ColorScaleSet>;
|
||||||
|
|
||||||
|
/// A one-based step in a [`ColorScale`].
|
||||||
|
pub type ColorScaleStep = usize;
|
||||||
|
|
||||||
|
pub struct ColorScaleSet {
|
||||||
|
name: ColorScaleName,
|
||||||
|
light: ColorScale,
|
||||||
|
dark: ColorScale,
|
||||||
|
light_alpha: ColorScale,
|
||||||
|
dark_alpha: ColorScale,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ColorScaleSet {
|
||||||
|
pub fn new(
|
||||||
|
name: ColorScaleName,
|
||||||
|
light: ColorScale,
|
||||||
|
light_alpha: ColorScale,
|
||||||
|
dark: ColorScale,
|
||||||
|
dark_alpha: ColorScale,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
name,
|
||||||
|
light,
|
||||||
|
light_alpha,
|
||||||
|
dark,
|
||||||
|
dark_alpha,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn name(&self) -> String {
|
||||||
|
self.name.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn light(&self, step: ColorScaleStep) -> Hsla {
|
||||||
|
self.light[step - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn light_alpha(&self, step: ColorScaleStep) -> Hsla {
|
||||||
|
self.light_alpha[step - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dark(&self, step: ColorScaleStep) -> Hsla {
|
||||||
|
self.dark[step - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dark_alpha(&self, step: ColorScaleStep) -> Hsla {
|
||||||
|
self.dark_alpha[step - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn current_appearance(cx: &AppContext) -> Appearance {
|
||||||
|
let theme = theme(cx);
|
||||||
|
if theme.metadata.is_light {
|
||||||
|
Appearance::Light
|
||||||
|
} else {
|
||||||
|
Appearance::Dark
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
|
||||||
|
let appearance = Self::current_appearance(cx);
|
||||||
|
|
||||||
|
match appearance {
|
||||||
|
Appearance::Light => self.light(step),
|
||||||
|
Appearance::Dark => self.dark(step),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step_alpha(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
|
||||||
|
let appearance = Self::current_appearance(cx);
|
||||||
|
match appearance {
|
||||||
|
Appearance::Light => self.light_alpha(step),
|
||||||
|
Appearance::Dark => self.dark_alpha(step),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,14 +1,24 @@
|
|||||||
|
mod default;
|
||||||
mod registry;
|
mod registry;
|
||||||
|
mod scale;
|
||||||
mod settings;
|
mod settings;
|
||||||
mod themes;
|
mod themes;
|
||||||
|
|
||||||
|
pub use default::*;
|
||||||
pub use registry::*;
|
pub use registry::*;
|
||||||
|
pub use scale::*;
|
||||||
pub use settings::*;
|
pub use settings::*;
|
||||||
|
|
||||||
use gpui2::{AppContext, HighlightStyle, Hsla, SharedString};
|
use gpui2::{AppContext, HighlightStyle, Hsla, SharedString};
|
||||||
use settings2::Settings;
|
use settings2::Settings;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum Appearance {
|
||||||
|
Light,
|
||||||
|
Dark,
|
||||||
|
}
|
||||||
|
|
||||||
pub fn init(cx: &mut AppContext) {
|
pub fn init(cx: &mut AppContext) {
|
||||||
cx.set_global(ThemeRegistry::default());
|
cx.set_global(ThemeRegistry::default());
|
||||||
ThemeSettings::register(cx);
|
ThemeSettings::register(cx);
|
||||||
@ -18,6 +28,10 @@ pub fn active_theme<'a>(cx: &'a AppContext) -> &'a Arc<Theme> {
|
|||||||
&ThemeSettings::get_global(cx).active_theme
|
&ThemeSettings::get_global(cx).active_theme
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn theme(cx: &AppContext) -> Arc<Theme> {
|
||||||
|
active_theme(cx).clone()
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Theme {
|
pub struct Theme {
|
||||||
pub metadata: ThemeMetadata,
|
pub metadata: ThemeMetadata,
|
||||||
|
|
||||||
|
130
crates/theme2/src/themes/andromeda.rs
Normal file
130
crates/theme2/src/themes/andromeda.rs
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
use gpui2::rgba;
|
||||||
|
|
||||||
|
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||||
|
|
||||||
|
pub fn andromeda() -> Theme {
|
||||||
|
Theme {
|
||||||
|
metadata: ThemeMetadata {
|
||||||
|
name: "Andromeda".into(),
|
||||||
|
is_light: false,
|
||||||
|
},
|
||||||
|
transparent: rgba(0x00000000).into(),
|
||||||
|
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||||
|
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||||
|
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||||
|
border: rgba(0x2b2f38ff).into(),
|
||||||
|
border_variant: rgba(0x2b2f38ff).into(),
|
||||||
|
border_focused: rgba(0x183934ff).into(),
|
||||||
|
border_transparent: rgba(0x00000000).into(),
|
||||||
|
elevated_surface: rgba(0x262933ff).into(),
|
||||||
|
surface: rgba(0x21242bff).into(),
|
||||||
|
background: rgba(0x262933ff).into(),
|
||||||
|
filled_element: rgba(0x262933ff).into(),
|
||||||
|
filled_element_hover: rgba(0xffffff1e).into(),
|
||||||
|
filled_element_active: rgba(0xffffff28).into(),
|
||||||
|
filled_element_selected: rgba(0x12231fff).into(),
|
||||||
|
filled_element_disabled: rgba(0x00000000).into(),
|
||||||
|
ghost_element: rgba(0x00000000).into(),
|
||||||
|
ghost_element_hover: rgba(0xffffff14).into(),
|
||||||
|
ghost_element_active: rgba(0xffffff1e).into(),
|
||||||
|
ghost_element_selected: rgba(0x12231fff).into(),
|
||||||
|
ghost_element_disabled: rgba(0x00000000).into(),
|
||||||
|
text: rgba(0xf7f7f8ff).into(),
|
||||||
|
text_muted: rgba(0xaca8aeff).into(),
|
||||||
|
text_placeholder: rgba(0xf82871ff).into(),
|
||||||
|
text_disabled: rgba(0x6b6b73ff).into(),
|
||||||
|
text_accent: rgba(0x10a793ff).into(),
|
||||||
|
icon_muted: rgba(0xaca8aeff).into(),
|
||||||
|
syntax: SyntaxTheme {
|
||||||
|
highlights: vec![
|
||||||
|
("emphasis".into(), rgba(0x10a793ff).into()),
|
||||||
|
("punctuation.bracket".into(), rgba(0xd8d5dbff).into()),
|
||||||
|
("attribute".into(), rgba(0x10a793ff).into()),
|
||||||
|
("variable".into(), rgba(0xf7f7f8ff).into()),
|
||||||
|
("predictive".into(), rgba(0x315f70ff).into()),
|
||||||
|
("property".into(), rgba(0x10a793ff).into()),
|
||||||
|
("variant".into(), rgba(0x10a793ff).into()),
|
||||||
|
("embedded".into(), rgba(0xf7f7f8ff).into()),
|
||||||
|
("string.special".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("keyword".into(), rgba(0x10a793ff).into()),
|
||||||
|
("tag".into(), rgba(0x10a793ff).into()),
|
||||||
|
("enum".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("link_text".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("primary".into(), rgba(0xf7f7f8ff).into()),
|
||||||
|
("punctuation".into(), rgba(0xd8d5dbff).into()),
|
||||||
|
("punctuation.special".into(), rgba(0xd8d5dbff).into()),
|
||||||
|
("function".into(), rgba(0xfee56cff).into()),
|
||||||
|
("number".into(), rgba(0x96df71ff).into()),
|
||||||
|
("preproc".into(), rgba(0xf7f7f8ff).into()),
|
||||||
|
("operator".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("constructor".into(), rgba(0x10a793ff).into()),
|
||||||
|
("string.escape".into(), rgba(0xafabb1ff).into()),
|
||||||
|
("string.special.symbol".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("string".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("comment".into(), rgba(0xafabb1ff).into()),
|
||||||
|
("hint".into(), rgba(0x618399ff).into()),
|
||||||
|
("type".into(), rgba(0x08e7c5ff).into()),
|
||||||
|
("label".into(), rgba(0x10a793ff).into()),
|
||||||
|
("comment.doc".into(), rgba(0xafabb1ff).into()),
|
||||||
|
("text.literal".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("constant".into(), rgba(0x96df71ff).into()),
|
||||||
|
("string.regex".into(), rgba(0xf29c14ff).into()),
|
||||||
|
("emphasis.strong".into(), rgba(0x10a793ff).into()),
|
||||||
|
("title".into(), rgba(0xf7f7f8ff).into()),
|
||||||
|
("punctuation.delimiter".into(), rgba(0xd8d5dbff).into()),
|
||||||
|
("link_uri".into(), rgba(0x96df71ff).into()),
|
||||||
|
("boolean".into(), rgba(0x96df71ff).into()),
|
||||||
|
("punctuation.list_marker".into(), rgba(0xd8d5dbff).into()),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
status_bar: rgba(0x262933ff).into(),
|
||||||
|
title_bar: rgba(0x262933ff).into(),
|
||||||
|
toolbar: rgba(0x1e2025ff).into(),
|
||||||
|
tab_bar: rgba(0x21242bff).into(),
|
||||||
|
editor: rgba(0x1e2025ff).into(),
|
||||||
|
editor_subheader: rgba(0x21242bff).into(),
|
||||||
|
editor_active_line: rgba(0x21242bff).into(),
|
||||||
|
terminal: rgba(0x1e2025ff).into(),
|
||||||
|
image_fallback_background: rgba(0x262933ff).into(),
|
||||||
|
git_created: rgba(0x96df71ff).into(),
|
||||||
|
git_modified: rgba(0x10a793ff).into(),
|
||||||
|
git_deleted: rgba(0xf82871ff).into(),
|
||||||
|
git_conflict: rgba(0xfee56cff).into(),
|
||||||
|
git_ignored: rgba(0x6b6b73ff).into(),
|
||||||
|
git_renamed: rgba(0xfee56cff).into(),
|
||||||
|
players: [
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x10a793ff).into(),
|
||||||
|
selection: rgba(0x10a7933d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x96df71ff).into(),
|
||||||
|
selection: rgba(0x96df713d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xc74cecff).into(),
|
||||||
|
selection: rgba(0xc74cec3d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xf29c14ff).into(),
|
||||||
|
selection: rgba(0xf29c143d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x893ea6ff).into(),
|
||||||
|
selection: rgba(0x893ea63d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x08e7c5ff).into(),
|
||||||
|
selection: rgba(0x08e7c53d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xf82871ff).into(),
|
||||||
|
selection: rgba(0xf828713d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xfee56cff).into(),
|
||||||
|
selection: rgba(0xfee56c3d).into(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
136
crates/theme2/src/themes/atelier_cave_dark.rs
Normal file
136
crates/theme2/src/themes/atelier_cave_dark.rs
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
use gpui2::rgba;
|
||||||
|
|
||||||
|
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||||
|
|
||||||
|
pub fn atelier_cave_dark() -> Theme {
|
||||||
|
Theme {
|
||||||
|
metadata: ThemeMetadata {
|
||||||
|
name: "Atelier Cave Dark".into(),
|
||||||
|
is_light: false,
|
||||||
|
},
|
||||||
|
transparent: rgba(0x00000000).into(),
|
||||||
|
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||||
|
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||||
|
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||||
|
border: rgba(0x56505eff).into(),
|
||||||
|
border_variant: rgba(0x56505eff).into(),
|
||||||
|
border_focused: rgba(0x222953ff).into(),
|
||||||
|
border_transparent: rgba(0x00000000).into(),
|
||||||
|
elevated_surface: rgba(0x3a353fff).into(),
|
||||||
|
surface: rgba(0x221f26ff).into(),
|
||||||
|
background: rgba(0x3a353fff).into(),
|
||||||
|
filled_element: rgba(0x3a353fff).into(),
|
||||||
|
filled_element_hover: rgba(0xffffff1e).into(),
|
||||||
|
filled_element_active: rgba(0xffffff28).into(),
|
||||||
|
filled_element_selected: rgba(0x161a35ff).into(),
|
||||||
|
filled_element_disabled: rgba(0x00000000).into(),
|
||||||
|
ghost_element: rgba(0x00000000).into(),
|
||||||
|
ghost_element_hover: rgba(0xffffff14).into(),
|
||||||
|
ghost_element_active: rgba(0xffffff1e).into(),
|
||||||
|
ghost_element_selected: rgba(0x161a35ff).into(),
|
||||||
|
ghost_element_disabled: rgba(0x00000000).into(),
|
||||||
|
text: rgba(0xefecf4ff).into(),
|
||||||
|
text_muted: rgba(0x898591ff).into(),
|
||||||
|
text_placeholder: rgba(0xbe4677ff).into(),
|
||||||
|
text_disabled: rgba(0x756f7eff).into(),
|
||||||
|
text_accent: rgba(0x566ddaff).into(),
|
||||||
|
icon_muted: rgba(0x898591ff).into(),
|
||||||
|
syntax: SyntaxTheme {
|
||||||
|
highlights: vec![
|
||||||
|
("comment.doc".into(), rgba(0x8b8792ff).into()),
|
||||||
|
("tag".into(), rgba(0x566ddaff).into()),
|
||||||
|
("link_text".into(), rgba(0xaa563bff).into()),
|
||||||
|
("constructor".into(), rgba(0x566ddaff).into()),
|
||||||
|
("punctuation".into(), rgba(0xe2dfe7ff).into()),
|
||||||
|
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
|
||||||
|
("string.special.symbol".into(), rgba(0x299292ff).into()),
|
||||||
|
("string.escape".into(), rgba(0x8b8792ff).into()),
|
||||||
|
("emphasis".into(), rgba(0x566ddaff).into()),
|
||||||
|
("type".into(), rgba(0xa06d3aff).into()),
|
||||||
|
("punctuation.delimiter".into(), rgba(0x8b8792ff).into()),
|
||||||
|
("variant".into(), rgba(0xa06d3aff).into()),
|
||||||
|
("variable.special".into(), rgba(0x9559e7ff).into()),
|
||||||
|
("text.literal".into(), rgba(0xaa563bff).into()),
|
||||||
|
("punctuation.list_marker".into(), rgba(0xe2dfe7ff).into()),
|
||||||
|
("comment".into(), rgba(0x655f6dff).into()),
|
||||||
|
("function.method".into(), rgba(0x576cdbff).into()),
|
||||||
|
("property".into(), rgba(0xbe4677ff).into()),
|
||||||
|
("operator".into(), rgba(0x8b8792ff).into()),
|
||||||
|
("emphasis.strong".into(), rgba(0x566ddaff).into()),
|
||||||
|
("label".into(), rgba(0x566ddaff).into()),
|
||||||
|
("enum".into(), rgba(0xaa563bff).into()),
|
||||||
|
("number".into(), rgba(0xaa563bff).into()),
|
||||||
|
("primary".into(), rgba(0xe2dfe7ff).into()),
|
||||||
|
("keyword".into(), rgba(0x9559e7ff).into()),
|
||||||
|
(
|
||||||
|
"function.special.definition".into(),
|
||||||
|
rgba(0xa06d3aff).into(),
|
||||||
|
),
|
||||||
|
("punctuation.bracket".into(), rgba(0x8b8792ff).into()),
|
||||||
|
("constant".into(), rgba(0x2b9292ff).into()),
|
||||||
|
("string.special".into(), rgba(0xbf3fbfff).into()),
|
||||||
|
("title".into(), rgba(0xefecf4ff).into()),
|
||||||
|
("preproc".into(), rgba(0xefecf4ff).into()),
|
||||||
|
("link_uri".into(), rgba(0x2b9292ff).into()),
|
||||||
|
("string".into(), rgba(0x299292ff).into()),
|
||||||
|
("embedded".into(), rgba(0xefecf4ff).into()),
|
||||||
|
("hint".into(), rgba(0x706897ff).into()),
|
||||||
|
("boolean".into(), rgba(0x2b9292ff).into()),
|
||||||
|
("variable".into(), rgba(0xe2dfe7ff).into()),
|
||||||
|
("predictive".into(), rgba(0x615787ff).into()),
|
||||||
|
("string.regex".into(), rgba(0x388bc6ff).into()),
|
||||||
|
("function".into(), rgba(0x576cdbff).into()),
|
||||||
|
("attribute".into(), rgba(0x566ddaff).into()),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
status_bar: rgba(0x3a353fff).into(),
|
||||||
|
title_bar: rgba(0x3a353fff).into(),
|
||||||
|
toolbar: rgba(0x19171cff).into(),
|
||||||
|
tab_bar: rgba(0x221f26ff).into(),
|
||||||
|
editor: rgba(0x19171cff).into(),
|
||||||
|
editor_subheader: rgba(0x221f26ff).into(),
|
||||||
|
editor_active_line: rgba(0x221f26ff).into(),
|
||||||
|
terminal: rgba(0x19171cff).into(),
|
||||||
|
image_fallback_background: rgba(0x3a353fff).into(),
|
||||||
|
git_created: rgba(0x2b9292ff).into(),
|
||||||
|
git_modified: rgba(0x566ddaff).into(),
|
||||||
|
git_deleted: rgba(0xbe4677ff).into(),
|
||||||
|
git_conflict: rgba(0xa06d3aff).into(),
|
||||||
|
git_ignored: rgba(0x756f7eff).into(),
|
||||||
|
git_renamed: rgba(0xa06d3aff).into(),
|
||||||
|
players: [
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x566ddaff).into(),
|
||||||
|
selection: rgba(0x566dda3d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x2b9292ff).into(),
|
||||||
|
selection: rgba(0x2b92923d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xbf41bfff).into(),
|
||||||
|
selection: rgba(0xbf41bf3d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xaa563bff).into(),
|
||||||
|
selection: rgba(0xaa563b3d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x955ae6ff).into(),
|
||||||
|
selection: rgba(0x955ae63d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x3a8bc6ff).into(),
|
||||||
|
selection: rgba(0x3a8bc63d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xbe4677ff).into(),
|
||||||
|
selection: rgba(0xbe46773d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xa06d3aff).into(),
|
||||||
|
selection: rgba(0xa06d3a3d).into(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
136
crates/theme2/src/themes/atelier_cave_light.rs
Normal file
136
crates/theme2/src/themes/atelier_cave_light.rs
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
use gpui2::rgba;
|
||||||
|
|
||||||
|
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||||
|
|
||||||
|
pub fn atelier_cave_light() -> Theme {
|
||||||
|
Theme {
|
||||||
|
metadata: ThemeMetadata {
|
||||||
|
name: "Atelier Cave Light".into(),
|
||||||
|
is_light: true,
|
||||||
|
},
|
||||||
|
transparent: rgba(0x00000000).into(),
|
||||||
|
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||||
|
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||||
|
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||||
|
border: rgba(0x8f8b96ff).into(),
|
||||||
|
border_variant: rgba(0x8f8b96ff).into(),
|
||||||
|
border_focused: rgba(0xc8c7f2ff).into(),
|
||||||
|
border_transparent: rgba(0x00000000).into(),
|
||||||
|
elevated_surface: rgba(0xbfbcc5ff).into(),
|
||||||
|
surface: rgba(0xe6e3ebff).into(),
|
||||||
|
background: rgba(0xbfbcc5ff).into(),
|
||||||
|
filled_element: rgba(0xbfbcc5ff).into(),
|
||||||
|
filled_element_hover: rgba(0xffffff1e).into(),
|
||||||
|
filled_element_active: rgba(0xffffff28).into(),
|
||||||
|
filled_element_selected: rgba(0xe1e0f9ff).into(),
|
||||||
|
filled_element_disabled: rgba(0x00000000).into(),
|
||||||
|
ghost_element: rgba(0x00000000).into(),
|
||||||
|
ghost_element_hover: rgba(0xffffff14).into(),
|
||||||
|
ghost_element_active: rgba(0xffffff1e).into(),
|
||||||
|
ghost_element_selected: rgba(0xe1e0f9ff).into(),
|
||||||
|
ghost_element_disabled: rgba(0x00000000).into(),
|
||||||
|
text: rgba(0x19171cff).into(),
|
||||||
|
text_muted: rgba(0x5a5462ff).into(),
|
||||||
|
text_placeholder: rgba(0xbd4677ff).into(),
|
||||||
|
text_disabled: rgba(0x6e6876ff).into(),
|
||||||
|
text_accent: rgba(0x586cdaff).into(),
|
||||||
|
icon_muted: rgba(0x5a5462ff).into(),
|
||||||
|
syntax: SyntaxTheme {
|
||||||
|
highlights: vec![
|
||||||
|
("link_text".into(), rgba(0xaa573cff).into()),
|
||||||
|
("string".into(), rgba(0x299292ff).into()),
|
||||||
|
("emphasis".into(), rgba(0x586cdaff).into()),
|
||||||
|
("label".into(), rgba(0x586cdaff).into()),
|
||||||
|
("property".into(), rgba(0xbe4677ff).into()),
|
||||||
|
("emphasis.strong".into(), rgba(0x586cdaff).into()),
|
||||||
|
("constant".into(), rgba(0x2b9292ff).into()),
|
||||||
|
(
|
||||||
|
"function.special.definition".into(),
|
||||||
|
rgba(0xa06d3aff).into(),
|
||||||
|
),
|
||||||
|
("embedded".into(), rgba(0x19171cff).into()),
|
||||||
|
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
|
||||||
|
("function".into(), rgba(0x576cdbff).into()),
|
||||||
|
("tag".into(), rgba(0x586cdaff).into()),
|
||||||
|
("number".into(), rgba(0xaa563bff).into()),
|
||||||
|
("primary".into(), rgba(0x26232aff).into()),
|
||||||
|
("text.literal".into(), rgba(0xaa573cff).into()),
|
||||||
|
("variant".into(), rgba(0xa06d3aff).into()),
|
||||||
|
("type".into(), rgba(0xa06d3aff).into()),
|
||||||
|
("punctuation".into(), rgba(0x26232aff).into()),
|
||||||
|
("string.escape".into(), rgba(0x585260ff).into()),
|
||||||
|
("keyword".into(), rgba(0x9559e7ff).into()),
|
||||||
|
("title".into(), rgba(0x19171cff).into()),
|
||||||
|
("constructor".into(), rgba(0x586cdaff).into()),
|
||||||
|
("punctuation.list_marker".into(), rgba(0x26232aff).into()),
|
||||||
|
("string.special".into(), rgba(0xbf3fbfff).into()),
|
||||||
|
("operator".into(), rgba(0x585260ff).into()),
|
||||||
|
("function.method".into(), rgba(0x576cdbff).into()),
|
||||||
|
("link_uri".into(), rgba(0x2b9292ff).into()),
|
||||||
|
("variable.special".into(), rgba(0x9559e7ff).into()),
|
||||||
|
("hint".into(), rgba(0x776d9dff).into()),
|
||||||
|
("punctuation.bracket".into(), rgba(0x585260ff).into()),
|
||||||
|
("string.special.symbol".into(), rgba(0x299292ff).into()),
|
||||||
|
("predictive".into(), rgba(0x887fafff).into()),
|
||||||
|
("attribute".into(), rgba(0x586cdaff).into()),
|
||||||
|
("enum".into(), rgba(0xaa573cff).into()),
|
||||||
|
("preproc".into(), rgba(0x19171cff).into()),
|
||||||
|
("boolean".into(), rgba(0x2b9292ff).into()),
|
||||||
|
("variable".into(), rgba(0x26232aff).into()),
|
||||||
|
("comment.doc".into(), rgba(0x585260ff).into()),
|
||||||
|
("string.regex".into(), rgba(0x388bc6ff).into()),
|
||||||
|
("punctuation.delimiter".into(), rgba(0x585260ff).into()),
|
||||||
|
("comment".into(), rgba(0x7d7787ff).into()),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
status_bar: rgba(0xbfbcc5ff).into(),
|
||||||
|
title_bar: rgba(0xbfbcc5ff).into(),
|
||||||
|
toolbar: rgba(0xefecf4ff).into(),
|
||||||
|
tab_bar: rgba(0xe6e3ebff).into(),
|
||||||
|
editor: rgba(0xefecf4ff).into(),
|
||||||
|
editor_subheader: rgba(0xe6e3ebff).into(),
|
||||||
|
editor_active_line: rgba(0xe6e3ebff).into(),
|
||||||
|
terminal: rgba(0xefecf4ff).into(),
|
||||||
|
image_fallback_background: rgba(0xbfbcc5ff).into(),
|
||||||
|
git_created: rgba(0x2b9292ff).into(),
|
||||||
|
git_modified: rgba(0x586cdaff).into(),
|
||||||
|
git_deleted: rgba(0xbd4677ff).into(),
|
||||||
|
git_conflict: rgba(0xa06e3bff).into(),
|
||||||
|
git_ignored: rgba(0x6e6876ff).into(),
|
||||||
|
git_renamed: rgba(0xa06e3bff).into(),
|
||||||
|
players: [
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x586cdaff).into(),
|
||||||
|
selection: rgba(0x586cda3d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x2b9292ff).into(),
|
||||||
|
selection: rgba(0x2b92923d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xbf41bfff).into(),
|
||||||
|
selection: rgba(0xbf41bf3d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xaa573cff).into(),
|
||||||
|
selection: rgba(0xaa573c3d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x955ae6ff).into(),
|
||||||
|
selection: rgba(0x955ae63d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x3a8bc6ff).into(),
|
||||||
|
selection: rgba(0x3a8bc63d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xbd4677ff).into(),
|
||||||
|
selection: rgba(0xbd46773d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xa06e3bff).into(),
|
||||||
|
selection: rgba(0xa06e3b3d).into(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
136
crates/theme2/src/themes/atelier_dune_dark.rs
Normal file
136
crates/theme2/src/themes/atelier_dune_dark.rs
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
use gpui2::rgba;
|
||||||
|
|
||||||
|
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||||
|
|
||||||
|
pub fn atelier_dune_dark() -> Theme {
|
||||||
|
Theme {
|
||||||
|
metadata: ThemeMetadata {
|
||||||
|
name: "Atelier Dune Dark".into(),
|
||||||
|
is_light: false,
|
||||||
|
},
|
||||||
|
transparent: rgba(0x00000000).into(),
|
||||||
|
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||||
|
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||||
|
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||||
|
border: rgba(0x6c695cff).into(),
|
||||||
|
border_variant: rgba(0x6c695cff).into(),
|
||||||
|
border_focused: rgba(0x262f56ff).into(),
|
||||||
|
border_transparent: rgba(0x00000000).into(),
|
||||||
|
elevated_surface: rgba(0x45433bff).into(),
|
||||||
|
surface: rgba(0x262622ff).into(),
|
||||||
|
background: rgba(0x45433bff).into(),
|
||||||
|
filled_element: rgba(0x45433bff).into(),
|
||||||
|
filled_element_hover: rgba(0xffffff1e).into(),
|
||||||
|
filled_element_active: rgba(0xffffff28).into(),
|
||||||
|
filled_element_selected: rgba(0x171e38ff).into(),
|
||||||
|
filled_element_disabled: rgba(0x00000000).into(),
|
||||||
|
ghost_element: rgba(0x00000000).into(),
|
||||||
|
ghost_element_hover: rgba(0xffffff14).into(),
|
||||||
|
ghost_element_active: rgba(0xffffff1e).into(),
|
||||||
|
ghost_element_selected: rgba(0x171e38ff).into(),
|
||||||
|
ghost_element_disabled: rgba(0x00000000).into(),
|
||||||
|
text: rgba(0xfefbecff).into(),
|
||||||
|
text_muted: rgba(0xa4a08bff).into(),
|
||||||
|
text_placeholder: rgba(0xd73837ff).into(),
|
||||||
|
text_disabled: rgba(0x8f8b77ff).into(),
|
||||||
|
text_accent: rgba(0x6684e0ff).into(),
|
||||||
|
icon_muted: rgba(0xa4a08bff).into(),
|
||||||
|
syntax: SyntaxTheme {
|
||||||
|
highlights: vec![
|
||||||
|
("constructor".into(), rgba(0x6684e0ff).into()),
|
||||||
|
("punctuation".into(), rgba(0xe8e4cfff).into()),
|
||||||
|
("punctuation.delimiter".into(), rgba(0xa6a28cff).into()),
|
||||||
|
("string.special".into(), rgba(0xd43451ff).into()),
|
||||||
|
("string.escape".into(), rgba(0xa6a28cff).into()),
|
||||||
|
("comment".into(), rgba(0x7d7a68ff).into()),
|
||||||
|
("enum".into(), rgba(0xb65611ff).into()),
|
||||||
|
("variable.special".into(), rgba(0xb854d4ff).into()),
|
||||||
|
("primary".into(), rgba(0xe8e4cfff).into()),
|
||||||
|
("comment.doc".into(), rgba(0xa6a28cff).into()),
|
||||||
|
("label".into(), rgba(0x6684e0ff).into()),
|
||||||
|
("operator".into(), rgba(0xa6a28cff).into()),
|
||||||
|
("string".into(), rgba(0x5fac38ff).into()),
|
||||||
|
("variant".into(), rgba(0xae9512ff).into()),
|
||||||
|
("variable".into(), rgba(0xe8e4cfff).into()),
|
||||||
|
("function.method".into(), rgba(0x6583e1ff).into()),
|
||||||
|
(
|
||||||
|
"function.special.definition".into(),
|
||||||
|
rgba(0xae9512ff).into(),
|
||||||
|
),
|
||||||
|
("string.regex".into(), rgba(0x1ead82ff).into()),
|
||||||
|
("emphasis.strong".into(), rgba(0x6684e0ff).into()),
|
||||||
|
("punctuation.special".into(), rgba(0xd43451ff).into()),
|
||||||
|
("punctuation.bracket".into(), rgba(0xa6a28cff).into()),
|
||||||
|
("link_text".into(), rgba(0xb65611ff).into()),
|
||||||
|
("link_uri".into(), rgba(0x5fac39ff).into()),
|
||||||
|
("boolean".into(), rgba(0x5fac39ff).into()),
|
||||||
|
("hint".into(), rgba(0xb17272ff).into()),
|
||||||
|
("tag".into(), rgba(0x6684e0ff).into()),
|
||||||
|
("function".into(), rgba(0x6583e1ff).into()),
|
||||||
|
("title".into(), rgba(0xfefbecff).into()),
|
||||||
|
("property".into(), rgba(0xd73737ff).into()),
|
||||||
|
("type".into(), rgba(0xae9512ff).into()),
|
||||||
|
("constant".into(), rgba(0x5fac39ff).into()),
|
||||||
|
("attribute".into(), rgba(0x6684e0ff).into()),
|
||||||
|
("predictive".into(), rgba(0x9c6262ff).into()),
|
||||||
|
("string.special.symbol".into(), rgba(0x5fac38ff).into()),
|
||||||
|
("punctuation.list_marker".into(), rgba(0xe8e4cfff).into()),
|
||||||
|
("emphasis".into(), rgba(0x6684e0ff).into()),
|
||||||
|
("keyword".into(), rgba(0xb854d4ff).into()),
|
||||||
|
("text.literal".into(), rgba(0xb65611ff).into()),
|
||||||
|
("number".into(), rgba(0xb65610ff).into()),
|
||||||
|
("preproc".into(), rgba(0xfefbecff).into()),
|
||||||
|
("embedded".into(), rgba(0xfefbecff).into()),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
status_bar: rgba(0x45433bff).into(),
|
||||||
|
title_bar: rgba(0x45433bff).into(),
|
||||||
|
toolbar: rgba(0x20201dff).into(),
|
||||||
|
tab_bar: rgba(0x262622ff).into(),
|
||||||
|
editor: rgba(0x20201dff).into(),
|
||||||
|
editor_subheader: rgba(0x262622ff).into(),
|
||||||
|
editor_active_line: rgba(0x262622ff).into(),
|
||||||
|
terminal: rgba(0x20201dff).into(),
|
||||||
|
image_fallback_background: rgba(0x45433bff).into(),
|
||||||
|
git_created: rgba(0x5fac39ff).into(),
|
||||||
|
git_modified: rgba(0x6684e0ff).into(),
|
||||||
|
git_deleted: rgba(0xd73837ff).into(),
|
||||||
|
git_conflict: rgba(0xae9414ff).into(),
|
||||||
|
git_ignored: rgba(0x8f8b77ff).into(),
|
||||||
|
git_renamed: rgba(0xae9414ff).into(),
|
||||||
|
players: [
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x6684e0ff).into(),
|
||||||
|
selection: rgba(0x6684e03d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x5fac39ff).into(),
|
||||||
|
selection: rgba(0x5fac393d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xd43651ff).into(),
|
||||||
|
selection: rgba(0xd436513d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xb65611ff).into(),
|
||||||
|
selection: rgba(0xb656113d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xb854d3ff).into(),
|
||||||
|
selection: rgba(0xb854d33d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0x20ad83ff).into(),
|
||||||
|
selection: rgba(0x20ad833d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xd73837ff).into(),
|
||||||
|
selection: rgba(0xd738373d).into(),
|
||||||
|
},
|
||||||
|
PlayerTheme {
|
||||||
|
cursor: rgba(0xae9414ff).into(),
|
||||||
|
selection: rgba(0xae94143d).into(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user