Merge branch 'zed2' into zed2-workspace

This commit is contained in:
Mikayla 2023-10-30 16:53:21 -07:00
commit 3dadfb8ba8
No known key found for this signature in database
185 changed files with 12176 additions and 4816 deletions

42
Cargo.lock generated
View File

@ -108,6 +108,33 @@ dependencies = [
"util", "util",
] ]
[[package]]
name = "ai2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bincode",
"futures 0.3.28",
"gpui2",
"isahc",
"language2",
"lazy_static",
"log",
"matrixmultiply",
"ordered-float 2.10.0",
"parking_lot 0.11.2",
"parse_duration",
"postage",
"rand 0.8.5",
"regex",
"rusqlite",
"serde",
"serde_json",
"tiktoken-rs",
"util",
]
[[package]] [[package]]
name = "alacritty_config" name = "alacritty_config"
version = "0.1.2-dev" version = "0.1.2-dev"
@ -1138,7 +1165,7 @@ dependencies = [
"audio2", "audio2",
"client2", "client2",
"collections", "collections",
"fs", "fs2",
"futures 0.3.28", "futures 0.3.28",
"gpui2", "gpui2",
"language2", "language2",
@ -4795,6 +4822,13 @@ dependencies = [
"gpui", "gpui",
] ]
[[package]]
name = "menu2"
version = "0.1.0"
dependencies = [
"gpui2",
]
[[package]] [[package]]
name = "metal" name = "metal"
version = "0.21.0" version = "0.21.0"
@ -6000,7 +6034,7 @@ dependencies = [
"anyhow", "anyhow",
"client2", "client2",
"collections", "collections",
"fs", "fs2",
"futures 0.3.28", "futures 0.3.28",
"gpui2", "gpui2",
"language2", "language2",
@ -6167,7 +6201,7 @@ dependencies = [
"ctor", "ctor",
"db2", "db2",
"env_logger 0.9.3", "env_logger 0.9.3",
"fs", "fs2",
"fsevent", "fsevent",
"futures 0.3.28", "futures 0.3.28",
"fuzzy2", "fuzzy2",
@ -8740,6 +8774,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"clap 4.4.4", "clap 4.4.4",
"convert_case 0.6.0",
"gpui2", "gpui2",
"log", "log",
"rust-embed", "rust-embed",
@ -10932,6 +10967,7 @@ dependencies = [
name = "zed2" name = "zed2"
version = "0.109.0" version = "0.109.0"
dependencies = [ dependencies = [
"ai2",
"anyhow", "anyhow",
"async-compression", "async-compression",
"async-recursion 0.3.2", "async-recursion 0.3.2",

View File

@ -59,6 +59,7 @@ members = [
"crates/lsp2", "crates/lsp2",
"crates/media", "crates/media",
"crates/menu", "crates/menu",
"crates/menu2",
"crates/multi_buffer", "crates/multi_buffer",
"crates/node_runtime", "crates/node_runtime",
"crates/notifications", "crates/notifications",

38
crates/Cargo.toml Normal file
View File

@ -0,0 +1,38 @@
[package]
name = "ai"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
language = { path = "../language" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View File

@ -8,6 +8,9 @@ publish = false
path = "src/ai.rs" path = "src/ai.rs"
doctest = false doctest = false
[features]
test-support = []
[dependencies] [dependencies]
gpui = { path = "../gpui" } gpui = { path = "../gpui" }
util = { path = "../util" } util = { path = "../util" }

View File

@ -1,4 +1,8 @@
pub mod auth;
pub mod completion; pub mod completion;
pub mod embedding; pub mod embedding;
pub mod models; pub mod models;
pub mod templates; pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

15
crates/ai/src/auth.rs Normal file
View File

@ -0,0 +1,15 @@
use gpui::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
fn delete_credentials(&self, cx: &AppContext);
}

View File

@ -1,214 +1,23 @@
use anyhow::{anyhow, Result}; use anyhow::Result;
use futures::{ use futures::{future::BoxFuture, stream::BoxStream};
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::executor::Background;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
io,
sync::Arc,
};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; use crate::{auth::CredentialProvider, models::LanguageModel};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] pub trait CompletionRequest: Send + Sync {
#[serde(rename_all = "lowercase")] fn data(&self) -> serde_json::Result<String>;
pub enum Role {
User,
Assistant,
System,
} }
impl Role { pub trait CompletionProvider: CredentialProvider {
pub fn cycle(&mut self) { fn base_model(&self) -> Box<dyn LanguageModel>;
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
pub trait CompletionProvider {
fn complete( fn complete(
&self, &self,
prompt: OpenAIRequest, prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>; ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
} }
pub struct OpenAICompletionProvider { impl Clone for Box<dyn CompletionProvider> {
api_key: String, fn clone(&self) -> Box<dyn CompletionProvider> {
executor: Arc<Background>, self.box_clone()
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
} }
} }

View File

@ -1,32 +1,13 @@
use anyhow::{anyhow, Result}; use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use parking_lot::Mutex;
use parse_duration::parse;
use postage::watch;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql; use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::completion::OPENAI_API_URL; use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>); pub struct Embedding(pub Vec<f32>);
@ -87,301 +68,14 @@ impl Embedding {
} }
} }
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
#[async_trait] #[async_trait]
pub trait EmbeddingProvider: Sync + Send { pub trait EmbeddingProvider: CredentialProvider {
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>; fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch( async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
&self,
spans: Vec<String>,
api_key: Option<String>,
) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize; fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
fn rate_limit_expiration(&self) -> Option<Instant>; fn rate_limit_expiration(&self) -> Option<Instant>;
} }
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Dummy API KEY".to_string())
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(
&self,
spans: Vec<String>,
_api_key: Option<String>,
) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
return Ok(vec![dummy_vec; spans.len()]);
}
fn max_tokens_per_batch(&self) -> usize {
OPENAI_INPUT_LIMIT
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let token_count = tokens.len();
let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
new_input.ok().unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, tokens.len())
}
}
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings {
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
}
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens.clone())
.ok()
.unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, tokens.len())
}
async fn embed_batch(
&self,
spans: Vec<String>,
api_key: Option<String>,
) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let Some(api_key) = api_key else {
return Err(anyhow!("no open ai key provided"));
};
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,66 +1,16 @@
use anyhow::anyhow; pub enum TruncationDirection {
use tiktoken_rs::CoreBPE; Start,
use util::ResultExt; End,
}
pub trait LanguageModel { pub trait LanguageModel {
fn name(&self) -> String; fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>; fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>; fn truncate(
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>; &self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>; fn capacity(&self) -> anyhow::Result<usize>;
} }
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
bpe.decode(tokens[..length].to_vec())
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
bpe.decode(tokens[length..].to_vec())
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -6,7 +6,7 @@ use language::BufferSnapshot;
use util::ResultExt; use util::ResultExt;
use crate::models::LanguageModel; use crate::models::LanguageModel;
use crate::templates::repository_context::PromptCodeSnippet; use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType { pub(crate) enum PromptFileType {
Text, Text,
@ -125,6 +125,9 @@ impl PromptChain {
#[cfg(test)] #[cfg(test)]
pub(crate) mod tests { pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*; use super::*;
#[test] #[test]
@ -141,7 +144,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?; let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length { if let Some(max_token_length) = max_token_length {
if token_count > max_token_length { if token_count > max_token_length {
content = args.model.truncate(&content, max_token_length)?; content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length; token_count = max_token_length;
} }
} }
@ -162,7 +169,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?; let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length { if let Some(max_token_length) = max_token_length {
if token_count > max_token_length { if token_count > max_token_length {
content = args.model.truncate(&content, max_token_length)?; content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length; token_count = max_token_length;
} }
} }
@ -171,38 +182,7 @@ pub(crate) mod tests {
} }
} }
#[derive(Clone)] let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
struct DummyLanguageModel {
capacity: usize,
}
impl LanguageModel for DummyLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
anyhow::Ok(
content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
)
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
anyhow::Ok(
content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
)
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
let args = PromptArguments { let args = PromptArguments {
model: model.clone(), model: model.clone(),
language_name: None, language_name: None,
@ -238,7 +218,7 @@ pub(crate) mod tests {
// Testing with Truncation Off // Testing with Truncation Off
// Should ignore capacity and return all prompts // Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 }); let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments { let args = PromptArguments {
model: model.clone(), model: model.clone(),
language_name: None, language_name: None,
@ -275,7 +255,7 @@ pub(crate) mod tests {
// Testing with Truncation Off // Testing with Truncation Off
// Should ignore capacity and return all prompts // Should ignore capacity and return all prompts
let capacity = 20; let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity }); let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments { let args = PromptArguments {
model: model.clone(), model: model.clone(),
language_name: None, language_name: None,
@ -311,7 +291,7 @@ pub(crate) mod tests {
// Change Ordering of Prompts Based on Priority // Change Ordering of Prompts Based on Priority
let capacity = 120; let capacity = 120;
let reserved_tokens = 10; let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity }); let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments { let args = PromptArguments {
model: model.clone(), model: model.clone(),
language_name: None, language_name: None,

View File

@ -3,8 +3,9 @@ use language::BufferSnapshot;
use language::ToOffset; use language::ToOffset;
use crate::models::LanguageModel; use crate::models::LanguageModel;
use crate::templates::base::PromptArguments; use crate::models::TruncationDirection;
use crate::templates::base::PromptTemplate; use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write; use std::fmt::Write;
use std::ops::Range; use std::ops::Range;
use std::sync::Arc; use std::sync::Arc;
@ -70,8 +71,9 @@ fn retrieve_context(
}; };
let truncated_start_window = let truncated_start_window =
model.truncate_start(&start_window, start_goal_tokens)?; model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!( writeln!(
prompt, prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}" "{truncated_start_window}{selected_window}{truncated_end_window}"
@ -89,7 +91,7 @@ fn retrieve_context(
if let Some(max_token_count) = max_token_count { if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count { if model.count_tokens(&prompt)? > max_token_count {
truncated = true; truncated = true;
prompt = model.truncate(&prompt, max_token_count)?; prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
} }
} }
} }
@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
// Really dumb truncation strategy // Really dumb truncation strategy
if let Some(max_tokens) = max_token_length { if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(&prompt, max_tokens)?; prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
} }
let token_count = args.model.count_tokens(&prompt)?; let token_count = args.model.count_tokens(&prompt)?;

View File

@ -1,4 +1,4 @@
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow; use anyhow::anyhow;
use std::fmt::Write; use std::fmt::Write;
@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
// Really dumb truncation strategy // Really dumb truncation strategy
if let Some(max_tokens) = max_token_length { if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(&prompt, max_tokens)?; prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
} }
let token_count = args.model.count_tokens(&prompt)?; let token_count = args.model.count_tokens(&prompt)?;

View File

@ -1,4 +1,4 @@
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write; use std::fmt::Write;
pub struct EngineerPreamble {} pub struct EngineerPreamble {}

View File

@ -1,4 +1,4 @@
use crate::templates::base::{PromptArguments, PromptTemplate}; use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write; use std::fmt::Write;
use std::{ops::Range, path::PathBuf}; use std::{ops::Range, path::PathBuf};

View File

@ -0,0 +1 @@
pub mod open_ai;

View File

@ -0,0 +1,298 @@
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::{executor::Background, AppContext};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
credential: ProviderCredential,
executor: Arc<Background>,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Clone)]
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
credential,
executor,
}
}
}
impl CredentialProvider for OpenAICompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
let mut credential = self.credential.write();
match *credential {
ProviderCredential::Credentials { .. } => {
return credential.clone();
}
_ => {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
*credential = ProviderCredential::Credentials { api_key };
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
*credential = ProviderCredential::Credentials { api_key };
}
} else {
};
}
}
credential.clone()
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential.clone() {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
*self.credential.write() = credential;
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the langauge model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -0,0 +1,306 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider {
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
impl CredentialProvider for OpenAIEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
let mut credential = self.credential.write();
match *credential {
ProviderCredential::Credentials { .. } => {
return credential.clone();
}
_ => {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
*credential = ProviderCredential::Credentials { api_key };
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
*credential = ProviderCredential::Credentials { api_key };
}
} else {
};
}
}
credential.clone()
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential.clone() {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
*self.credential.write() = credential;
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View File

@ -0,0 +1,9 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

View File

@ -0,0 +1,57 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)]
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -0,0 +1,11 @@
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

191
crates/ai/src/test.rs Normal file
View File

@ -0,0 +1,191 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

38
crates/ai2/Cargo.toml Normal file
View File

@ -0,0 +1,38 @@
[package]
name = "ai2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai2.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui2 = { path = "../gpui2" }
util = { path = "../util" }
language2 = { path = "../language2" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
gpui2 = { path = "../gpui2", features = ["test-support"] }

8
crates/ai2/src/ai2.rs Normal file
View File

@ -0,0 +1,8 @@
pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

17
crates/ai2/src/auth.rs Normal file
View File

@ -0,0 +1,17 @@
use async_trait::async_trait;
use gpui2::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
#[async_trait]
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
async fn delete_credentials(&self, cx: &mut AppContext);
}

View File

@ -0,0 +1,23 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::{auth::CredentialProvider, models::LanguageModel};
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

123
crates/ai2/src/embedding.rs Normal file
View File

@ -0,0 +1,123 @@
use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
OrderedFloat(result)
}
}
#[async_trait]
pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui2::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}

16
crates/ai2/src/models.rs Normal file
View File

@ -0,0 +1,16 @@
pub enum TruncationDirection {
Start,
End,
}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View File

@ -0,0 +1,330 @@
use std::cmp::Reverse;
use std::ops::Range;
use std::sync::Arc;
use language2::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
Code,
}
// TODO: Set this up to manage for defaults well
pub struct PromptArguments {
pub model: Arc<dyn LanguageModel>,
pub user_prompt: Option<String>,
pub language_name: Option<String>,
pub project_name: Option<String>,
pub snippets: Vec<PromptCodeSnippet>,
pub reserved_tokens: usize,
pub buffer: Option<BufferSnapshot>,
pub selected_range: Option<Range<usize>>,
}
impl PromptArguments {
pub(crate) fn get_file_type(&self) -> PromptFileType {
if self
.language_name
.as_ref()
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
.unwrap_or(true)
{
PromptFileType::Code
} else {
PromptFileType::Text
}
}
}
pub trait PromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)>;
}
#[repr(i8)]
#[derive(PartialEq, Eq, Ord)]
pub enum PromptPriority {
Mandatory, // Ignores truncation
Ordered { order: usize }, // Truncates based on priority
}
impl PartialOrd for PromptPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
}
}
}
pub struct PromptChain {
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
}
impl PromptChain {
pub fn new(
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
) -> Self {
PromptChain { args, templates }
}
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
// Argsort based on Prompt Priority
let seperator = "\n";
let seperator_tokens = self.args.model.count_tokens(seperator)?;
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
// If Truncate
let mut tokens_outstanding = if truncate {
Some(self.args.model.capacity()? - self.args.reserved_tokens)
} else {
None
};
let mut prompts = vec!["".to_string(); sorted_indices.len()];
for idx in sorted_indices {
let (_, template) = &self.templates[idx];
if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err()
{
if template_prompt != "" {
prompts[idx] = template_prompt;
if let Some(remaining_tokens) = tokens_outstanding {
let new_tokens = prompt_token_count + seperator_tokens;
tokens_outstanding = if remaining_tokens > new_tokens {
Some(remaining_tokens - new_tokens)
} else {
Some(0)
};
}
}
}
}
prompts.retain(|x| x != "");
let full_prompt = prompts.join(seperator);
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
anyhow::Ok((prompts.join(seperator), total_token_count))
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*;
#[test]
pub fn test_prompt_chain() {
struct TestPromptTemplate {}
impl PromptTemplate for TestPromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
struct TestLowPriorityTemplate {}
impl PromptTemplate for TestLowPriorityTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a low priority test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 2 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(prompt, "This is a test promp".to_string());
assert_eq!(token_count, capacity);
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Mandatory,
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(
prompt,
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
.to_string()
);
assert_eq!(token_count, capacity - reserved_tokens);
}
}

View File

@ -0,0 +1,164 @@
use anyhow::anyhow;
use language2::BufferSnapshot;
use language2::ToOffset;
use crate::models::LanguageModel;
use crate::models::TruncationDirection;
use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
impl PromptTemplate for FileContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
if let Some(buffer) = &args.buffer {
let mut prompt = String::new();
// Add Initial Preamble
// TODO: Do we want to add the path in here?
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
let language_name = args
.language_name
.clone()
.unwrap_or("".to_string())
.to_lowercase();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if truncated {
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
}
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
} else {
Err(anyhow!("no buffer provided to retrieve file context from"))
}
}
}

View File

@ -0,0 +1,99 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
pub fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}
pub struct GenerateInlineContent {}
impl PromptTemplate for GenerateInlineContent {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let Some(user_prompt) = &args.user_prompt else {
return Err(anyhow!("user prompt not provided"));
};
let file_type = args.get_file_type();
let content_type = match &file_type {
PromptFileType::Code => "code",
PromptFileType::Text => "text",
};
let mut prompt = String::new();
if let Some(selected_range) = &args.selected_range {
if selected_range.start == selected_range.end {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
capitalize(content_type)
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
}
} else {
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}"
)
.unwrap();
}
if let Some(language_name) = &args.language_name {
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
match file_type {
PromptFileType::Code => {
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
}
_ => {}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}

View File

@ -0,0 +1,5 @@
pub mod base;
pub mod file_context;
pub mod generate;
pub mod preamble;
pub mod repository_context;

View File

@ -0,0 +1,52 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut prompts = Vec::new();
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
"You are an expert {}engineer.",
args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
prompts.push("You are an expert engineer.".to_string());
}
}
if let Some(project_name) = args.project_name.clone() {
prompts.push(format!(
"You are currently working inside the '{project_name}' project in code editor Zed."
));
}
if let Some(mut remaining_tokens) = max_token_length {
let mut prompt = String::new();
let mut total_count = 0;
for prompt_piece in prompts {
let prompt_token_count =
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
if remaining_tokens > prompt_token_count {
writeln!(prompt, "{prompt_piece}").unwrap();
remaining_tokens -= prompt_token_count;
total_count += prompt_token_count;
}
}
anyhow::Ok((prompt, total_count))
} else {
let prompt = prompts.join("\n");
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}
}

View File

@ -0,0 +1,98 @@
use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui2::{AsyncAppContext, Model};
use language2::{Anchor, Buffer};
#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(
buffer: Model<Buffer>,
range: Range<Anchor>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Self> {
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
let language_name = buffer
.language()
.and_then(|language| Some(language.name().to_string().to_lowercase()));
let file_path = buffer
.file()
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
})?;
anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
})
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.and_then(|path| Some(path.to_string_lossy().to_string()))
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
pub struct RepositoryContext {}
impl PromptTemplate for RepositoryContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
let mut prompt = String::new();
let mut remaining_tokens = max_token_length.clone();
let seperator_token_length = args.model.count_tokens("\n")?;
for snippet in &args.snippets {
let mut snippet_prompt = template.to_string();
let content = snippet.to_string();
writeln!(snippet_prompt, "{content}").unwrap();
let token_count = args.model.count_tokens(&snippet_prompt)?;
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
if let Some(tokens_left) = remaining_tokens {
if tokens_left >= token_count {
writeln!(prompt, "{snippet_prompt}").unwrap();
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
{
Some(tokens_left - token_count - seperator_token_length)
} else {
Some(0)
};
}
} else {
writeln!(prompt, "{snippet_prompt}").unwrap();
}
}
}
let total_token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, total_token_count))
}
}

View File

@ -0,0 +1 @@
pub mod open_ai;

View File

@ -0,0 +1,306 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui2::{AppContext, Executor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
credential: ProviderCredential,
executor: Arc<Executor>,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Clone)]
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Executor>,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
credential,
executor,
}
}
}
#[async_trait]
impl CredentialProvider for OpenAICompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = cx
.run_on_main(move |cx| match existing_credential {
ProviderCredential::Credentials { .. } => {
return existing_credential.clone();
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
return ProviderCredential::Credentials { api_key };
}
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
return ProviderCredential::Credentials { api_key };
} else {
return ProviderCredential::NoCredentials;
}
} else {
return ProviderCredential::NoCredentials;
}
}
})
.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
cx.run_on_main(move |cx| match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
})
.await;
}
async fn delete_credentials(&self, cx: &mut AppContext) {
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
.await;
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the langauge model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -0,0 +1,313 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui2::Executor;
use gpui2::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Executor>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider {
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
#[async_trait]
impl CredentialProvider for OpenAIEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = cx
.run_on_main(move |cx| match existing_credential {
ProviderCredential::Credentials { .. } => {
return existing_credential.clone();
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
return ProviderCredential::Credentials { api_key };
}
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
return ProviderCredential::Credentials { api_key };
} else {
return ProviderCredential::NoCredentials;
}
} else {
return ProviderCredential::NoCredentials;
}
}
})
.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
cx.run_on_main(move |cx| match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
})
.await;
}
async fn delete_credentials(&self, cx: &mut AppContext) {
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
.await;
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View File

@ -0,0 +1,9 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

View File

@ -0,0 +1,57 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)]
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -0,0 +1,11 @@
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

193
crates/ai2/src/test.rs Normal file
View File

@ -0,0 +1,193 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui2::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
async fn delete_credentials(&self, _cx: &mut AppContext) {}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
#[async_trait]
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
async fn delete_credentials(&self, _cx: &mut AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
[dev-dependencies] [dev-dependencies]
editor = { path = "../editor", features = ["test-support"] } editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] } project = { path = "../project", features = ["test-support"] }
ai = { path = "../ai", features = ["test-support"]}
ctor.workspace = true ctor.workspace = true
env_logger.workspace = true env_logger.workspace = true

View File

@ -4,7 +4,7 @@ mod codegen;
mod prompts; mod prompts;
mod streaming_diff; mod streaming_diff;
use ai::completion::Role; use ai::providers::open_ai::Role;
use anyhow::Result; use anyhow::Result;
pub use assistant_panel::AssistantPanel; pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel; use assistant_settings::OpenAIModel;

View File

@ -5,12 +5,14 @@ use crate::{
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage, SavedMessage,
}; };
use ai::{ use ai::{
completion::{ auth::ProviderCredential,
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, completion::{CompletionProvider, CompletionRequest},
}, providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
templates::repository_context::PromptCodeSnippet,
}; };
use ai::prompts::repository_context::PromptCodeSnippet;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings}; use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@ -43,8 +45,8 @@ use search::BufferSearchBar;
use semantic_index::{SemanticIndex, SemanticIndexStatus}; use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore; use settings::SettingsStore;
use std::{ use std::{
cell::{Cell, RefCell}, cell::Cell,
cmp, env, cmp,
fmt::Write, fmt::Write,
iter, iter,
ops::Range, ops::Range,
@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
cx.capture_action(ConversationEditor::copy); cx.capture_action(ConversationEditor::copy);
cx.add_action(ConversationEditor::split); cx.add_action(ConversationEditor::split);
cx.capture_action(ConversationEditor::cycle_message_role); cx.capture_action(ConversationEditor::cycle_message_role);
cx.add_action(AssistantPanel::save_api_key); cx.add_action(AssistantPanel::save_credentials);
cx.add_action(AssistantPanel::reset_api_key); cx.add_action(AssistantPanel::reset_credentials);
cx.add_action(AssistantPanel::toggle_zoom); cx.add_action(AssistantPanel::toggle_zoom);
cx.add_action(AssistantPanel::deploy); cx.add_action(AssistantPanel::deploy);
cx.add_action(AssistantPanel::select_next_match); cx.add_action(AssistantPanel::select_next_match);
@ -140,9 +142,8 @@ pub struct AssistantPanel {
zoomed: bool, zoomed: bool,
has_focus: bool, has_focus: bool,
toolbar: ViewHandle<Toolbar>, toolbar: ViewHandle<Toolbar>,
api_key: Rc<RefCell<Option<String>>>, completion_provider: Box<dyn CompletionProvider>,
api_key_editor: Option<ViewHandle<Editor>>, api_key_editor: Option<ViewHandle<Editor>>,
has_read_credentials: bool,
languages: Arc<LanguageRegistry>, languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
subscriptions: Vec<Subscription>, subscriptions: Vec<Subscription>,
@ -202,6 +203,11 @@ impl AssistantPanel {
}); });
let semantic_index = SemanticIndex::global(cx); let semantic_index = SemanticIndex::global(cx);
// Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider = Box::new(OpenAICompletionProvider::new(
"gpt-4",
cx.background().clone(),
));
let mut this = Self { let mut this = Self {
workspace: workspace_handle, workspace: workspace_handle,
@ -213,9 +219,8 @@ impl AssistantPanel {
zoomed: false, zoomed: false,
has_focus: false, has_focus: false,
toolbar, toolbar,
api_key: Rc::new(RefCell::new(None)), completion_provider,
api_key_editor: None, api_key_editor: None,
has_read_credentials: false,
languages: workspace.app_state().languages.clone(), languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(), fs: workspace.app_state().fs.clone(),
width: None, width: None,
@ -254,10 +259,7 @@ impl AssistantPanel {
cx: &mut ViewContext<Workspace>, cx: &mut ViewContext<Workspace>,
) { ) {
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) { let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
if this if this.update(cx, |assistant, _| assistant.has_credentials()) {
.update(cx, |assistant, cx| assistant.load_api_key(cx))
.is_some()
{
this this
} else { } else {
workspace.focus_panel::<AssistantPanel>(cx); workspace.focus_panel::<AssistantPanel>(cx);
@ -289,12 +291,6 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
project: &ModelHandle<Project>, project: &ModelHandle<Project>,
) { ) {
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
api_key
} else {
return;
};
let selection = editor.read(cx).selections.newest_anchor().clone(); let selection = editor.read(cx).selections.newest_anchor().clone();
if selection.start.excerpt_id != selection.end.excerpt_id { if selection.start.excerpt_id != selection.end.excerpt_id {
return; return;
@ -325,10 +321,13 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new( let provider = Arc::new(OpenAICompletionProvider::new(
api_key, "gpt-4",
cx.background().clone(), cx.background().clone(),
)); ));
// Retrieve Credentials Authenticates the Provider
// provider.retrieve_credentials(cx);
let codegen = cx.add_model(|cx| { let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
}); });
@ -745,13 +744,14 @@ impl AssistantPanel {
content: prompt, content: prompt,
}); });
let request = OpenAIRequest { let request = Box::new(OpenAIRequest {
model: model.full_name().into(), model: model.full_name().into(),
messages, messages,
stream: true, stream: true,
stop: vec!["|END|>".to_string()], stop: vec!["|END|>".to_string()],
temperature, temperature,
}; });
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
anyhow::Ok(()) anyhow::Ok(())
}) })
@ -811,7 +811,7 @@ impl AssistantPanel {
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> { fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let editor = cx.add_view(|cx| { let editor = cx.add_view(|cx| {
ConversationEditor::new( ConversationEditor::new(
self.api_key.clone(), self.completion_provider.clone(),
self.languages.clone(), self.languages.clone(),
self.fs.clone(), self.fs.clone(),
self.workspace.clone(), self.workspace.clone(),
@ -870,17 +870,19 @@ impl AssistantPanel {
} }
} }
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) { fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
if let Some(api_key) = self if let Some(api_key) = self
.api_key_editor .api_key_editor
.as_ref() .as_ref()
.map(|editor| editor.read(cx).text(cx)) .map(|editor| editor.read(cx).text(cx))
{ {
if !api_key.is_empty() { if !api_key.is_empty() {
cx.platform() let credential = ProviderCredential::Credentials {
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) api_key: api_key.clone(),
.log_err(); };
*self.api_key.borrow_mut() = Some(api_key);
self.completion_provider.save_credentials(cx, credential);
self.api_key_editor.take(); self.api_key_editor.take();
cx.focus_self(); cx.focus_self();
cx.notify(); cx.notify();
@ -890,9 +892,8 @@ impl AssistantPanel {
} }
} }
fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) { fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err(); self.completion_provider.delete_credentials(cx);
self.api_key.take();
self.api_key_editor = Some(build_api_key_editor(cx)); self.api_key_editor = Some(build_api_key_editor(cx));
cx.focus_self(); cx.focus_self();
cx.notify(); cx.notify();
@ -1151,13 +1152,12 @@ impl AssistantPanel {
let fs = self.fs.clone(); let fs = self.fs.clone();
let workspace = self.workspace.clone(); let workspace = self.workspace.clone();
let api_key = self.api_key.clone();
let languages = self.languages.clone(); let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?; let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?; let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.add_model(|cx| { let conversation = cx.add_model(|cx| {
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx) Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
}); });
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened // If, by the time we've loaded the conversation, the user has already opened
@ -1181,30 +1181,12 @@ impl AssistantPanel {
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
} }
fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> { fn has_credentials(&mut self) -> bool {
if self.api_key.borrow().is_none() && !self.has_read_credentials { self.completion_provider.has_credentials()
self.has_read_credentials = true; }
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
};
if let Some(api_key) = api_key {
*self.api_key.borrow_mut() = Some(api_key);
} else if self.api_key_editor.is_none() {
self.api_key_editor = Some(build_api_key_editor(cx));
cx.notify();
}
}
self.api_key.borrow().clone() fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
self.completion_provider.retrieve_credentials(cx);
} }
} }
@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) { fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
if active { if active {
self.load_api_key(cx); self.load_credentials(cx);
if self.editors.is_empty() { if self.editors.is_empty() {
self.new_conversation(cx); self.new_conversation(cx);
@ -1454,10 +1436,10 @@ struct Conversation {
token_count: Option<usize>, token_count: Option<usize>,
max_token_count: usize, max_token_count: usize,
pending_token_count: Task<Option<()>>, pending_token_count: Task<Option<()>>,
api_key: Rc<RefCell<Option<String>>>,
pending_save: Task<Result<()>>, pending_save: Task<Result<()>>,
path: Option<PathBuf>, path: Option<PathBuf>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
completion_provider: Box<dyn CompletionProvider>,
} }
impl Entity for Conversation { impl Entity for Conversation {
@ -1466,9 +1448,9 @@ impl Entity for Conversation {
impl Conversation { impl Conversation {
fn new( fn new(
api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
completion_provider: Box<dyn CompletionProvider>,
) -> Self { ) -> Self {
let markdown = language_registry.language_for_name("Markdown"); let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| { let buffer = cx.add_model(|cx| {
@ -1507,8 +1489,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())), pending_save: Task::ready(Ok(())),
path: None, path: None,
api_key,
buffer, buffer,
completion_provider,
}; };
let message = MessageAnchor { let message = MessageAnchor {
id: MessageId(post_inc(&mut this.next_message_id.0)), id: MessageId(post_inc(&mut this.next_message_id.0)),
@ -1554,7 +1536,6 @@ impl Conversation {
fn deserialize( fn deserialize(
saved_conversation: SavedConversation, saved_conversation: SavedConversation,
path: PathBuf, path: PathBuf,
api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
@ -1563,6 +1544,10 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()), None => Some(Uuid::new_v4().to_string()),
}; };
let model = saved_conversation.model; let model = saved_conversation.model;
let completion_provider: Box<dyn CompletionProvider> = Box::new(
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
);
completion_provider.retrieve_credentials(cx);
let markdown = language_registry.language_for_name("Markdown"); let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new(); let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0); let mut next_message_id = MessageId(0);
@ -1609,8 +1594,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())), pending_save: Task::ready(Ok(())),
path: Some(path), path: Some(path),
api_key,
buffer, buffer,
completion_provider,
}; };
this.count_remaining_tokens(cx); this.count_remaining_tokens(cx);
this this
@ -1731,11 +1716,11 @@ impl Conversation {
} }
if should_assist { if should_assist {
let Some(api_key) = self.api_key.borrow().clone() else { if !self.completion_provider.has_credentials() {
return Default::default(); return Default::default();
}; }
let request = OpenAIRequest { let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(), model: self.model.full_name().to_string(),
messages: self messages: self
.messages(cx) .messages(cx)
@ -1745,9 +1730,9 @@ impl Conversation {
stream: true, stream: true,
stop: vec![], stop: vec![],
temperature: 1.0, temperature: 1.0,
}; });
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = self.completion_provider.complete(request);
let assistant_message = self let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap(); .unwrap();
@ -1765,33 +1750,28 @@ impl Conversation {
let mut messages = stream.await?; let mut messages = stream.await?;
while let Some(message) = messages.next().await { while let Some(message) = messages.next().await {
let mut message = message?; let text = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix =
this.message_anchors.iter().position(|message| {
message.id == assistant_message_id
})?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(ConversationEvent::StreamedCompletion);
Some(()) this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let message_ix = this
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message.start.to_offset(buffer).saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
}); });
} cx.emit(ConversationEvent::StreamedCompletion);
Some(())
});
smol::future::yield_now().await; smol::future::yield_now().await;
} }
@ -2013,57 +1993,54 @@ impl Conversation {
fn summarize(&mut self, cx: &mut ModelContext<Self>) { fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.message_anchors.len() >= 2 && self.summary.is_none() { if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone(); if !self.completion_provider.has_credentials() {
if let Some(api_key) = api_key { return;
let messages = self
.messages(cx)
.take(2)
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage {
role: Role::User,
content:
"Summarize the conversation into a short title without punctuation"
.into(),
}));
let request = OpenAIRequest {
model: self.model.full_name().to_string(),
messages: messages.collect(),
stream: true,
stop: vec![],
temperature: 1.0,
};
let stream = stream_completion(api_key, cx.background().clone(), request);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
let text = choice.delta.content.unwrap_or_default();
this.update(&mut cx, |this, cx| {
this.summary
.get_or_insert(Default::default())
.text
.push_str(&text);
cx.emit(ConversationEvent::SummaryChanged);
});
}
}
this.update(&mut cx, |this, cx| {
if let Some(summary) = this.summary.as_mut() {
summary.done = true;
cx.emit(ConversationEvent::SummaryChanged);
}
});
anyhow::Ok(())
}
.log_err()
});
} }
let messages = self
.messages(cx)
.take(2)
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage {
role: Role::User,
content: "Summarize the conversation into a short title without punctuation"
.into(),
}));
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(),
messages: messages.collect(),
stream: true,
stop: vec![],
temperature: 1.0,
});
let stream = self.completion_provider.complete(request);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let text = message?;
this.update(&mut cx, |this, cx| {
this.summary
.get_or_insert(Default::default())
.text
.push_str(&text);
cx.emit(ConversationEvent::SummaryChanged);
});
}
this.update(&mut cx, |this, cx| {
if let Some(summary) = this.summary.as_mut() {
summary.done = true;
cx.emit(ConversationEvent::SummaryChanged);
}
});
anyhow::Ok(())
}
.log_err()
});
} }
} }
@ -2224,13 +2201,14 @@ struct ConversationEditor {
impl ConversationEditor { impl ConversationEditor {
fn new( fn new(
api_key: Rc<RefCell<Option<String>>>, completion_provider: Box<dyn CompletionProvider>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
workspace: WeakViewHandle<Workspace>, workspace: WeakViewHandle<Workspace>,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Self { ) -> Self {
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); let conversation =
cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
Self::for_conversation(conversation, fs, workspace, cx) Self::for_conversation(conversation, fs, workspace, cx)
} }
@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests { mod tests {
use super::*; use super::*;
use crate::MessageId; use crate::MessageId;
use ai::test::FakeCompletionProvider;
use gpui::AppContext; use gpui::AppContext;
#[gpui::test] #[gpui::test]
@ -3426,7 +3405,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone(); let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3554,7 +3535,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone(); let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3650,7 +3633,8 @@ mod tests {
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone(); let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3732,8 +3716,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test()); let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = let conversation =
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone(); let buffer = conversation.read(cx).buffer.clone();
let message_0 = conversation.read(cx).message_anchors[0].id; let message_0 = conversation.read(cx).message_anchors[0].id;
let message_1 = conversation.update(cx, |conversation, cx| { let message_1 = conversation.update(cx, |conversation, cx| {
@ -3770,7 +3755,6 @@ mod tests {
Conversation::deserialize( Conversation::deserialize(
conversation.read(cx).serialize(cx), conversation.read(cx).serialize(cx),
Default::default(), Default::default(),
Default::default(),
registry.clone(), registry.clone(),
cx, cx,
) )

View File

@ -1,5 +1,5 @@
use crate::streaming_diff::{Hunk, StreamingDiff}; use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, OpenAIRequest}; use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result; use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@ -96,7 +96,7 @@ impl Codegen {
self.error.as_ref() self.error.as_ref()
} }
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) { pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range(); let range = self.range();
let snapshot = self.snapshot.clone(); let snapshot = self.snapshot.clone();
let selected_text = snapshot let selected_text = snapshot
@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures::{ use ai::test::FakeCompletionProvider;
future::BoxFuture, use futures::stream::{self};
stream::{self, BoxStream},
};
use gpui::{executor::Deterministic, TestAppContext}; use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc; use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*; use rand::prelude::*;
use serde::Serialize;
use settings::SettingsStore; use settings::SettingsStore;
use smol::future::FutureExt;
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[gpui::test(iterations = 10)] #[gpui::test(iterations = 10)]
async fn test_transform_autoindent( async fn test_transform_autoindent(
@ -372,7 +380,7 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
}); });
let provider = Arc::new(TestCompletionProvider::new()); let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| { let codegen = cx.add_model(|cx| {
Codegen::new( Codegen::new(
buffer.clone(), buffer.clone(),
@ -381,7 +389,11 @@ mod tests {
cx, cx,
) )
}); });
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!( let mut new_text = concat!(
" let mut x = 0;\n", " let mut x = 0;\n",
@ -434,7 +446,7 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6)) snapshot.anchor_before(Point::new(1, 6))
}); });
let provider = Arc::new(TestCompletionProvider::new()); let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| { let codegen = cx.add_model(|cx| {
Codegen::new( Codegen::new(
buffer.clone(), buffer.clone(),
@ -443,7 +455,11 @@ mod tests {
cx, cx,
) )
}); });
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!( let mut new_text = concat!(
"t mut x = 0;\n", "t mut x = 0;\n",
@ -496,7 +512,7 @@ mod tests {
let snapshot = buffer.snapshot(cx); let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2)) snapshot.anchor_before(Point::new(1, 2))
}); });
let provider = Arc::new(TestCompletionProvider::new()); let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| { let codegen = cx.add_model(|cx| {
Codegen::new( Codegen::new(
buffer.clone(), buffer.clone(),
@ -505,7 +521,11 @@ mod tests {
cx, cx,
) )
}); });
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!( let mut new_text = concat!(
"let mut x = 0;\n", "let mut x = 0;\n",
@ -593,38 +613,6 @@ mod tests {
} }
} }
struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl TestCompletionProvider {
fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CompletionProvider for TestCompletionProvider {
fn complete(
&self,
_prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}
fn rust_lang() -> Language { fn rust_lang() -> Language {
Language::new( Language::new(
LanguageConfig { LanguageConfig {

View File

@ -1,9 +1,10 @@
use ai::models::{LanguageModel, OpenAILanguageModel}; use ai::models::LanguageModel;
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::templates::file_context::FileContext; use ai::prompts::file_context::FileContext;
use ai::templates::generate::GenerateInlineContent; use ai::prompts::generate::GenerateInlineContent;
use ai::templates::preamble::EngineerPreamble; use ai::prompts::preamble::EngineerPreamble;
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::providers::open_ai::OpenAILanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse}; use std::cmp::{self, Reverse};
use std::ops::Range; use std::ops::Range;

View File

@ -25,7 +25,7 @@ collections = { path = "../collections" }
gpui2 = { path = "../gpui2" } gpui2 = { path = "../gpui2" }
log.workspace = true log.workspace = true
live_kit_client = { path = "../live_kit_client" } live_kit_client = { path = "../live_kit_client" }
fs = { path = "../fs" } fs2 = { path = "../fs2" }
language2 = { path = "../language2" } language2 = { path = "../language2" }
media = { path = "../media" } media = { path = "../media" }
project2 = { path = "../project2" } project2 = { path = "../project2" }
@ -43,7 +43,7 @@ serde_derive.workspace = true
[dev-dependencies] [dev-dependencies]
client2 = { path = "../client2", features = ["test-support"] } client2 = { path = "../client2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] } fs2 = { path = "../fs2", features = ["test-support"] }
language2 = { path = "../language2", features = ["test-support"] } language2 = { path = "../language2", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] } gpui2 = { path = "../gpui2", features = ["test-support"] }

View File

@ -12,8 +12,8 @@ use client2::{
use collections::HashSet; use collections::HashSet;
use futures::{future::Shared, FutureExt}; use futures::{future::Shared, FutureExt};
use gpui2::{ use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Subscription, Task, AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
WeakHandle, WeakModel,
}; };
use postage::watch; use postage::watch;
use project2::Project; use project2::Project;
@ -23,10 +23,10 @@ use std::sync::Arc;
pub use participant::ParticipantLocation; pub use participant::ParticipantLocation;
pub use room::Room; pub use room::Room;
pub fn init(client: Arc<Client>, user_store: Handle<UserStore>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
CallSettings::register(cx); CallSettings::register(cx);
let active_call = cx.entity(|cx| ActiveCall::new(client, user_store, cx)); let active_call = cx.build_model(|cx| ActiveCall::new(client, user_store, cx));
cx.set_global(active_call); cx.set_global(active_call);
} }
@ -40,16 +40,16 @@ pub struct IncomingCall {
/// Singleton global maintaining the user's participation in a room across workspaces. /// Singleton global maintaining the user's participation in a room across workspaces.
pub struct ActiveCall { pub struct ActiveCall {
room: Option<(Handle<Room>, Vec<Subscription>)>, room: Option<(Model<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<Handle<Room>, Arc<anyhow::Error>>>>>, pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
location: Option<WeakHandle<Project>>, location: Option<WeakModel<Project>>,
pending_invites: HashSet<u64>, pending_invites: HashSet<u64>,
incoming_call: ( incoming_call: (
watch::Sender<Option<IncomingCall>>, watch::Sender<Option<IncomingCall>>,
watch::Receiver<Option<IncomingCall>>, watch::Receiver<Option<IncomingCall>>,
), ),
client: Arc<Client>, client: Arc<Client>,
user_store: Handle<UserStore>, user_store: Model<UserStore>,
_subscriptions: Vec<client2::Subscription>, _subscriptions: Vec<client2::Subscription>,
} }
@ -58,11 +58,7 @@ impl EventEmitter for ActiveCall {
} }
impl ActiveCall { impl ActiveCall {
fn new( fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
client: Arc<Client>,
user_store: Handle<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
Self { Self {
room: None, room: None,
pending_room_creation: None, pending_room_creation: None,
@ -84,7 +80,7 @@ impl ActiveCall {
} }
async fn handle_incoming_call( async fn handle_incoming_call(
this: Handle<Self>, this: Model<Self>,
envelope: TypedEnvelope<proto::IncomingCall>, envelope: TypedEnvelope<proto::IncomingCall>,
_: Arc<Client>, _: Arc<Client>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
@ -112,7 +108,7 @@ impl ActiveCall {
} }
async fn handle_call_canceled( async fn handle_call_canceled(
this: Handle<Self>, this: Model<Self>,
envelope: TypedEnvelope<proto::CallCanceled>, envelope: TypedEnvelope<proto::CallCanceled>,
_: Arc<Client>, _: Arc<Client>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
@ -129,14 +125,14 @@ impl ActiveCall {
Ok(()) Ok(())
} }
pub fn global(cx: &AppContext) -> Handle<Self> { pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<Handle<Self>>().clone() cx.global::<Model<Self>>().clone()
} }
pub fn invite( pub fn invite(
&mut self, &mut self,
called_user_id: u64, called_user_id: u64,
initial_project: Option<Handle<Project>>, initial_project: Option<Model<Project>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if !self.pending_invites.insert(called_user_id) { if !self.pending_invites.insert(called_user_id) {
@ -291,7 +287,7 @@ impl ActiveCall {
&mut self, &mut self,
channel_id: u64, channel_id: u64,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Handle<Room>>> { ) -> Task<Result<Model<Room>>> {
if let Some(room) = self.room().cloned() { if let Some(room) = self.room().cloned() {
if room.read(cx).channel_id() == Some(channel_id) { if room.read(cx).channel_id() == Some(channel_id) {
return Task::ready(Ok(room)); return Task::ready(Ok(room));
@ -327,7 +323,7 @@ impl ActiveCall {
pub fn share_project( pub fn share_project(
&mut self, &mut self,
project: Handle<Project>, project: Model<Project>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> { ) -> Task<Result<u64>> {
if let Some((room, _)) = self.room.as_ref() { if let Some((room, _)) = self.room.as_ref() {
@ -340,7 +336,7 @@ impl ActiveCall {
pub fn unshare_project( pub fn unshare_project(
&mut self, &mut self,
project: Handle<Project>, project: Model<Project>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
if let Some((room, _)) = self.room.as_ref() { if let Some((room, _)) = self.room.as_ref() {
@ -351,13 +347,13 @@ impl ActiveCall {
} }
} }
pub fn location(&self) -> Option<&WeakHandle<Project>> { pub fn location(&self) -> Option<&WeakModel<Project>> {
self.location.as_ref() self.location.as_ref()
} }
pub fn set_location( pub fn set_location(
&mut self, &mut self,
project: Option<&Handle<Project>>, project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if project.is_some() || !*ZED_ALWAYS_ACTIVE { if project.is_some() || !*ZED_ALWAYS_ACTIVE {
@ -371,7 +367,7 @@ impl ActiveCall {
fn set_room( fn set_room(
&mut self, &mut self,
room: Option<Handle<Room>>, room: Option<Model<Room>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if room.as_ref() != self.room.as_ref().map(|room| &room.0) { if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
@ -407,7 +403,7 @@ impl ActiveCall {
} }
} }
pub fn room(&self) -> Option<&Handle<Room>> { pub fn room(&self) -> Option<&Model<Room>> {
self.room.as_ref().map(|(room, _)| room) self.room.as_ref().map(|(room, _)| room)
} }

View File

@ -1,7 +1,7 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use client2::ParticipantIndex; use client2::ParticipantIndex;
use client2::{proto, User}; use client2::{proto, User};
use gpui2::WeakHandle; use gpui2::WeakModel;
pub use live_kit_client::Frame; pub use live_kit_client::Frame;
use project2::Project; use project2::Project;
use std::{fmt, sync::Arc}; use std::{fmt, sync::Arc};
@ -33,7 +33,7 @@ impl ParticipantLocation {
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct LocalParticipant { pub struct LocalParticipant {
pub projects: Vec<proto::ParticipantProject>, pub projects: Vec<proto::ParticipantProject>,
pub active_project: Option<WeakHandle<Project>>, pub active_project: Option<WeakModel<Project>>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]

View File

@ -13,10 +13,10 @@ use client2::{
Client, ParticipantIndex, TypedEnvelope, User, UserStore, Client, ParticipantIndex, TypedEnvelope, User, UserStore,
}; };
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs; use fs2::Fs;
use futures::{FutureExt, StreamExt}; use futures::{FutureExt, StreamExt};
use gpui2::{ use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Task, WeakHandle, AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
}; };
use language2::LanguageRegistry; use language2::LanguageRegistry;
use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate}; use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate};
@ -61,8 +61,8 @@ pub struct Room {
channel_id: Option<u64>, channel_id: Option<u64>,
// live_kit: Option<LiveKitRoom>, // live_kit: Option<LiveKitRoom>,
status: RoomStatus, status: RoomStatus,
shared_projects: HashSet<WeakHandle<Project>>, shared_projects: HashSet<WeakModel<Project>>,
joined_projects: HashSet<WeakHandle<Project>>, joined_projects: HashSet<WeakModel<Project>>,
local_participant: LocalParticipant, local_participant: LocalParticipant,
remote_participants: BTreeMap<u64, RemoteParticipant>, remote_participants: BTreeMap<u64, RemoteParticipant>,
pending_participants: Vec<Arc<User>>, pending_participants: Vec<Arc<User>>,
@ -70,7 +70,7 @@ pub struct Room {
pending_call_count: usize, pending_call_count: usize,
leave_when_empty: bool, leave_when_empty: bool,
client: Arc<Client>, client: Arc<Client>,
user_store: Handle<UserStore>, user_store: Model<UserStore>,
follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>, follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>,
client_subscriptions: Vec<client2::Subscription>, client_subscriptions: Vec<client2::Subscription>,
_subscriptions: Vec<gpui2::Subscription>, _subscriptions: Vec<gpui2::Subscription>,
@ -111,7 +111,7 @@ impl Room {
channel_id: Option<u64>, channel_id: Option<u64>,
live_kit_connection_info: Option<proto::LiveKitConnectionInfo>, live_kit_connection_info: Option<proto::LiveKitConnectionInfo>,
client: Arc<Client>, client: Arc<Client>,
user_store: Handle<UserStore>, user_store: Model<UserStore>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
todo!() todo!()
@ -237,15 +237,15 @@ impl Room {
pub(crate) fn create( pub(crate) fn create(
called_user_id: u64, called_user_id: u64,
initial_project: Option<Handle<Project>>, initial_project: Option<Model<Project>>,
client: Arc<Client>, client: Arc<Client>,
user_store: Handle<UserStore>, user_store: Model<UserStore>,
cx: &mut AppContext, cx: &mut AppContext,
) -> Task<Result<Handle<Self>>> { ) -> Task<Result<Model<Self>>> {
cx.spawn(move |mut cx| async move { cx.spawn(move |mut cx| async move {
let response = client.request(proto::CreateRoom {}).await?; let response = client.request(proto::CreateRoom {}).await?;
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
let room = cx.entity(|cx| { let room = cx.build_model(|cx| {
Self::new( Self::new(
room_proto.id, room_proto.id,
None, None,
@ -283,9 +283,9 @@ impl Room {
pub(crate) fn join_channel( pub(crate) fn join_channel(
channel_id: u64, channel_id: u64,
client: Arc<Client>, client: Arc<Client>,
user_store: Handle<UserStore>, user_store: Model<UserStore>,
cx: &mut AppContext, cx: &mut AppContext,
) -> Task<Result<Handle<Self>>> { ) -> Task<Result<Model<Self>>> {
cx.spawn(move |cx| async move { cx.spawn(move |cx| async move {
Self::from_join_response( Self::from_join_response(
client.request(proto::JoinChannel { channel_id }).await?, client.request(proto::JoinChannel { channel_id }).await?,
@ -299,9 +299,9 @@ impl Room {
pub(crate) fn join( pub(crate) fn join(
call: &IncomingCall, call: &IncomingCall,
client: Arc<Client>, client: Arc<Client>,
user_store: Handle<UserStore>, user_store: Model<UserStore>,
cx: &mut AppContext, cx: &mut AppContext,
) -> Task<Result<Handle<Self>>> { ) -> Task<Result<Model<Self>>> {
let id = call.room_id; let id = call.room_id;
cx.spawn(move |cx| async move { cx.spawn(move |cx| async move {
Self::from_join_response( Self::from_join_response(
@ -343,11 +343,11 @@ impl Room {
fn from_join_response( fn from_join_response(
response: proto::JoinRoomResponse, response: proto::JoinRoomResponse,
client: Arc<Client>, client: Arc<Client>,
user_store: Handle<UserStore>, user_store: Model<UserStore>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Handle<Self>> { ) -> Result<Model<Self>> {
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
let room = cx.entity(|cx| { let room = cx.build_model(|cx| {
Self::new( Self::new(
room_proto.id, room_proto.id,
response.channel_id, response.channel_id,
@ -424,7 +424,7 @@ impl Room {
} }
async fn maintain_connection( async fn maintain_connection(
this: WeakHandle<Self>, this: WeakModel<Self>,
client: Arc<Client>, client: Arc<Client>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<()> { ) -> Result<()> {
@ -661,7 +661,7 @@ impl Room {
} }
async fn handle_room_updated( async fn handle_room_updated(
this: Handle<Self>, this: Model<Self>,
envelope: TypedEnvelope<proto::RoomUpdated>, envelope: TypedEnvelope<proto::RoomUpdated>,
_: Arc<Client>, _: Arc<Client>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
@ -1101,7 +1101,7 @@ impl Room {
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Handle<Project>>> { ) -> Task<Result<Model<Project>>> {
let client = self.client.clone(); let client = self.client.clone();
let user_store = self.user_store.clone(); let user_store = self.user_store.clone();
cx.emit(Event::RemoteProjectJoined { project_id: id }); cx.emit(Event::RemoteProjectJoined { project_id: id });
@ -1125,7 +1125,7 @@ impl Room {
pub(crate) fn share_project( pub(crate) fn share_project(
&mut self, &mut self,
project: Handle<Project>, project: Model<Project>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> { ) -> Task<Result<u64>> {
if let Some(project_id) = project.read(cx).remote_id() { if let Some(project_id) = project.read(cx).remote_id() {
@ -1161,7 +1161,7 @@ impl Room {
pub(crate) fn unshare_project( pub(crate) fn unshare_project(
&mut self, &mut self,
project: Handle<Project>, project: Model<Project>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
let project_id = match project.read(cx).remote_id() { let project_id = match project.read(cx).remote_id() {
@ -1175,7 +1175,7 @@ impl Room {
pub(crate) fn set_location( pub(crate) fn set_location(
&mut self, &mut self,
project: Option<&Handle<Project>>, project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if self.status.is_offline() { if self.status.is_offline() {

View File

@ -14,8 +14,8 @@ use futures::{
future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt, future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt,
}; };
use gpui2::{ use gpui2::{
serde_json, AnyHandle, AnyWeakHandle, AppContext, AsyncAppContext, Handle, SemanticVersion, serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task,
Task, WeakHandle, WeakModel,
}; };
use lazy_static::lazy_static; use lazy_static::lazy_static;
use parking_lot::RwLock; use parking_lot::RwLock;
@ -227,7 +227,7 @@ struct ClientState {
_reconnect_task: Option<Task<()>>, _reconnect_task: Option<Task<()>>,
reconnect_interval: Duration, reconnect_interval: Duration,
entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>, entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
models_by_message_type: HashMap<TypeId, AnyWeakHandle>, models_by_message_type: HashMap<TypeId, AnyWeakModel>,
entity_types_by_message_type: HashMap<TypeId, TypeId>, entity_types_by_message_type: HashMap<TypeId, TypeId>,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
message_handlers: HashMap< message_handlers: HashMap<
@ -236,7 +236,7 @@ struct ClientState {
dyn Send dyn Send
+ Sync + Sync
+ Fn( + Fn(
AnyHandle, AnyModel,
Box<dyn AnyTypedEnvelope>, Box<dyn AnyTypedEnvelope>,
&Arc<Client>, &Arc<Client>,
AsyncAppContext, AsyncAppContext,
@ -246,7 +246,7 @@ struct ClientState {
} }
enum WeakSubscriber { enum WeakSubscriber {
Entity { handle: AnyWeakHandle }, Entity { handle: AnyWeakModel },
Pending(Vec<Box<dyn AnyTypedEnvelope>>), Pending(Vec<Box<dyn AnyTypedEnvelope>>),
} }
@ -314,7 +314,7 @@ impl<T> PendingEntitySubscription<T>
where where
T: 'static + Send, T: 'static + Send,
{ {
pub fn set_model(mut self, model: &Handle<T>, cx: &mut AsyncAppContext) -> Subscription { pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
self.consumed = true; self.consumed = true;
let mut state = self.client.state.write(); let mut state = self.client.state.write();
let id = (TypeId::of::<T>(), self.remote_id); let id = (TypeId::of::<T>(), self.remote_id);
@ -552,13 +552,13 @@ impl Client {
#[track_caller] #[track_caller]
pub fn add_message_handler<M, E, H, F>( pub fn add_message_handler<M, E, H, F>(
self: &Arc<Self>, self: &Arc<Self>,
entity: WeakHandle<E>, entity: WeakModel<E>,
handler: H, handler: H,
) -> Subscription ) -> Subscription
where where
M: EnvelopedMessage, M: EnvelopedMessage,
E: 'static + Send, E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F, H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<()>> + Send, F: 'static + Future<Output = Result<()>> + Send,
{ {
let message_type_id = TypeId::of::<M>(); let message_type_id = TypeId::of::<M>();
@ -594,13 +594,13 @@ impl Client {
pub fn add_request_handler<M, E, H, F>( pub fn add_request_handler<M, E, H, F>(
self: &Arc<Self>, self: &Arc<Self>,
model: WeakHandle<E>, model: WeakModel<E>,
handler: H, handler: H,
) -> Subscription ) -> Subscription
where where
M: RequestMessage, M: RequestMessage,
E: 'static + Send, E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F, H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<M::Response>> + Send, F: 'static + Future<Output = Result<M::Response>> + Send,
{ {
self.add_message_handler(model, move |handle, envelope, this, cx| { self.add_message_handler(model, move |handle, envelope, this, cx| {
@ -616,7 +616,7 @@ impl Client {
where where
M: EntityMessage, M: EntityMessage,
E: 'static + Send, E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F, H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<()>> + Send, F: 'static + Future<Output = Result<()>> + Send,
{ {
self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| { self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
@ -628,7 +628,7 @@ impl Client {
where where
M: EntityMessage, M: EntityMessage,
E: 'static + Send, E: 'static + Send,
H: 'static + Send + Sync + Fn(AnyHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F, H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<()>> + Send, F: 'static + Future<Output = Result<()>> + Send,
{ {
let model_type_id = TypeId::of::<E>(); let model_type_id = TypeId::of::<E>();
@ -667,7 +667,7 @@ impl Client {
where where
M: EntityMessage + RequestMessage, M: EntityMessage + RequestMessage,
E: 'static + Send, E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F, H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<M::Response>> + Send, F: 'static + Future<Output = Result<M::Response>> + Send,
{ {
self.add_model_message_handler(move |entity, envelope, client, cx| { self.add_model_message_handler(move |entity, envelope, client, cx| {
@ -1546,7 +1546,7 @@ mod tests {
let (done_tx1, mut done_rx1) = smol::channel::unbounded(); let (done_tx1, mut done_rx1) = smol::channel::unbounded();
let (done_tx2, mut done_rx2) = smol::channel::unbounded(); let (done_tx2, mut done_rx2) = smol::channel::unbounded();
client.add_model_message_handler( client.add_model_message_handler(
move |model: Handle<Model>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| { move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
match model.update(&mut cx, |model, _| model.id).unwrap() { match model.update(&mut cx, |model, _| model.id).unwrap() {
1 => done_tx1.try_send(()).unwrap(), 1 => done_tx1.try_send(()).unwrap(),
2 => done_tx2.try_send(()).unwrap(), 2 => done_tx2.try_send(()).unwrap(),
@ -1555,15 +1555,15 @@ mod tests {
async { Ok(()) } async { Ok(()) }
}, },
); );
let model1 = cx.entity(|_| Model { let model1 = cx.build_model(|_| TestModel {
id: 1, id: 1,
subscription: None, subscription: None,
}); });
let model2 = cx.entity(|_| Model { let model2 = cx.build_model(|_| TestModel {
id: 2, id: 2,
subscription: None, subscription: None,
}); });
let model3 = cx.entity(|_| Model { let model3 = cx.build_model(|_| TestModel {
id: 3, id: 3,
subscription: None, subscription: None,
}); });
@ -1596,7 +1596,7 @@ mod tests {
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
let server = FakeServer::for_client(user_id, &client, cx).await; let server = FakeServer::for_client(user_id, &client, cx).await;
let model = cx.entity(|_| Model::default()); let model = cx.build_model(|_| TestModel::default());
let (done_tx1, _done_rx1) = smol::channel::unbounded(); let (done_tx1, _done_rx1) = smol::channel::unbounded();
let (done_tx2, mut done_rx2) = smol::channel::unbounded(); let (done_tx2, mut done_rx2) = smol::channel::unbounded();
let subscription1 = client.add_message_handler( let subscription1 = client.add_message_handler(
@ -1624,11 +1624,11 @@ mod tests {
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
let server = FakeServer::for_client(user_id, &client, cx).await; let server = FakeServer::for_client(user_id, &client, cx).await;
let model = cx.entity(|_| Model::default()); let model = cx.build_model(|_| TestModel::default());
let (done_tx, mut done_rx) = smol::channel::unbounded(); let (done_tx, mut done_rx) = smol::channel::unbounded();
let subscription = client.add_message_handler( let subscription = client.add_message_handler(
model.clone().downgrade(), model.clone().downgrade(),
move |model: Handle<Model>, _: TypedEnvelope<proto::Ping>, _, mut cx| { move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
model model
.update(&mut cx, |model, _| model.subscription.take()) .update(&mut cx, |model, _| model.subscription.take())
.unwrap(); .unwrap();
@ -1644,7 +1644,7 @@ mod tests {
} }
#[derive(Default)] #[derive(Default)]
struct Model { struct TestModel {
id: usize, id: usize,
subscription: Option<Subscription>, subscription: Option<Subscription>,
} }

View File

@ -5,7 +5,9 @@ use parking_lot::Mutex;
use serde::Serialize; use serde::Serialize;
use settings2::Settings; use settings2::Settings;
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration}; use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt}; use sysinfo::{
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
};
use tempfile::NamedTempFile; use tempfile::NamedTempFile;
use util::http::HttpClient; use util::http::HttpClient;
use util::{channel::ReleaseChannel, TryFutureExt}; use util::{channel::ReleaseChannel, TryFutureExt};
@ -161,8 +163,16 @@ impl Telemetry {
let this = self.clone(); let this = self.clone();
cx.spawn(|cx| async move { cx.spawn(|cx| async move {
let mut system = System::new_all(); // Avoiding calling `System::new_all()`, as there have been crashes related to it
system.refresh_all(); let refresh_kind = RefreshKind::new()
.with_memory() // For memory usage
.with_processes(ProcessRefreshKind::everything()) // For process usage
.with_cpu(CpuRefreshKind::everything()); // For core count
let mut system = System::new_with_specifics(refresh_kind);
// Avoiding calling `refresh_all()`, just update what we need
system.refresh_specifics(refresh_kind);
loop { loop {
// Waiting some amount of time before the first query is important to get a reasonable value // Waiting some amount of time before the first query is important to get a reasonable value
@ -170,8 +180,7 @@ impl Telemetry {
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60); const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await; smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
system.refresh_memory(); system.refresh_specifics(refresh_kind);
system.refresh_processes();
let current_process = Pid::from_u32(std::process::id()); let current_process = Pid::from_u32(std::process::id());
let Some(process) = system.processes().get(&current_process) else { let Some(process) = system.processes().get(&current_process) else {

View File

@ -1,7 +1,7 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use futures::{stream::BoxStream, StreamExt}; use futures::{stream::BoxStream, StreamExt};
use gpui2::{Context, Executor, Handle, TestAppContext}; use gpui2::{Context, Executor, Model, TestAppContext};
use parking_lot::Mutex; use parking_lot::Mutex;
use rpc2::{ use rpc2::{
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse}, proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
@ -194,9 +194,9 @@ impl FakeServer {
&self, &self,
client: Arc<Client>, client: Arc<Client>,
cx: &mut TestAppContext, cx: &mut TestAppContext,
) -> Handle<UserStore> { ) -> Model<UserStore> {
let http_client = FakeHttpClient::with_404_response(); let http_client = FakeHttpClient::with_404_response();
let user_store = cx.entity(|cx| UserStore::new(client, http_client, cx)); let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx));
assert_eq!( assert_eq!(
self.receive::<proto::GetUsers>() self.receive::<proto::GetUsers>()
.await .await

View File

@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result};
use collections::{hash_map::Entry, HashMap, HashSet}; use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags2::FeatureFlagAppExt; use feature_flags2::FeatureFlagAppExt;
use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt}; use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
use gpui2::{AsyncAppContext, EventEmitter, Handle, ImageData, ModelContext, Task}; use gpui2::{AsyncAppContext, EventEmitter, ImageData, Model, ModelContext, Task};
use postage::{sink::Sink, watch}; use postage::{sink::Sink, watch};
use rpc2::proto::{RequestMessage, UsersResponse}; use rpc2::proto::{RequestMessage, UsersResponse};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
@ -213,7 +213,7 @@ impl UserStore {
} }
async fn handle_update_invite_info( async fn handle_update_invite_info(
this: Handle<Self>, this: Model<Self>,
message: TypedEnvelope<proto::UpdateInviteInfo>, message: TypedEnvelope<proto::UpdateInviteInfo>,
_: Arc<Client>, _: Arc<Client>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
@ -229,7 +229,7 @@ impl UserStore {
} }
async fn handle_show_contacts( async fn handle_show_contacts(
this: Handle<Self>, this: Model<Self>,
_: TypedEnvelope<proto::ShowContacts>, _: TypedEnvelope<proto::ShowContacts>,
_: Arc<Client>, _: Arc<Client>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
@ -243,7 +243,7 @@ impl UserStore {
} }
async fn handle_update_contacts( async fn handle_update_contacts(
this: Handle<Self>, this: Model<Self>,
message: TypedEnvelope<proto::UpdateContacts>, message: TypedEnvelope<proto::UpdateContacts>,
_: Arc<Client>, _: Arc<Client>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
@ -690,7 +690,7 @@ impl User {
impl Contact { impl Contact {
async fn from_proto( async fn from_proto(
contact: proto::Contact, contact: proto::Contact,
user_store: &Handle<UserStore>, user_store: &Model<UserStore>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let user = user_store let user = user_store

View File

@ -7,8 +7,8 @@ use async_tar::Archive;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt}; use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
use gpui2::{ use gpui2::{
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Handle, ModelContext, Task, AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Model, ModelContext, Task,
WeakHandle, WeakModel,
}; };
use language2::{ use language2::{
language_settings::{all_language_settings, language_settings}, language_settings::{all_language_settings, language_settings},
@ -49,7 +49,7 @@ pub fn init(
node_runtime: Arc<dyn NodeRuntime>, node_runtime: Arc<dyn NodeRuntime>,
cx: &mut AppContext, cx: &mut AppContext,
) { ) {
let copilot = cx.entity({ let copilot = cx.build_model({
let node_runtime = node_runtime.clone(); let node_runtime = node_runtime.clone();
move |cx| Copilot::start(new_server_id, http, node_runtime, cx) move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
}); });
@ -183,7 +183,7 @@ struct RegisteredBuffer {
impl RegisteredBuffer { impl RegisteredBuffer {
fn report_changes( fn report_changes(
&mut self, &mut self,
buffer: &Handle<Buffer>, buffer: &Model<Buffer>,
cx: &mut ModelContext<Copilot>, cx: &mut ModelContext<Copilot>,
) -> oneshot::Receiver<(i32, BufferSnapshot)> { ) -> oneshot::Receiver<(i32, BufferSnapshot)> {
let (done_tx, done_rx) = oneshot::channel(); let (done_tx, done_rx) = oneshot::channel();
@ -278,7 +278,7 @@ pub struct Copilot {
http: Arc<dyn HttpClient>, http: Arc<dyn HttpClient>,
node_runtime: Arc<dyn NodeRuntime>, node_runtime: Arc<dyn NodeRuntime>,
server: CopilotServer, server: CopilotServer,
buffers: HashSet<WeakHandle<Buffer>>, buffers: HashSet<WeakModel<Buffer>>,
server_id: LanguageServerId, server_id: LanguageServerId,
_subscription: gpui2::Subscription, _subscription: gpui2::Subscription,
} }
@ -292,9 +292,9 @@ impl EventEmitter for Copilot {
} }
impl Copilot { impl Copilot {
pub fn global(cx: &AppContext) -> Option<Handle<Self>> { pub fn global(cx: &AppContext) -> Option<Model<Self>> {
if cx.has_global::<Handle<Self>>() { if cx.has_global::<Model<Self>>() {
Some(cx.global::<Handle<Self>>().clone()) Some(cx.global::<Model<Self>>().clone())
} else { } else {
None None
} }
@ -383,7 +383,7 @@ impl Copilot {
new_server_id: LanguageServerId, new_server_id: LanguageServerId,
http: Arc<dyn HttpClient>, http: Arc<dyn HttpClient>,
node_runtime: Arc<dyn NodeRuntime>, node_runtime: Arc<dyn NodeRuntime>,
this: WeakHandle<Self>, this: WeakModel<Self>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> impl Future<Output = ()> { ) -> impl Future<Output = ()> {
async move { async move {
@ -590,7 +590,7 @@ impl Copilot {
} }
} }
pub fn register_buffer(&mut self, buffer: &Handle<Buffer>, cx: &mut ModelContext<Self>) { pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
let weak_buffer = buffer.downgrade(); let weak_buffer = buffer.downgrade();
self.buffers.insert(weak_buffer.clone()); self.buffers.insert(weak_buffer.clone());
@ -646,7 +646,7 @@ impl Copilot {
fn handle_buffer_event( fn handle_buffer_event(
&mut self, &mut self,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
event: &language2::Event, event: &language2::Event,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<()> { ) -> Result<()> {
@ -706,7 +706,7 @@ impl Copilot {
Ok(()) Ok(())
} }
fn unregister_buffer(&mut self, buffer: &WeakHandle<Buffer>) { fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
if let Ok(server) = self.server.as_running() { if let Ok(server) = self.server.as_running() {
if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) { if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
server server
@ -723,7 +723,7 @@ impl Copilot {
pub fn completions<T>( pub fn completions<T>(
&mut self, &mut self,
buffer: &Handle<Buffer>, buffer: &Model<Buffer>,
position: T, position: T,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>> ) -> Task<Result<Vec<Completion>>>
@ -735,7 +735,7 @@ impl Copilot {
pub fn completions_cycling<T>( pub fn completions_cycling<T>(
&mut self, &mut self,
buffer: &Handle<Buffer>, buffer: &Model<Buffer>,
position: T, position: T,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>> ) -> Task<Result<Vec<Completion>>>
@ -792,7 +792,7 @@ impl Copilot {
fn request_completions<R, T>( fn request_completions<R, T>(
&mut self, &mut self,
buffer: &Handle<Buffer>, buffer: &Model<Buffer>,
position: T, position: T,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>> ) -> Task<Result<Vec<Completion>>>
@ -926,7 +926,7 @@ fn id_for_language(language: Option<&Arc<Language>>) -> String {
} }
} }
fn uri_for_buffer(buffer: &Handle<Buffer>, cx: &AppContext) -> lsp2::Url { fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp2::Url {
if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) { if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
lsp2::Url::from_file_path(file.abs_path(cx)).unwrap() lsp2::Url::from_file_path(file.abs_path(cx)).unwrap()
} else { } else {

View File

@ -967,7 +967,6 @@ impl CompletionsMenu {
self.selected_item -= 1; self.selected_item -= 1;
} else { } else {
self.selected_item = self.matches.len() - 1; self.selected_item = self.matches.len() - 1;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
} }
self.list.scroll_to(ScrollTarget::Show(self.selected_item)); self.list.scroll_to(ScrollTarget::Show(self.selected_item));
self.attempt_resolve_selected_completion_documentation(project, cx); self.attempt_resolve_selected_completion_documentation(project, cx);
@ -1538,7 +1537,6 @@ impl CodeActionsMenu {
self.selected_item -= 1; self.selected_item -= 1;
} else { } else {
self.selected_item = self.actions.len() - 1; self.selected_item = self.actions.len() - 1;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
} }
self.list.scroll_to(ScrollTarget::Show(self.selected_item)); self.list.scroll_to(ScrollTarget::Show(self.selected_item));
cx.notify(); cx.notify();
@ -1547,11 +1545,10 @@ impl CodeActionsMenu {
fn select_next(&mut self, cx: &mut ViewContext<Editor>) { fn select_next(&mut self, cx: &mut ViewContext<Editor>) {
if self.selected_item + 1 < self.actions.len() { if self.selected_item + 1 < self.actions.len() {
self.selected_item += 1; self.selected_item += 1;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
} else { } else {
self.selected_item = 0; self.selected_item = 0;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
} }
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
cx.notify(); cx.notify();
} }

View File

@ -16,7 +16,7 @@ use crate::{
current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AnyWindowHandle, current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AnyWindowHandle,
AppMetadata, AssetSource, ClipboardItem, Context, DispatchPhase, DisplayId, Executor, AppMetadata, AssetSource, ClipboardItem, Context, DispatchPhase, DisplayId, Executor,
FocusEvent, FocusHandle, FocusId, KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly, FocusEvent, FocusHandle, FocusId, KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly,
Pixels, Platform, Point, SharedString, SubscriberSet, Subscription, SvgRenderer, Task, Pixels, Platform, Point, Render, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
TextStyle, TextStyleRefinement, TextSystem, View, ViewContext, Window, WindowContext, TextStyle, TextStyleRefinement, TextSystem, View, ViewContext, Window, WindowContext,
WindowHandle, WindowId, WindowHandle, WindowId,
}; };
@ -309,10 +309,17 @@ impl AppContext {
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R, update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
) -> Result<R> ) -> Result<R>
where where
V: 'static, V: 'static + Send,
{ {
self.update_window(handle.any_handle, |cx| { self.update_window(handle.any_handle, |cx| {
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap(); let root_view = cx
.window
.root_view
.as_ref()
.unwrap()
.clone()
.downcast()
.unwrap();
root_view.update(cx, update) root_view.update(cx, update)
}) })
} }
@ -685,7 +692,7 @@ impl AppContext {
pub fn observe_release<E: 'static>( pub fn observe_release<E: 'static>(
&mut self, &mut self,
handle: &Handle<E>, handle: &Model<E>,
mut on_release: impl FnMut(&mut E, &mut AppContext) + Send + 'static, mut on_release: impl FnMut(&mut E, &mut AppContext) + Send + 'static,
) -> Subscription { ) -> Subscription {
self.release_listeners.insert( self.release_listeners.insert(
@ -750,35 +757,35 @@ impl AppContext {
} }
impl Context for AppContext { impl Context for AppContext {
type EntityContext<'a, T> = ModelContext<'a, T>; type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T; type Result<T> = T;
/// Build an entity that is owned by the application. The given function will be invoked with /// Build an entity that is owned by the application. The given function will be invoked with
/// a `ModelContext` and must return an object representing the entity. A `Handle` will be returned /// a `ModelContext` and must return an object representing the entity. A `Model` will be returned
/// which can be used to access the entity in a context. /// which can be used to access the entity in a context.
fn entity<T: 'static + Send>( fn build_model<T: 'static + Send>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Handle<T> { ) -> Model<T> {
self.update(|cx| { self.update(|cx| {
let slot = cx.entities.reserve(); let slot = cx.entities.reserve();
let entity = build_entity(&mut ModelContext::mutable(cx, slot.downgrade())); let entity = build_model(&mut ModelContext::mutable(cx, slot.downgrade()));
cx.entities.insert(slot, entity) cx.entities.insert(slot, entity)
}) })
} }
/// Update the entity referenced by the given handle. The function is passed a mutable reference to the /// Update the entity referenced by the given model. The function is passed a mutable reference to the
/// entity along with a `ModelContext` for the entity. /// entity along with a `ModelContext` for the entity.
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R { ) -> R {
self.update(|cx| { self.update(|cx| {
let mut entity = cx.entities.lease(handle); let mut entity = cx.entities.lease(model);
let result = update( let result = update(
&mut entity, &mut entity,
&mut ModelContext::mutable(cx, handle.downgrade()), &mut ModelContext::mutable(cx, model.downgrade()),
); );
cx.entities.end_lease(entity); cx.entities.end_lease(entity);
result result
@ -861,10 +868,17 @@ impl MainThread<AppContext> {
update: impl FnOnce(&mut V, &mut MainThread<ViewContext<'_, '_, V>>) -> R, update: impl FnOnce(&mut V, &mut MainThread<ViewContext<'_, '_, V>>) -> R,
) -> Result<R> ) -> Result<R>
where where
V: 'static, V: 'static + Send,
{ {
self.update_window(handle.any_handle, |cx| { self.update_window(handle.any_handle, |cx| {
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap(); let root_view = cx
.window
.root_view
.as_ref()
.unwrap()
.clone()
.downcast()
.unwrap();
root_view.update(cx, update) root_view.update(cx, update)
}) })
} }
@ -872,7 +886,7 @@ impl MainThread<AppContext> {
/// Opens a new window with the given option and the root view returned by the given function. /// Opens a new window with the given option and the root view returned by the given function.
/// The function is invoked with a `WindowContext`, which can be used to interact with window-specific /// The function is invoked with a `WindowContext`, which can be used to interact with window-specific
/// functionality. /// functionality.
pub fn open_window<V: 'static>( pub fn open_window<V: Render>(
&mut self, &mut self,
options: crate::WindowOptions, options: crate::WindowOptions,
build_root_view: impl FnOnce(&mut WindowContext) -> View<V> + Send + 'static, build_root_view: impl FnOnce(&mut WindowContext) -> View<V> + Send + 'static,
@ -955,10 +969,8 @@ impl<G: 'static> DerefMut for GlobalLease<G> {
/// Contains state associated with an active drag operation, started by dragging an element /// Contains state associated with an active drag operation, started by dragging an element
/// within the window or by dragging into the app from the underlying platform. /// within the window or by dragging into the app from the underlying platform.
pub(crate) struct AnyDrag { pub(crate) struct AnyDrag {
pub drag_handle_view: Option<AnyView>, pub view: AnyView,
pub cursor_offset: Point<Pixels>, pub cursor_offset: Point<Pixels>,
pub state: AnyBox,
pub state_type: TypeId,
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
AnyWindowHandle, AppContext, Component, Context, Executor, Handle, MainThread, ModelContext, AnyWindowHandle, AppContext, Context, Executor, MainThread, Model, ModelContext, Result, Task,
Result, Task, View, ViewContext, VisualContext, WindowContext, WindowHandle, View, ViewContext, VisualContext, WindowContext, WindowHandle,
}; };
use anyhow::Context as _; use anyhow::Context as _;
use derive_more::{Deref, DerefMut}; use derive_more::{Deref, DerefMut};
@ -14,25 +14,25 @@ pub struct AsyncAppContext {
} }
impl Context for AsyncAppContext { impl Context for AsyncAppContext {
type EntityContext<'a, T> = ModelContext<'a, T>; type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = Result<T>; type Result<T> = Result<T>;
fn entity<T: 'static>( fn build_model<T: 'static>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Handle<T>> ) -> Self::Result<Model<T>>
where where
T: 'static + Send, T: 'static + Send,
{ {
let app = self.app.upgrade().context("app was released")?; let app = self.app.upgrade().context("app was released")?;
let mut lock = app.lock(); // Need this to compile let mut lock = app.lock(); // Need this to compile
Ok(lock.entity(build_entity)) Ok(lock.build_model(build_model))
} }
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> { ) -> Self::Result<R> {
let app = self.app.upgrade().context("app was released")?; let app = self.app.upgrade().context("app was released")?;
let mut lock = app.lock(); // Need this to compile let mut lock = app.lock(); // Need this to compile
@ -84,7 +84,7 @@ impl AsyncAppContext {
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R, update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
) -> Result<R> ) -> Result<R>
where where
V: 'static, V: 'static + Send,
{ {
let app = self.app.upgrade().context("app was released")?; let app = self.app.upgrade().context("app was released")?;
let mut app_context = app.lock(); let mut app_context = app.lock();
@ -234,24 +234,24 @@ impl AsyncWindowContext {
} }
impl Context for AsyncWindowContext { impl Context for AsyncWindowContext {
type EntityContext<'a, T> = ModelContext<'a, T>; type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = Result<T>; type Result<T> = Result<T>;
fn entity<T>( fn build_model<T>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Result<Handle<T>> ) -> Result<Model<T>>
where where
T: 'static + Send, T: 'static + Send,
{ {
self.app self.app
.update_window(self.window, |cx| cx.entity(build_entity)) .update_window(self.window, |cx| cx.build_model(build_model))
} }
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Result<R> { ) -> Result<R> {
self.app self.app
.update_window(self.window, |cx| cx.update_entity(handle, update)) .update_window(self.window, |cx| cx.update_entity(handle, update))
@ -261,17 +261,15 @@ impl Context for AsyncWindowContext {
impl VisualContext for AsyncWindowContext { impl VisualContext for AsyncWindowContext {
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>; type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
fn build_view<E, V>( fn build_view<V>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V, build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
) -> Self::Result<View<V>> ) -> Self::Result<View<V>>
where where
E: Component<V>,
V: 'static + Send, V: 'static + Send,
{ {
self.app self.app
.update_window(self.window, |cx| cx.build_view(build_entity, render)) .update_window(self.window, |cx| cx.build_view(build_view_state))
} }
fn update_view<V: 'static, R>( fn update_view<V: 'static, R>(

View File

@ -1,4 +1,4 @@
use crate::{AnyBox, AppContext, Context, EntityHandle}; use crate::{AnyBox, AppContext, Context};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use derive_more::{Deref, DerefMut}; use derive_more::{Deref, DerefMut};
use parking_lot::{RwLock, RwLockUpgradableReadGuard}; use parking_lot::{RwLock, RwLockUpgradableReadGuard};
@ -53,29 +53,29 @@ impl EntityMap {
/// Reserve a slot for an entity, which you can subsequently use with `insert`. /// Reserve a slot for an entity, which you can subsequently use with `insert`.
pub fn reserve<T: 'static>(&self) -> Slot<T> { pub fn reserve<T: 'static>(&self) -> Slot<T> {
let id = self.ref_counts.write().counts.insert(1.into()); let id = self.ref_counts.write().counts.insert(1.into());
Slot(Handle::new(id, Arc::downgrade(&self.ref_counts))) Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
} }
/// Insert an entity into a slot obtained by calling `reserve`. /// Insert an entity into a slot obtained by calling `reserve`.
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Handle<T> pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
where where
T: 'static + Send, T: 'static + Send,
{ {
let handle = slot.0; let model = slot.0;
self.entities.insert(handle.entity_id, Box::new(entity)); self.entities.insert(model.entity_id, Box::new(entity));
handle model
} }
/// Move an entity to the stack. /// Move an entity to the stack.
pub fn lease<'a, T>(&mut self, handle: &'a Handle<T>) -> Lease<'a, T> { pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
self.assert_valid_context(handle); self.assert_valid_context(model);
let entity = Some( let entity = Some(
self.entities self.entities
.remove(handle.entity_id) .remove(model.entity_id)
.expect("Circular entity lease. Is the entity already being updated?"), .expect("Circular entity lease. Is the entity already being updated?"),
); );
Lease { Lease {
handle, model,
entity, entity,
entity_type: PhantomData, entity_type: PhantomData,
} }
@ -84,18 +84,18 @@ impl EntityMap {
/// Return an entity after moving it to the stack. /// Return an entity after moving it to the stack.
pub fn end_lease<T>(&mut self, mut lease: Lease<T>) { pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
self.entities self.entities
.insert(lease.handle.entity_id, lease.entity.take().unwrap()); .insert(lease.model.entity_id, lease.entity.take().unwrap());
} }
pub fn read<T: 'static>(&self, handle: &Handle<T>) -> &T { pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
self.assert_valid_context(handle); self.assert_valid_context(model);
self.entities[handle.entity_id].downcast_ref().unwrap() self.entities[model.entity_id].downcast_ref().unwrap()
} }
fn assert_valid_context(&self, handle: &AnyHandle) { fn assert_valid_context(&self, model: &AnyModel) {
debug_assert!( debug_assert!(
Weak::ptr_eq(&handle.entity_map, &Arc::downgrade(&self.ref_counts)), Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
"used a handle with the wrong context" "used a model with the wrong context"
); );
} }
@ -115,7 +115,7 @@ impl EntityMap {
pub struct Lease<'a, T> { pub struct Lease<'a, T> {
entity: Option<AnyBox>, entity: Option<AnyBox>,
pub handle: &'a Handle<T>, pub model: &'a Model<T>,
entity_type: PhantomData<T>, entity_type: PhantomData<T>,
} }
@ -143,15 +143,15 @@ impl<'a, T> Drop for Lease<'a, T> {
} }
#[derive(Deref, DerefMut)] #[derive(Deref, DerefMut)]
pub struct Slot<T>(Handle<T>); pub struct Slot<T>(Model<T>);
pub struct AnyHandle { pub struct AnyModel {
pub(crate) entity_id: EntityId, pub(crate) entity_id: EntityId,
entity_type: TypeId, pub(crate) entity_type: TypeId,
entity_map: Weak<RwLock<EntityRefCounts>>, entity_map: Weak<RwLock<EntityRefCounts>>,
} }
impl AnyHandle { impl AnyModel {
fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self { fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
Self { Self {
entity_id: id, entity_id: id,
@ -164,18 +164,18 @@ impl AnyHandle {
self.entity_id self.entity_id
} }
pub fn downgrade(&self) -> AnyWeakHandle { pub fn downgrade(&self) -> AnyWeakModel {
AnyWeakHandle { AnyWeakModel {
entity_id: self.entity_id, entity_id: self.entity_id,
entity_type: self.entity_type, entity_type: self.entity_type,
entity_ref_counts: self.entity_map.clone(), entity_ref_counts: self.entity_map.clone(),
} }
} }
pub fn downcast<T: 'static>(&self) -> Option<Handle<T>> { pub fn downcast<T: 'static>(&self) -> Option<Model<T>> {
if TypeId::of::<T>() == self.entity_type { if TypeId::of::<T>() == self.entity_type {
Some(Handle { Some(Model {
any_handle: self.clone(), any_model: self.clone(),
entity_type: PhantomData, entity_type: PhantomData,
}) })
} else { } else {
@ -184,16 +184,16 @@ impl AnyHandle {
} }
} }
impl Clone for AnyHandle { impl Clone for AnyModel {
fn clone(&self) -> Self { fn clone(&self) -> Self {
if let Some(entity_map) = self.entity_map.upgrade() { if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.read(); let entity_map = entity_map.read();
let count = entity_map let count = entity_map
.counts .counts
.get(self.entity_id) .get(self.entity_id)
.expect("detected over-release of a handle"); .expect("detected over-release of a model");
let prev_count = count.fetch_add(1, SeqCst); let prev_count = count.fetch_add(1, SeqCst);
assert_ne!(prev_count, 0, "Detected over-release of a handle."); assert_ne!(prev_count, 0, "Detected over-release of a model.");
} }
Self { Self {
@ -204,16 +204,16 @@ impl Clone for AnyHandle {
} }
} }
impl Drop for AnyHandle { impl Drop for AnyModel {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(entity_map) = self.entity_map.upgrade() { if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.upgradable_read(); let entity_map = entity_map.upgradable_read();
let count = entity_map let count = entity_map
.counts .counts
.get(self.entity_id) .get(self.entity_id)
.expect("Detected over-release of a handle."); .expect("Detected over-release of a model.");
let prev_count = count.fetch_sub(1, SeqCst); let prev_count = count.fetch_sub(1, SeqCst);
assert_ne!(prev_count, 0, "Detected over-release of a handle."); assert_ne!(prev_count, 0, "Detected over-release of a model.");
if prev_count == 1 { if prev_count == 1 {
// We were the last reference to this entity, so we can remove it. // We were the last reference to this entity, so we can remove it.
let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map); let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
@ -223,60 +223,65 @@ impl Drop for AnyHandle {
} }
} }
impl<T> From<Handle<T>> for AnyHandle { impl<T> From<Model<T>> for AnyModel {
fn from(handle: Handle<T>) -> Self { fn from(model: Model<T>) -> Self {
handle.any_handle model.any_model
} }
} }
impl Hash for AnyHandle { impl Hash for AnyModel {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state); self.entity_id.hash(state);
} }
} }
impl PartialEq for AnyHandle { impl PartialEq for AnyModel {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id self.entity_id == other.entity_id
} }
} }
impl Eq for AnyHandle {} impl Eq for AnyModel {}
#[derive(Deref, DerefMut)] #[derive(Deref, DerefMut)]
pub struct Handle<T> { pub struct Model<T> {
#[deref] #[deref]
#[deref_mut] #[deref_mut]
any_handle: AnyHandle, pub(crate) any_model: AnyModel,
entity_type: PhantomData<T>, pub(crate) entity_type: PhantomData<T>,
} }
unsafe impl<T> Send for Handle<T> {} unsafe impl<T> Send for Model<T> {}
unsafe impl<T> Sync for Handle<T> {} unsafe impl<T> Sync for Model<T> {}
impl<T: 'static> Handle<T> { impl<T: 'static> Model<T> {
fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
where where
T: 'static, T: 'static,
{ {
Self { Self {
any_handle: AnyHandle::new(id, TypeId::of::<T>(), entity_map), any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
entity_type: PhantomData, entity_type: PhantomData,
} }
} }
pub fn downgrade(&self) -> WeakHandle<T> { pub fn downgrade(&self) -> WeakModel<T> {
WeakHandle { WeakModel {
any_handle: self.any_handle.downgrade(), any_model: self.any_model.downgrade(),
entity_type: self.entity_type, entity_type: self.entity_type,
} }
} }
/// Convert this into a dynamically typed model.
pub fn into_any(self) -> AnyModel {
self.any_model
}
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T { pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
cx.entities.read(self) cx.entities.read(self)
} }
/// Update the entity referenced by this handle with the given function. /// Update the entity referenced by this model with the given function.
/// ///
/// The update function receives a context appropriate for its environment. /// The update function receives a context appropriate for its environment.
/// When updating in an `AppContext`, it receives a `ModelContext`. /// When updating in an `AppContext`, it receives a `ModelContext`.
@ -284,7 +289,7 @@ impl<T: 'static> Handle<T> {
pub fn update<C, R>( pub fn update<C, R>(
&self, &self,
cx: &mut C, cx: &mut C,
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
) -> C::Result<R> ) -> C::Result<R>
where where
C: Context, C: Context,
@ -293,73 +298,54 @@ impl<T: 'static> Handle<T> {
} }
} }
impl<T> Clone for Handle<T> { impl<T> Clone for Model<T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
any_handle: self.any_handle.clone(), any_model: self.any_model.clone(),
entity_type: self.entity_type, entity_type: self.entity_type,
} }
} }
} }
impl<T> std::fmt::Debug for Handle<T> { impl<T> std::fmt::Debug for Model<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!( write!(
f, f,
"Handle {{ entity_id: {:?}, entity_type: {:?} }}", "Model {{ entity_id: {:?}, entity_type: {:?} }}",
self.any_handle.entity_id, self.any_model.entity_id,
type_name::<T>() type_name::<T>()
) )
} }
} }
impl<T> Hash for Handle<T> { impl<T> Hash for Model<T> {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
self.any_handle.hash(state); self.any_model.hash(state);
} }
} }
impl<T> PartialEq for Handle<T> { impl<T> PartialEq for Model<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.any_handle == other.any_handle self.any_model == other.any_model
} }
} }
impl<T> Eq for Handle<T> {} impl<T> Eq for Model<T> {}
impl<T> PartialEq<WeakHandle<T>> for Handle<T> { impl<T> PartialEq<WeakModel<T>> for Model<T> {
fn eq(&self, other: &WeakHandle<T>) -> bool { fn eq(&self, other: &WeakModel<T>) -> bool {
self.entity_id == other.entity_id self.entity_id() == other.entity_id()
}
}
impl<T: 'static> EntityHandle<T> for Handle<T> {
type Weak = WeakHandle<T>;
fn entity_id(&self) -> EntityId {
self.entity_id
}
fn downgrade(&self) -> Self::Weak {
self.downgrade()
}
fn upgrade_from(weak: &Self::Weak) -> Option<Self>
where
Self: Sized,
{
weak.upgrade()
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct AnyWeakHandle { pub struct AnyWeakModel {
pub(crate) entity_id: EntityId, pub(crate) entity_id: EntityId,
entity_type: TypeId, entity_type: TypeId,
entity_ref_counts: Weak<RwLock<EntityRefCounts>>, entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
} }
impl AnyWeakHandle { impl AnyWeakModel {
pub fn entity_id(&self) -> EntityId { pub fn entity_id(&self) -> EntityId {
self.entity_id self.entity_id
} }
@ -373,14 +359,14 @@ impl AnyWeakHandle {
ref_count > 0 ref_count > 0
} }
pub fn upgrade(&self) -> Option<AnyHandle> { pub fn upgrade(&self) -> Option<AnyModel> {
let entity_map = self.entity_ref_counts.upgrade()?; let entity_map = self.entity_ref_counts.upgrade()?;
entity_map entity_map
.read() .read()
.counts .counts
.get(self.entity_id)? .get(self.entity_id)?
.fetch_add(1, SeqCst); .fetch_add(1, SeqCst);
Some(AnyHandle { Some(AnyModel {
entity_id: self.entity_id, entity_id: self.entity_id,
entity_type: self.entity_type, entity_type: self.entity_type,
entity_map: self.entity_ref_counts.clone(), entity_map: self.entity_ref_counts.clone(),
@ -388,55 +374,55 @@ impl AnyWeakHandle {
} }
} }
impl<T> From<WeakHandle<T>> for AnyWeakHandle { impl<T> From<WeakModel<T>> for AnyWeakModel {
fn from(handle: WeakHandle<T>) -> Self { fn from(model: WeakModel<T>) -> Self {
handle.any_handle model.any_model
} }
} }
impl Hash for AnyWeakHandle { impl Hash for AnyWeakModel {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state); self.entity_id.hash(state);
} }
} }
impl PartialEq for AnyWeakHandle { impl PartialEq for AnyWeakModel {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id self.entity_id == other.entity_id
} }
} }
impl Eq for AnyWeakHandle {} impl Eq for AnyWeakModel {}
#[derive(Deref, DerefMut)] #[derive(Deref, DerefMut)]
pub struct WeakHandle<T> { pub struct WeakModel<T> {
#[deref] #[deref]
#[deref_mut] #[deref_mut]
any_handle: AnyWeakHandle, any_model: AnyWeakModel,
entity_type: PhantomData<T>, entity_type: PhantomData<T>,
} }
unsafe impl<T> Send for WeakHandle<T> {} unsafe impl<T> Send for WeakModel<T> {}
unsafe impl<T> Sync for WeakHandle<T> {} unsafe impl<T> Sync for WeakModel<T> {}
impl<T> Clone for WeakHandle<T> { impl<T> Clone for WeakModel<T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
any_handle: self.any_handle.clone(), any_model: self.any_model.clone(),
entity_type: self.entity_type, entity_type: self.entity_type,
} }
} }
} }
impl<T: 'static> WeakHandle<T> { impl<T: 'static> WeakModel<T> {
pub fn upgrade(&self) -> Option<Handle<T>> { pub fn upgrade(&self) -> Option<Model<T>> {
Some(Handle { Some(Model {
any_handle: self.any_handle.upgrade()?, any_model: self.any_model.upgrade()?,
entity_type: self.entity_type, entity_type: self.entity_type,
}) })
} }
/// Update the entity referenced by this handle with the given function if /// Update the entity referenced by this model with the given function if
/// the referenced entity still exists. Returns an error if the entity has /// the referenced entity still exists. Returns an error if the entity has
/// been released. /// been released.
/// ///
@ -446,7 +432,7 @@ impl<T: 'static> WeakHandle<T> {
pub fn update<C, R>( pub fn update<C, R>(
&self, &self,
cx: &mut C, cx: &mut C,
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
) -> Result<R> ) -> Result<R>
where where
C: Context, C: Context,
@ -460,22 +446,22 @@ impl<T: 'static> WeakHandle<T> {
} }
} }
impl<T> Hash for WeakHandle<T> { impl<T> Hash for WeakModel<T> {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
self.any_handle.hash(state); self.any_model.hash(state);
} }
} }
impl<T> PartialEq for WeakHandle<T> { impl<T> PartialEq for WeakModel<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.any_handle == other.any_handle self.any_model == other.any_model
} }
} }
impl<T> Eq for WeakHandle<T> {} impl<T> Eq for WeakModel<T> {}
impl<T> PartialEq<Handle<T>> for WeakHandle<T> { impl<T> PartialEq<Model<T>> for WeakModel<T> {
fn eq(&self, other: &Handle<T>) -> bool { fn eq(&self, other: &Model<T>) -> bool {
self.entity_id == other.entity_id self.entity_id() == other.entity_id()
} }
} }

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, Handle, MainThread, AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, MainThread, Model,
Reference, Subscription, Task, WeakHandle, Reference, Subscription, Task, WeakModel,
}; };
use derive_more::{Deref, DerefMut}; use derive_more::{Deref, DerefMut};
use futures::FutureExt; use futures::FutureExt;
@ -15,11 +15,11 @@ pub struct ModelContext<'a, T> {
#[deref] #[deref]
#[deref_mut] #[deref_mut]
app: Reference<'a, AppContext>, app: Reference<'a, AppContext>,
model_state: WeakHandle<T>, model_state: WeakModel<T>,
} }
impl<'a, T: 'static> ModelContext<'a, T> { impl<'a, T: 'static> ModelContext<'a, T> {
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakHandle<T>) -> Self { pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel<T>) -> Self {
Self { Self {
app: Reference::Mutable(app), app: Reference::Mutable(app),
model_state, model_state,
@ -30,20 +30,20 @@ impl<'a, T: 'static> ModelContext<'a, T> {
self.model_state.entity_id self.model_state.entity_id
} }
pub fn handle(&self) -> Handle<T> { pub fn handle(&self) -> Model<T> {
self.weak_handle() self.weak_handle()
.upgrade() .upgrade()
.expect("The entity must be alive if we have a model context") .expect("The entity must be alive if we have a model context")
} }
pub fn weak_handle(&self) -> WeakHandle<T> { pub fn weak_handle(&self) -> WeakModel<T> {
self.model_state.clone() self.model_state.clone()
} }
pub fn observe<T2: 'static>( pub fn observe<T2: 'static>(
&mut self, &mut self,
handle: &Handle<T2>, handle: &Model<T2>,
mut on_notify: impl FnMut(&mut T, Handle<T2>, &mut ModelContext<'_, T>) + Send + 'static, mut on_notify: impl FnMut(&mut T, Model<T2>, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription ) -> Subscription
where where
T: 'static + Send, T: 'static + Send,
@ -65,10 +65,8 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn subscribe<E: 'static + EventEmitter>( pub fn subscribe<E: 'static + EventEmitter>(
&mut self, &mut self,
handle: &Handle<E>, handle: &Model<E>,
mut on_event: impl FnMut(&mut T, Handle<E>, &E::Event, &mut ModelContext<'_, T>) mut on_event: impl FnMut(&mut T, Model<E>, &E::Event, &mut ModelContext<'_, T>) + Send + 'static,
+ Send
+ 'static,
) -> Subscription ) -> Subscription
where where
T: 'static + Send, T: 'static + Send,
@ -107,7 +105,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn observe_release<E: 'static>( pub fn observe_release<E: 'static>(
&mut self, &mut self,
handle: &Handle<E>, handle: &Model<E>,
mut on_release: impl FnMut(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + 'static, mut on_release: impl FnMut(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription ) -> Subscription
where where
@ -182,7 +180,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn spawn<Fut, R>( pub fn spawn<Fut, R>(
&self, &self,
f: impl FnOnce(WeakHandle<T>, AsyncAppContext) -> Fut + Send + 'static, f: impl FnOnce(WeakModel<T>, AsyncAppContext) -> Fut + Send + 'static,
) -> Task<R> ) -> Task<R>
where where
T: 'static, T: 'static,
@ -195,7 +193,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn spawn_on_main<Fut, R>( pub fn spawn_on_main<Fut, R>(
&self, &self,
f: impl FnOnce(WeakHandle<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static, f: impl FnOnce(WeakModel<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
) -> Task<R> ) -> Task<R>
where where
Fut: Future<Output = R> + 'static, Fut: Future<Output = R> + 'static,
@ -220,23 +218,23 @@ where
} }
impl<'a, T> Context for ModelContext<'a, T> { impl<'a, T> Context for ModelContext<'a, T> {
type EntityContext<'b, U> = ModelContext<'b, U>; type ModelContext<'b, U> = ModelContext<'b, U>;
type Result<U> = U; type Result<U> = U;
fn entity<U>( fn build_model<U>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, U>) -> U, build_model: impl FnOnce(&mut Self::ModelContext<'_, U>) -> U,
) -> Handle<U> ) -> Model<U>
where where
U: 'static + Send, U: 'static + Send,
{ {
self.app.entity(build_entity) self.app.build_model(build_model)
} }
fn update_entity<U: 'static, R>( fn update_entity<U: 'static, R>(
&mut self, &mut self,
handle: &Handle<U>, handle: &Model<U>,
update: impl FnOnce(&mut U, &mut Self::EntityContext<'_, U>) -> R, update: impl FnOnce(&mut U, &mut Self::ModelContext<'_, U>) -> R,
) -> R { ) -> R {
self.app.update_entity(handle, update) self.app.update_entity(handle, update)
} }

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, Handle, MainThread, AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, MainThread, Model,
ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext, ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext,
}; };
use parking_lot::Mutex; use parking_lot::Mutex;
@ -12,24 +12,24 @@ pub struct TestAppContext {
} }
impl Context for TestAppContext { impl Context for TestAppContext {
type EntityContext<'a, T> = ModelContext<'a, T>; type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T; type Result<T> = T;
fn entity<T: 'static>( fn build_model<T: 'static>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Handle<T>> ) -> Self::Result<Model<T>>
where where
T: 'static + Send, T: 'static + Send,
{ {
let mut lock = self.app.lock(); let mut lock = self.app.lock();
lock.entity(build_entity) lock.build_model(build_model)
} }
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> { ) -> Self::Result<R> {
let mut lock = self.app.lock(); let mut lock = self.app.lock();
lock.update_entity(handle, update) lock.update_entity(handle, update)

View File

@ -4,7 +4,7 @@ pub(crate) use smallvec::SmallVec;
use std::{any::Any, mem}; use std::{any::Any, mem};
pub trait Element<V: 'static> { pub trait Element<V: 'static> {
type ElementState: 'static; type ElementState: 'static + Send;
fn id(&self) -> Option<ElementId>; fn id(&self) -> Option<ElementId>;

View File

@ -70,33 +70,31 @@ use taffy::TaffyLayoutEngine;
type AnyBox = Box<dyn Any + Send>; type AnyBox = Box<dyn Any + Send>;
pub trait Context { pub trait Context {
type EntityContext<'a, T>; type ModelContext<'a, T>;
type Result<T>; type Result<T>;
fn entity<T>( fn build_model<T>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Handle<T>> ) -> Self::Result<Model<T>>
where where
T: 'static + Send; T: 'static + Send;
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R>; ) -> Self::Result<R>;
} }
pub trait VisualContext: Context { pub trait VisualContext: Context {
type ViewContext<'a, 'w, V>; type ViewContext<'a, 'w, V>;
fn build_view<E, V>( fn build_view<V>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V, build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
) -> Self::Result<View<V>> ) -> Self::Result<View<V>>
where where
E: Component<V>,
V: 'static + Send; V: 'static + Send;
fn update_view<V: 'static, R>( fn update_view<V: 'static, R>(
@ -140,37 +138,37 @@ impl<T> DerefMut for MainThread<T> {
} }
impl<C: Context> Context for MainThread<C> { impl<C: Context> Context for MainThread<C> {
type EntityContext<'a, T> = MainThread<C::EntityContext<'a, T>>; type ModelContext<'a, T> = MainThread<C::ModelContext<'a, T>>;
type Result<T> = C::Result<T>; type Result<T> = C::Result<T>;
fn entity<T>( fn build_model<T>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Handle<T>> ) -> Self::Result<Model<T>>
where where
T: 'static + Send, T: 'static + Send,
{ {
self.0.entity(|cx| { self.0.build_model(|cx| {
let cx = unsafe { let cx = unsafe {
mem::transmute::< mem::transmute::<
&mut C::EntityContext<'_, T>, &mut C::ModelContext<'_, T>,
&mut MainThread<C::EntityContext<'_, T>>, &mut MainThread<C::ModelContext<'_, T>>,
>(cx) >(cx)
}; };
build_entity(cx) build_model(cx)
}) })
} }
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> { ) -> Self::Result<R> {
self.0.update_entity(handle, |entity, cx| { self.0.update_entity(handle, |entity, cx| {
let cx = unsafe { let cx = unsafe {
mem::transmute::< mem::transmute::<
&mut C::EntityContext<'_, T>, &mut C::ModelContext<'_, T>,
&mut MainThread<C::EntityContext<'_, T>>, &mut MainThread<C::ModelContext<'_, T>>,
>(cx) >(cx)
}; };
update(entity, cx) update(entity, cx)
@ -181,27 +179,22 @@ impl<C: Context> Context for MainThread<C> {
impl<C: VisualContext> VisualContext for MainThread<C> { impl<C: VisualContext> VisualContext for MainThread<C> {
type ViewContext<'a, 'w, V> = MainThread<C::ViewContext<'a, 'w, V>>; type ViewContext<'a, 'w, V> = MainThread<C::ViewContext<'a, 'w, V>>;
fn build_view<E, V>( fn build_view<V>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V, build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
) -> Self::Result<View<V>> ) -> Self::Result<View<V>>
where where
E: Component<V>,
V: 'static + Send, V: 'static + Send,
{ {
self.0.build_view( self.0.build_view(|cx| {
|cx| { let cx = unsafe {
let cx = unsafe { mem::transmute::<
mem::transmute::< &mut C::ViewContext<'_, '_, V>,
&mut C::ViewContext<'_, '_, V>, &mut MainThread<C::ViewContext<'_, '_, V>>,
&mut MainThread<C::ViewContext<'_, '_, V>>, >(cx)
>(cx) };
}; build_view_state(cx)
build_entity(cx) })
},
render,
)
} }
fn update_view<V: 'static, R>( fn update_view<V: 'static, R>(

View File

@ -1,7 +1,7 @@
use crate::{ use crate::{
point, px, Action, AnyBox, AnyDrag, AppContext, BorrowWindow, Bounds, Component, div, point, px, Action, AnyDrag, AnyView, AppContext, BorrowWindow, Bounds, Component,
DispatchContext, DispatchPhase, Element, ElementId, FocusHandle, KeyMatch, Keystroke, DispatchContext, DispatchPhase, Div, Element, ElementId, FocusHandle, KeyMatch, Keystroke,
Modifiers, Overflow, Pixels, Point, SharedString, Size, Style, StyleRefinement, View, Modifiers, Overflow, Pixels, Point, Render, SharedString, Size, Style, StyleRefinement, View,
ViewContext, ViewContext,
}; };
use collections::HashMap; use collections::HashMap;
@ -258,17 +258,17 @@ pub trait StatelessInteractive<V: 'static>: Element<V> {
self self
} }
fn on_drop<S: 'static>( fn on_drop<W: 'static + Send>(
mut self, mut self,
listener: impl Fn(&mut V, S, &mut ViewContext<V>) + Send + 'static, listener: impl Fn(&mut V, View<W>, &mut ViewContext<V>) + Send + 'static,
) -> Self ) -> Self
where where
Self: Sized, Self: Sized,
{ {
self.stateless_interaction().drop_listeners.push(( self.stateless_interaction().drop_listeners.push((
TypeId::of::<S>(), TypeId::of::<W>(),
Box::new(move |view, drag_state, cx| { Box::new(move |view, dragged_view, cx| {
listener(view, *drag_state.downcast().unwrap(), cx); listener(view, dragged_view.downcast().unwrap(), cx);
}), }),
)); ));
self self
@ -314,36 +314,22 @@ pub trait StatefulInteractive<V: 'static>: StatelessInteractive<V> {
self self
} }
fn on_drag<S, R, E>( fn on_drag<W>(
mut self, mut self,
listener: impl Fn(&mut V, &mut ViewContext<V>) -> Drag<S, R, V, E> + Send + 'static, listener: impl Fn(&mut V, &mut ViewContext<V>) -> View<W> + Send + 'static,
) -> Self ) -> Self
where where
Self: Sized, Self: Sized,
S: Any + Send, W: 'static + Send + Render,
R: Fn(&mut V, &mut ViewContext<V>) -> E,
R: 'static + Send,
E: Component<V>,
{ {
debug_assert!( debug_assert!(
self.stateful_interaction().drag_listener.is_none(), self.stateful_interaction().drag_listener.is_none(),
"calling on_drag more than once on the same element is not supported" "calling on_drag more than once on the same element is not supported"
); );
self.stateful_interaction().drag_listener = self.stateful_interaction().drag_listener =
Some(Box::new(move |view_state, cursor_offset, cx| { Some(Box::new(move |view_state, cursor_offset, cx| AnyDrag {
let drag = listener(view_state, cx); view: listener(view_state, cx).into_any(),
let drag_handle_view = Some( cursor_offset,
View::for_handle(cx.handle().upgrade().unwrap(), move |view_state, cx| {
(drag.render_drag_handle)(view_state, cx)
})
.into_any(),
);
AnyDrag {
drag_handle_view,
cursor_offset,
state: Box::new(drag.state),
state_type: TypeId::of::<S>(),
}
})); }));
self self
} }
@ -412,7 +398,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
if let Some(drag) = cx.active_drag.take() { if let Some(drag) = cx.active_drag.take() {
for (state_type, group_drag_style) in &self.as_stateless().group_drag_over_styles { for (state_type, group_drag_style) in &self.as_stateless().group_drag_over_styles {
if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) { if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
if *state_type == drag.state_type if *state_type == drag.view.entity_type()
&& group_bounds.contains_point(&mouse_position) && group_bounds.contains_point(&mouse_position)
{ {
style.refine(&group_drag_style.style); style.refine(&group_drag_style.style);
@ -421,7 +407,8 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
} }
for (state_type, drag_over_style) in &self.as_stateless().drag_over_styles { for (state_type, drag_over_style) in &self.as_stateless().drag_over_styles {
if *state_type == drag.state_type && bounds.contains_point(&mouse_position) { if *state_type == drag.view.entity_type() && bounds.contains_point(&mouse_position)
{
style.refine(drag_over_style); style.refine(drag_over_style);
} }
} }
@ -509,7 +496,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
cx.on_mouse_event(move |view, event: &MouseUpEvent, phase, cx| { cx.on_mouse_event(move |view, event: &MouseUpEvent, phase, cx| {
if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) { if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
if let Some(drag_state_type) = if let Some(drag_state_type) =
cx.active_drag.as_ref().map(|drag| drag.state_type) cx.active_drag.as_ref().map(|drag| drag.view.entity_type())
{ {
for (drop_state_type, listener) in &drop_listeners { for (drop_state_type, listener) in &drop_listeners {
if *drop_state_type == drag_state_type { if *drop_state_type == drag_state_type {
@ -517,7 +504,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
.active_drag .active_drag
.take() .take()
.expect("checked for type drag state type above"); .expect("checked for type drag state type above");
listener(view, drag.state, cx); listener(view, drag.view.clone(), cx);
cx.notify(); cx.notify();
cx.stop_propagation(); cx.stop_propagation();
} }
@ -685,7 +672,7 @@ impl<V> From<ElementId> for StatefulInteraction<V> {
} }
} }
type DropListener<V> = dyn Fn(&mut V, AnyBox, &mut ViewContext<V>) + 'static + Send; type DropListener<V> = dyn Fn(&mut V, AnyView, &mut ViewContext<V>) + 'static + Send;
pub struct StatelessInteraction<V> { pub struct StatelessInteraction<V> {
pub dispatch_context: DispatchContext, pub dispatch_context: DispatchContext,
@ -866,7 +853,7 @@ pub struct Drag<S, R, V, E>
where where
R: Fn(&mut V, &mut ViewContext<V>) -> E, R: Fn(&mut V, &mut ViewContext<V>) -> E,
V: 'static, V: 'static,
E: Component<V>, E: Component<()>,
{ {
pub state: S, pub state: S,
pub render_drag_handle: R, pub render_drag_handle: R,
@ -877,7 +864,7 @@ impl<S, R, V, E> Drag<S, R, V, E>
where where
R: Fn(&mut V, &mut ViewContext<V>) -> E, R: Fn(&mut V, &mut ViewContext<V>) -> E,
V: 'static, V: 'static,
E: Component<V>, E: Component<()>,
{ {
pub fn new(state: S, render_drag_handle: R) -> Self { pub fn new(state: S, render_drag_handle: R) -> Self {
Drag { Drag {
@ -888,6 +875,10 @@ where
} }
} }
// impl<S, R, V, E> Render for Drag<S, R, V, E> {
// // fn render(&mut self, cx: ViewContext<Self>) ->
// }
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)] #[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
pub enum MouseButton { pub enum MouseButton {
Left, Left,
@ -995,6 +986,14 @@ impl Deref for MouseExitEvent {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct ExternalPaths(pub(crate) SmallVec<[PathBuf; 2]>); pub struct ExternalPaths(pub(crate) SmallVec<[PathBuf; 2]>);
impl Render for ExternalPaths {
type Element = Div<Self>;
fn render(&mut self, _: &mut ViewContext<Self>) -> Self::Element {
div() // Intentionally left empty because the platform will render icons for the dragged files
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum FileDropEvent { pub enum FileDropEvent {
Entered { Entered {

View File

@ -1,45 +1,35 @@
use crate::{ use crate::{
AnyBox, AnyElement, AppContext, AvailableSpace, BorrowWindow, Bounds, Component, Element, AnyBox, AnyElement, AnyModel, AppContext, AvailableSpace, BorrowWindow, Bounds, Component,
ElementId, EntityHandle, EntityId, Flatten, Handle, LayoutId, Pixels, Size, ViewContext, Element, ElementId, EntityHandle, EntityId, Flatten, LayoutId, Model, Pixels, Size,
VisualContext, WeakHandle, WindowContext, ViewContext, VisualContext, WeakModel, WindowContext,
}; };
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use parking_lot::Mutex;
use std::{ use std::{
any::Any, any::{Any, TypeId},
marker::PhantomData, marker::PhantomData,
sync::{Arc, Weak}, sync::Arc,
}; };
pub struct View<V> { pub trait Render: 'static + Sized {
pub(crate) state: Handle<V>, type Element: Element<Self> + 'static + Send;
render: Arc<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element;
} }
impl<V: 'static> View<V> { pub struct View<V> {
pub fn for_handle<E>( pub(crate) model: Model<V>,
state: Handle<V>, }
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
) -> View<V>
where
E: Component<V>,
{
View {
state,
render: Arc::new(Mutex::new(
move |state: &mut V, cx: &mut ViewContext<'_, '_, V>| render(state, cx).render(),
)),
}
}
impl<V: Render> View<V> {
pub fn into_any(self) -> AnyView { pub fn into_any(self) -> AnyView {
AnyView(Arc::new(self)) AnyView(Arc::new(self))
} }
}
impl<V: 'static> View<V> {
pub fn downgrade(&self) -> WeakView<V> { pub fn downgrade(&self) -> WeakView<V> {
WeakView { WeakView {
state: self.state.downgrade(), model: self.model.downgrade(),
render: Arc::downgrade(&self.render),
} }
} }
@ -55,20 +45,19 @@ impl<V: 'static> View<V> {
} }
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a V { pub fn read<'a>(&self, cx: &'a AppContext) -> &'a V {
cx.entities.read(&self.state) self.model.read(cx)
} }
} }
impl<V> Clone for View<V> { impl<V> Clone for View<V> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
state: self.state.clone(), model: self.model.clone(),
render: self.render.clone(),
} }
} }
} }
impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V> { impl<V: Render, ParentViewState: 'static> Component<ParentViewState> for View<V> {
fn render(self) -> AnyElement<ParentViewState> { fn render(self) -> AnyElement<ParentViewState> {
AnyElement::new(EraseViewState { AnyElement::new(EraseViewState {
view: self, view: self,
@ -77,11 +66,14 @@ impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V
} }
} }
impl<V: 'static> Element<()> for View<V> { impl<V> Element<()> for View<V>
where
V: Render,
{
type ElementState = AnyElement<V>; type ElementState = AnyElement<V>;
fn id(&self) -> Option<ElementId> { fn id(&self) -> Option<crate::ElementId> {
Some(ElementId::View(self.state.entity_id)) Some(ElementId::View(self.model.entity_id))
} }
fn initialize( fn initialize(
@ -91,7 +83,7 @@ impl<V: 'static> Element<()> for View<V> {
cx: &mut ViewContext<()>, cx: &mut ViewContext<()>,
) -> Self::ElementState { ) -> Self::ElementState {
self.update(cx, |state, cx| { self.update(cx, |state, cx| {
let mut any_element = (self.render.lock())(state, cx); let mut any_element = AnyElement::new(state.render(cx));
any_element.initialize(state, cx); any_element.initialize(state, cx);
any_element any_element
}) })
@ -121,7 +113,7 @@ impl<T: 'static> EntityHandle<T> for View<T> {
type Weak = WeakView<T>; type Weak = WeakView<T>;
fn entity_id(&self) -> EntityId { fn entity_id(&self) -> EntityId {
self.state.entity_id self.model.entity_id
} }
fn downgrade(&self) -> Self::Weak { fn downgrade(&self) -> Self::Weak {
@ -137,15 +129,13 @@ impl<T: 'static> EntityHandle<T> for View<T> {
} }
pub struct WeakView<V> { pub struct WeakView<V> {
pub(crate) state: WeakHandle<V>, pub(crate) model: WeakModel<V>,
render: Weak<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
} }
impl<V: 'static> WeakView<V> { impl<V: 'static> WeakView<V> {
pub fn upgrade(&self) -> Option<View<V>> { pub fn upgrade(&self) -> Option<View<V>> {
let state = self.state.upgrade()?; let model = self.model.upgrade()?;
let render = self.render.upgrade()?; Some(View { model })
Some(View { state, render })
} }
pub fn update<C, R>( pub fn update<C, R>(
@ -165,8 +155,7 @@ impl<V: 'static> WeakView<V> {
impl<V> Clone for WeakView<V> { impl<V> Clone for WeakView<V> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
state: self.state.clone(), model: self.model.clone(),
render: self.render.clone(),
} }
} }
} }
@ -178,13 +167,13 @@ struct EraseViewState<V, ParentV> {
unsafe impl<V, ParentV> Send for EraseViewState<V, ParentV> {} unsafe impl<V, ParentV> Send for EraseViewState<V, ParentV> {}
impl<V: 'static, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> { impl<V: Render, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> {
fn render(self) -> AnyElement<ParentV> { fn render(self) -> AnyElement<ParentV> {
AnyElement::new(self) AnyElement::new(self)
} }
} }
impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> { impl<V: Render, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> {
type ElementState = AnyBox; type ElementState = AnyBox;
fn id(&self) -> Option<ElementId> { fn id(&self) -> Option<ElementId> {
@ -221,30 +210,43 @@ impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, Parent
} }
trait ViewObject: Send + Sync { trait ViewObject: Send + Sync {
fn entity_type(&self) -> TypeId;
fn entity_id(&self) -> EntityId; fn entity_id(&self) -> EntityId;
fn model(&self) -> AnyModel;
fn initialize(&self, cx: &mut WindowContext) -> AnyBox; fn initialize(&self, cx: &mut WindowContext) -> AnyBox;
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId; fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId;
fn paint(&self, bounds: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext); fn paint(&self, bounds: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext);
fn as_any(&self) -> &dyn Any; fn as_any(&self) -> &dyn Any;
} }
impl<V: 'static> ViewObject for View<V> { impl<V> ViewObject for View<V>
where
V: Render,
{
fn entity_type(&self) -> TypeId {
TypeId::of::<V>()
}
fn entity_id(&self) -> EntityId { fn entity_id(&self) -> EntityId {
self.state.entity_id self.model.entity_id
}
fn model(&self) -> AnyModel {
self.model.clone().into_any()
} }
fn initialize(&self, cx: &mut WindowContext) -> AnyBox { fn initialize(&self, cx: &mut WindowContext) -> AnyBox {
cx.with_element_id(self.state.entity_id, |_global_id, cx| { cx.with_element_id(self.model.entity_id, |_global_id, cx| {
self.update(cx, |state, cx| { self.update(cx, |state, cx| {
let mut any_element = Box::new((self.render.lock())(state, cx)); let mut any_element = Box::new(AnyElement::new(state.render(cx)));
any_element.initialize(state, cx); any_element.initialize(state, cx);
any_element as AnyBox any_element
}) })
}) })
} }
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId { fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId {
cx.with_element_id(self.state.entity_id, |_global_id, cx| { cx.with_element_id(self.model.entity_id, |_global_id, cx| {
self.update(cx, |state, cx| { self.update(cx, |state, cx| {
let element = element.downcast_mut::<AnyElement<V>>().unwrap(); let element = element.downcast_mut::<AnyElement<V>>().unwrap();
element.layout(state, cx) element.layout(state, cx)
@ -253,7 +255,7 @@ impl<V: 'static> ViewObject for View<V> {
} }
fn paint(&self, _: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext) { fn paint(&self, _: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext) {
cx.with_element_id(self.state.entity_id, |_global_id, cx| { cx.with_element_id(self.model.entity_id, |_global_id, cx| {
self.update(cx, |state, cx| { self.update(cx, |state, cx| {
let element = element.downcast_mut::<AnyElement<V>>().unwrap(); let element = element.downcast_mut::<AnyElement<V>>().unwrap();
element.paint(state, cx); element.paint(state, cx);
@ -270,8 +272,12 @@ impl<V: 'static> ViewObject for View<V> {
pub struct AnyView(Arc<dyn ViewObject>); pub struct AnyView(Arc<dyn ViewObject>);
impl AnyView { impl AnyView {
pub fn downcast<V: 'static>(&self) -> Option<View<V>> { pub fn downcast<V: 'static + Send>(self) -> Option<View<V>> {
self.0.as_any().downcast_ref().cloned() self.0.model().downcast().map(|model| View { model })
}
pub(crate) fn entity_type(&self) -> TypeId {
self.0.entity_type()
} }
pub(crate) fn draw(&self, available_space: Size<AvailableSpace>, cx: &mut WindowContext) { pub(crate) fn draw(&self, available_space: Size<AvailableSpace>, cx: &mut WindowContext) {
@ -343,6 +349,18 @@ impl<ParentV: 'static> Component<ParentV> for EraseAnyViewState<ParentV> {
} }
} }
impl<T, E> Render for T
where
T: 'static + FnMut(&mut WindowContext) -> E,
E: 'static + Send + Element<T>,
{
type Element = E;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
(self)(cx)
}
}
impl<ParentV: 'static> Element<ParentV> for EraseAnyViewState<ParentV> { impl<ParentV: 'static> Element<ParentV> for EraseAnyViewState<ParentV> {
type ElementState = AnyBox; type ElementState = AnyBox;

View File

@ -1,14 +1,14 @@
use crate::{ use crate::{
px, size, Action, AnyBox, AnyDrag, AnyView, AppContext, AsyncWindowContext, AvailableSpace, px, size, Action, AnyBox, AnyDrag, AnyView, AppContext, AsyncWindowContext, AvailableSpace,
Bounds, BoxShadow, Context, Corners, DevicePixels, DispatchContext, DisplayId, Edges, Effect, Bounds, BoxShadow, Context, Corners, DevicePixels, DispatchContext, DisplayId, Edges, Effect,
EntityHandle, EntityId, EventEmitter, ExternalPaths, FileDropEvent, FocusEvent, FontId, EntityHandle, EntityId, EventEmitter, FileDropEvent, FocusEvent, FontId, GlobalElementId,
GlobalElementId, GlyphId, Handle, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch, GlyphId, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch, KeyMatcher, Keystroke,
KeyMatcher, Keystroke, LayoutId, MainThread, MainThreadOnly, ModelContext, Modifiers, LayoutId, MainThread, MainThreadOnly, Model, ModelContext, Modifiers, MonochromeSprite,
MonochromeSprite, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas,
PlatformAtlas, PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams, PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams, RenderImageParams,
RenderImageParams, RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size, RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size, Style, Subscription,
Style, Subscription, TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakModel, WeakView,
WeakHandle, WeakView, WindowOptions, SUBPIXEL_VARIANTS, WindowOptions, SUBPIXEL_VARIANTS,
}; };
use anyhow::Result; use anyhow::Result;
use collections::HashMap; use collections::HashMap;
@ -918,15 +918,13 @@ impl<'a, 'w> WindowContext<'a, 'w> {
root_view.draw(available_space, cx); root_view.draw(available_space, cx);
}); });
if let Some(mut active_drag) = self.app.active_drag.take() { if let Some(active_drag) = self.app.active_drag.take() {
self.stack(1, |cx| { self.stack(1, |cx| {
let offset = cx.mouse_position() - active_drag.cursor_offset; let offset = cx.mouse_position() - active_drag.cursor_offset;
cx.with_element_offset(Some(offset), |cx| { cx.with_element_offset(Some(offset), |cx| {
let available_space = let available_space =
size(AvailableSpace::MinContent, AvailableSpace::MinContent); size(AvailableSpace::MinContent, AvailableSpace::MinContent);
if let Some(drag_handle_view) = &mut active_drag.drag_handle_view { active_drag.view.draw(available_space, cx);
drag_handle_view.draw(available_space, cx);
}
cx.active_drag = Some(active_drag); cx.active_drag = Some(active_drag);
}); });
}); });
@ -994,12 +992,12 @@ impl<'a, 'w> WindowContext<'a, 'w> {
InputEvent::FileDrop(file_drop) => match file_drop { InputEvent::FileDrop(file_drop) => match file_drop {
FileDropEvent::Entered { position, files } => { FileDropEvent::Entered { position, files } => {
self.window.mouse_position = position; self.window.mouse_position = position;
self.active_drag.get_or_insert_with(|| AnyDrag { if self.active_drag.is_none() {
drag_handle_view: None, self.active_drag = Some(AnyDrag {
cursor_offset: position, view: self.build_view(|_| files).into_any(),
state: Box::new(files), cursor_offset: position,
state_type: TypeId::of::<ExternalPaths>(), });
}); }
InputEvent::MouseDown(MouseDownEvent { InputEvent::MouseDown(MouseDownEvent {
position, position,
button: MouseButton::Left, button: MouseButton::Left,
@ -1267,30 +1265,30 @@ impl<'a, 'w> WindowContext<'a, 'w> {
} }
impl Context for WindowContext<'_, '_> { impl Context for WindowContext<'_, '_> {
type EntityContext<'a, T> = ModelContext<'a, T>; type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T; type Result<T> = T;
fn entity<T>( fn build_model<T>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Handle<T> ) -> Model<T>
where where
T: 'static + Send, T: 'static + Send,
{ {
let slot = self.app.entities.reserve(); let slot = self.app.entities.reserve();
let entity = build_entity(&mut ModelContext::mutable(&mut *self.app, slot.downgrade())); let model = build_model(&mut ModelContext::mutable(&mut *self.app, slot.downgrade()));
self.entities.insert(slot, entity) self.entities.insert(slot, model)
} }
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R { ) -> R {
let mut entity = self.entities.lease(handle); let mut entity = self.entities.lease(model);
let result = update( let result = update(
&mut *entity, &mut *entity,
&mut ModelContext::mutable(&mut *self.app, handle.downgrade()), &mut ModelContext::mutable(&mut *self.app, model.downgrade()),
); );
self.entities.end_lease(entity); self.entities.end_lease(entity);
result result
@ -1300,21 +1298,17 @@ impl Context for WindowContext<'_, '_> {
impl VisualContext for WindowContext<'_, '_> { impl VisualContext for WindowContext<'_, '_> {
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>; type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
/// Builds a new view in the current window. The first argument is a function that builds fn build_view<V>(
/// an entity representing the view's state. It is invoked with a `ViewContext` that provides
/// entity-specific access to the window and application state during construction. The second
/// argument is a render function that returns a component based on the view's state.
fn build_view<E, V>(
&mut self, &mut self,
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V, build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
) -> Self::Result<View<V>> ) -> Self::Result<View<V>>
where where
E: crate::Component<V>,
V: 'static + Send, V: 'static + Send,
{ {
let slot = self.app.entities.reserve(); let slot = self.app.entities.reserve();
let view = View::for_handle(slot.clone(), render); let view = View {
model: slot.clone(),
};
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade()); let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
let entity = build_view_state(&mut cx); let entity = build_view_state(&mut cx);
self.entities.insert(slot, entity); self.entities.insert(slot, entity);
@ -1327,7 +1321,7 @@ impl VisualContext for WindowContext<'_, '_> {
view: &View<T>, view: &View<T>,
update: impl FnOnce(&mut T, &mut Self::ViewContext<'_, '_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ViewContext<'_, '_, T>) -> R,
) -> Self::Result<R> { ) -> Self::Result<R> {
let mut lease = self.app.entities.lease(&view.state); let mut lease = self.app.entities.lease(&view.model);
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade()); let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
let result = update(&mut *lease, &mut cx); let result = update(&mut *lease, &mut cx);
cx.app.entities.end_lease(lease); cx.app.entities.end_lease(lease);
@ -1582,8 +1576,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
self.view.clone() self.view.clone()
} }
pub fn handle(&self) -> WeakHandle<V> { pub fn model(&self) -> WeakModel<V> {
self.view.state.clone() self.view.model.clone()
} }
pub fn stack<R>(&mut self, order: u32, f: impl FnOnce(&mut Self) -> R) -> R { pub fn stack<R>(&mut self, order: u32, f: impl FnOnce(&mut Self) -> R) -> R {
@ -1603,8 +1597,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
pub fn observe<E>( pub fn observe<E>(
&mut self, &mut self,
handle: &Handle<E>, handle: &Model<E>,
mut on_notify: impl FnMut(&mut V, Handle<E>, &mut ViewContext<'_, '_, V>) + Send + 'static, mut on_notify: impl FnMut(&mut V, Model<E>, &mut ViewContext<'_, '_, V>) + Send + 'static,
) -> Subscription ) -> Subscription
where where
E: 'static, E: 'static,
@ -1665,7 +1659,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
) -> Subscription { ) -> Subscription {
let window_handle = self.window.handle; let window_handle = self.window.handle;
self.app.release_listeners.insert( self.app.release_listeners.insert(
self.view.state.entity_id, self.view.model.entity_id,
Box::new(move |this, cx| { Box::new(move |this, cx| {
let this = this.downcast_mut().expect("invalid entity type"); let this = this.downcast_mut().expect("invalid entity type");
// todo!("are we okay with silently swallowing the error?") // todo!("are we okay with silently swallowing the error?")
@ -1676,7 +1670,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
pub fn observe_release<T: 'static>( pub fn observe_release<T: 'static>(
&mut self, &mut self,
handle: &Handle<T>, handle: &Model<T>,
mut on_release: impl FnMut(&mut V, &mut T, &mut ViewContext<'_, '_, V>) + Send + 'static, mut on_release: impl FnMut(&mut V, &mut T, &mut ViewContext<'_, '_, V>) + Send + 'static,
) -> Subscription ) -> Subscription
where where
@ -1698,7 +1692,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
pub fn notify(&mut self) { pub fn notify(&mut self) {
self.window_cx.notify(); self.window_cx.notify();
self.window_cx.app.push_effect(Effect::Notify { self.window_cx.app.push_effect(Effect::Notify {
emitter: self.view.state.entity_id, emitter: self.view.model.entity_id,
}); });
} }
@ -1878,7 +1872,7 @@ where
V::Event: Any + Send, V::Event: Any + Send,
{ {
pub fn emit(&mut self, event: V::Event) { pub fn emit(&mut self, event: V::Event) {
let emitter = self.view.state.entity_id; let emitter = self.view.model.entity_id;
self.app.push_effect(Effect::Emit { self.app.push_effect(Effect::Emit {
emitter, emitter,
event: Box::new(event), event: Box::new(event),
@ -1897,41 +1891,36 @@ impl<'a, 'w, V: 'static> MainThread<ViewContext<'a, 'w, V>> {
} }
impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> { impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> {
type EntityContext<'b, U> = ModelContext<'b, U>; type ModelContext<'b, U> = ModelContext<'b, U>;
type Result<U> = U; type Result<U> = U;
fn entity<T>( fn build_model<T>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T, build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Handle<T> ) -> Model<T>
where where
T: 'static + Send, T: 'static + Send,
{ {
self.window_cx.entity(build_entity) self.window_cx.build_model(build_model)
} }
fn update_entity<T: 'static, R>( fn update_entity<T: 'static, R>(
&mut self, &mut self,
handle: &Handle<T>, model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R { ) -> R {
self.window_cx.update_entity(handle, update) self.window_cx.update_entity(model, update)
} }
} }
impl<V: 'static> VisualContext for ViewContext<'_, '_, V> { impl<V: 'static> VisualContext for ViewContext<'_, '_, V> {
type ViewContext<'a, 'w, V2> = ViewContext<'a, 'w, V2>; type ViewContext<'a, 'w, V2> = ViewContext<'a, 'w, V2>;
fn build_view<E, V2>( fn build_view<W: 'static + Send>(
&mut self, &mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V2>) -> V2, build_view: impl FnOnce(&mut Self::ViewContext<'_, '_, W>) -> W,
render: impl Fn(&mut V2, &mut ViewContext<'_, '_, V2>) -> E + Send + 'static, ) -> Self::Result<View<W>> {
) -> Self::Result<View<V2>> self.window_cx.build_view(build_view)
where
E: crate::Component<V2>,
V2: 'static + Send,
{
self.window_cx.build_view(build_entity, render)
} }
fn update_view<V2: 'static, R>( fn update_view<V2: 'static, R>(

View File

@ -5,7 +5,7 @@ use crate::language_settings::{
use crate::Buffer; use crate::Buffer;
use clock::ReplicaId; use clock::ReplicaId;
use collections::BTreeMap; use collections::BTreeMap;
use gpui2::{AppContext, Handle}; use gpui2::{AppContext, Model};
use gpui2::{Context, TestAppContext}; use gpui2::{Context, TestAppContext};
use indoc::indoc; use indoc::indoc;
use proto::deserialize_operation; use proto::deserialize_operation;
@ -42,7 +42,7 @@ fn init_logger() {
fn test_line_endings(cx: &mut gpui2::AppContext) { fn test_line_endings(cx: &mut gpui2::AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "one\r\ntwo\rthree") let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "one\r\ntwo\rthree")
.with_language(Arc::new(rust_lang()), cx); .with_language(Arc::new(rust_lang()), cx);
assert_eq!(buffer.text(), "one\ntwo\nthree"); assert_eq!(buffer.text(), "one\ntwo\nthree");
@ -138,8 +138,8 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
let buffer_1_events = Arc::new(Mutex::new(Vec::new())); let buffer_1_events = Arc::new(Mutex::new(Vec::new()));
let buffer_2_events = Arc::new(Mutex::new(Vec::new())); let buffer_2_events = Arc::new(Mutex::new(Vec::new()));
let buffer1 = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef")); let buffer1 = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef"));
let buffer2 = cx.entity(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef")); let buffer2 = cx.build_model(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef"));
let buffer1_ops = Arc::new(Mutex::new(Vec::new())); let buffer1_ops = Arc::new(Mutex::new(Vec::new()));
buffer1.update(cx, { buffer1.update(cx, {
let buffer1_ops = buffer1_ops.clone(); let buffer1_ops = buffer1_ops.clone();
@ -218,7 +218,7 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
#[gpui2::test] #[gpui2::test]
async fn test_apply_diff(cx: &mut TestAppContext) { async fn test_apply_diff(cx: &mut TestAppContext) {
let text = "a\nbb\nccc\ndddd\neeeee\nffffff\n"; let text = "a\nbb\nccc\ndddd\neeeee\nffffff\n";
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text)); let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
let anchor = buffer.update(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3))); let anchor = buffer.update(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)));
let text = "a\nccc\ndddd\nffffff\n"; let text = "a\nccc\ndddd\nffffff\n";
@ -250,7 +250,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
] ]
.join("\n"); .join("\n");
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text)); let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
// Spawn a task to format the buffer's whitespace. // Spawn a task to format the buffer's whitespace.
// Pause so that the foratting task starts running. // Pause so that the foratting task starts running.
@ -314,7 +314,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
#[gpui2::test] #[gpui2::test]
async fn test_reparse(cx: &mut gpui2::TestAppContext) { async fn test_reparse(cx: &mut gpui2::TestAppContext) {
let text = "fn a() {}"; let text = "fn a() {}";
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx) Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
}); });
@ -442,7 +442,7 @@ async fn test_reparse(cx: &mut gpui2::TestAppContext) {
#[gpui2::test] #[gpui2::test]
async fn test_resetting_language(cx: &mut gpui2::TestAppContext) { async fn test_resetting_language(cx: &mut gpui2::TestAppContext) {
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
let mut buffer = let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), "{}").with_language(Arc::new(rust_lang()), cx); Buffer::new(0, cx.entity_id().as_u64(), "{}").with_language(Arc::new(rust_lang()), cx);
buffer.set_sync_parse_timeout(Duration::ZERO); buffer.set_sync_parse_timeout(Duration::ZERO);
@ -492,7 +492,7 @@ async fn test_outline(cx: &mut gpui2::TestAppContext) {
"# "#
.unindent(); .unindent();
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx) Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
}); });
let outline = buffer let outline = buffer
@ -578,7 +578,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui2::TestAppContext) {
"# "#
.unindent(); .unindent();
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx) Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
}); });
let outline = buffer let outline = buffer
@ -616,7 +616,7 @@ async fn test_outline_with_extra_context(cx: &mut gpui2::TestAppContext) {
"# "#
.unindent(); .unindent();
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(language), cx) Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(language), cx)
}); });
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
@ -660,7 +660,7 @@ async fn test_symbols_containing(cx: &mut gpui2::TestAppContext) {
"# "#
.unindent(); .unindent();
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx) Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
}); });
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
@ -881,7 +881,7 @@ fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &
#[gpui2::test] #[gpui2::test]
fn test_range_for_syntax_ancestor(cx: &mut AppContext) { fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
cx.entity(|cx| { cx.build_model(|cx| {
let text = "fn a() { b(|c| {}) }"; let text = "fn a() { b(|c| {}) }";
let buffer = let buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx); Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -922,7 +922,7 @@ fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
fn test_autoindent_with_soft_tabs(cx: &mut AppContext) { fn test_autoindent_with_soft_tabs(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let text = "fn a() {}"; let text = "fn a() {}";
let mut buffer = let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx); Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -965,7 +965,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
settings.defaults.hard_tabs = Some(true); settings.defaults.hard_tabs = Some(true);
}); });
cx.entity(|cx| { cx.build_model(|cx| {
let text = "fn a() {}"; let text = "fn a() {}";
let mut buffer = let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx); Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -1006,7 +1006,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppContext) { fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let entity_id = cx.entity_id(); let entity_id = cx.entity_id();
let mut buffer = Buffer::new( let mut buffer = Buffer::new(
0, 0,
@ -1080,7 +1080,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
buffer buffer
}); });
cx.entity(|cx| { cx.build_model(|cx| {
eprintln!("second buffer: {:?}", cx.entity_id()); eprintln!("second buffer: {:?}", cx.entity_id());
let mut buffer = Buffer::new( let mut buffer = Buffer::new(
@ -1147,7 +1147,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut AppContext) { fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let mut buffer = Buffer::new( let mut buffer = Buffer::new(
0, 0,
cx.entity_id().as_u64(), cx.entity_id().as_u64(),
@ -1209,7 +1209,7 @@ fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut Ap
fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) { fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let mut buffer = Buffer::new( let mut buffer = Buffer::new(
0, 0,
cx.entity_id().as_u64(), cx.entity_id().as_u64(),
@ -1266,7 +1266,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) { fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let text = "a\nb"; let text = "a\nb";
let mut buffer = let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx); Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -1284,7 +1284,7 @@ fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
fn test_autoindent_multi_line_insertion(cx: &mut AppContext) { fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let text = " let text = "
const a: usize = 1; const a: usize = 1;
fn b() { fn b() {
@ -1326,7 +1326,7 @@ fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
fn test_autoindent_block_mode(cx: &mut AppContext) { fn test_autoindent_block_mode(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let text = r#" let text = r#"
fn a() { fn a() {
b(); b();
@ -1410,7 +1410,7 @@ fn test_autoindent_block_mode(cx: &mut AppContext) {
fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContext) { fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let text = r#" let text = r#"
fn a() { fn a() {
if b() { if b() {
@ -1490,7 +1490,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContex
fn test_autoindent_language_without_indents_query(cx: &mut AppContext) { fn test_autoindent_language_without_indents_query(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let text = " let text = "
* one * one
- a - a
@ -1559,7 +1559,7 @@ fn test_autoindent_with_injected_languages(cx: &mut AppContext) {
language_registry.add(html_language.clone()); language_registry.add(html_language.clone());
language_registry.add(javascript_language.clone()); language_registry.add(javascript_language.clone());
cx.entity(|cx| { cx.build_model(|cx| {
let (text, ranges) = marked_text_ranges( let (text, ranges) = marked_text_ranges(
&" &"
<div>ˇ <div>ˇ
@ -1610,7 +1610,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
settings.defaults.tab_size = Some(2.try_into().unwrap()); settings.defaults.tab_size = Some(2.try_into().unwrap());
}); });
cx.entity(|cx| { cx.build_model(|cx| {
let mut buffer = let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), "").with_language(Arc::new(ruby_lang()), cx); Buffer::new(0, cx.entity_id().as_u64(), "").with_language(Arc::new(ruby_lang()), cx);
@ -1653,7 +1653,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
fn test_language_scope_at_with_javascript(cx: &mut AppContext) { fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let language = Language::new( let language = Language::new(
LanguageConfig { LanguageConfig {
name: "JavaScript".into(), name: "JavaScript".into(),
@ -1742,7 +1742,7 @@ fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
fn test_language_scope_at_with_rust(cx: &mut AppContext) { fn test_language_scope_at_with_rust(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let language = Language::new( let language = Language::new(
LanguageConfig { LanguageConfig {
name: "Rust".into(), name: "Rust".into(),
@ -1810,7 +1810,7 @@ fn test_language_scope_at_with_rust(cx: &mut AppContext) {
fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) { fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
init_settings(cx, |_| {}); init_settings(cx, |_| {});
cx.entity(|cx| { cx.build_model(|cx| {
let text = r#" let text = r#"
<ol> <ol>
<% people.each do |person| %> <% people.each do |person| %>
@ -1858,7 +1858,7 @@ fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
fn test_serialization(cx: &mut gpui2::AppContext) { fn test_serialization(cx: &mut gpui2::AppContext) {
let mut now = Instant::now(); let mut now = Instant::now();
let buffer1 = cx.entity(|cx| { let buffer1 = cx.build_model(|cx| {
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "abc"); let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "abc");
buffer.edit([(3..3, "D")], None, cx); buffer.edit([(3..3, "D")], None, cx);
@ -1881,7 +1881,7 @@ fn test_serialization(cx: &mut gpui2::AppContext) {
let ops = cx let ops = cx
.executor() .executor()
.block(buffer1.read(cx).serialize_ops(None, cx)); .block(buffer1.read(cx).serialize_ops(None, cx));
let buffer2 = cx.entity(|cx| { let buffer2 = cx.build_model(|cx| {
let mut buffer = Buffer::from_proto(1, state, None).unwrap(); let mut buffer = Buffer::from_proto(1, state, None).unwrap();
buffer buffer
.apply_ops( .apply_ops(
@ -1914,10 +1914,11 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
let mut replica_ids = Vec::new(); let mut replica_ids = Vec::new();
let mut buffers = Vec::new(); let mut buffers = Vec::new();
let network = Arc::new(Mutex::new(Network::new(rng.clone()))); let network = Arc::new(Mutex::new(Network::new(rng.clone())));
let base_buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str())); let base_buffer =
cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str()));
for i in 0..rng.gen_range(min_peers..=max_peers) { for i in 0..rng.gen_range(min_peers..=max_peers) {
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
let state = base_buffer.read(cx).to_proto(); let state = base_buffer.read(cx).to_proto();
let ops = cx let ops = cx
.executor() .executor()
@ -2034,7 +2035,7 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
new_replica_id, new_replica_id,
replica_id replica_id
); );
new_buffer = Some(cx.entity(|cx| { new_buffer = Some(cx.build_model(|cx| {
let mut new_buffer = let mut new_buffer =
Buffer::from_proto(new_replica_id, old_buffer_state, None).unwrap(); Buffer::from_proto(new_replica_id, old_buffer_state, None).unwrap();
new_buffer new_buffer
@ -2396,7 +2397,7 @@ fn javascript_lang() -> Language {
.unwrap() .unwrap()
} }
fn get_tree_sexp(buffer: &Handle<Buffer>, cx: &mut gpui2::TestAppContext) -> String { fn get_tree_sexp(buffer: &Model<Buffer>, cx: &mut gpui2::TestAppContext) -> String {
buffer.update(cx, |buffer, _| { buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot(); let snapshot = buffer.snapshot();
let layers = snapshot.syntax.layers(buffer.as_text_snapshot()); let layers = snapshot.syntax.layers(buffer.as_text_snapshot());
@ -2412,7 +2413,7 @@ fn assert_bracket_pairs(
cx: &mut AppContext, cx: &mut AppContext,
) { ) {
let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false); let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false);
let buffer = cx.entity(|cx| { let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), expected_text.clone()) Buffer::new(0, cx.entity_id().as_u64(), expected_text.clone())
.with_language(Arc::new(language), cx) .with_language(Arc::new(language), cx)
}); });

12
crates/menu2/Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "menu2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/menu2.rs"
doctest = false
[dependencies]
gpui2 = { path = "../gpui2" }

25
crates/menu2/src/menu2.rs Normal file
View File

@ -0,0 +1,25 @@
// todo!(use actions! macro)
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Cancel;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Confirm;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SecondaryConfirm;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectPrev;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectNext;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectFirst;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectLast;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ShowContextMenu;

View File

@ -16,7 +16,7 @@ client2 = { path = "../client2" }
collections = { path = "../collections"} collections = { path = "../collections"}
language2 = { path = "../language2" } language2 = { path = "../language2" }
gpui2 = { path = "../gpui2" } gpui2 = { path = "../gpui2" }
fs = { path = "../fs" } fs2 = { path = "../fs2" }
lsp2 = { path = "../lsp2" } lsp2 = { path = "../lsp2" }
node_runtime = { path = "../node_runtime"} node_runtime = { path = "../node_runtime"}
util = { path = "../util" } util = { path = "../util" }
@ -32,4 +32,4 @@ parking_lot.workspace = true
[dev-dependencies] [dev-dependencies]
language2 = { path = "../language2", features = ["test-support"] } language2 = { path = "../language2", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] } gpui2 = { path = "../gpui2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] } fs2 = { path = "../fs2", features = ["test-support"] }

View File

@ -1,7 +1,7 @@
use anyhow::Context; use anyhow::Context;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use fs::Fs; use fs2::Fs;
use gpui2::{AsyncAppContext, Handle}; use gpui2::{AsyncAppContext, Model};
use language2::{language_settings::language_settings, Buffer, BundledFormatter, Diff}; use language2::{language_settings::language_settings, Buffer, BundledFormatter, Diff};
use lsp2::{LanguageServer, LanguageServerId}; use lsp2::{LanguageServer, LanguageServerId};
use node_runtime::NodeRuntime; use node_runtime::NodeRuntime;
@ -183,7 +183,7 @@ impl Prettier {
pub async fn format( pub async fn format(
&self, &self,
buffer: &Handle<Buffer>, buffer: &Model<Buffer>,
buffer_path: Option<PathBuf>, buffer_path: Option<PathBuf>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> anyhow::Result<Diff> { ) -> anyhow::Result<Diff> {

View File

@ -25,7 +25,7 @@ client2 = { path = "../client2" }
clock = { path = "../clock" } clock = { path = "../clock" }
collections = { path = "../collections" } collections = { path = "../collections" }
db2 = { path = "../db2" } db2 = { path = "../db2" }
fs = { path = "../fs" } fs2 = { path = "../fs2" }
fsevent = { path = "../fsevent" } fsevent = { path = "../fsevent" }
fuzzy2 = { path = "../fuzzy2" } fuzzy2 = { path = "../fuzzy2" }
git = { path = "../git" } git = { path = "../git" }
@ -71,7 +71,7 @@ pretty_assertions.workspace = true
client2 = { path = "../client2", features = ["test-support"] } client2 = { path = "../client2", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] }
db2 = { path = "../db2", features = ["test-support"] } db2 = { path = "../db2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] } fs2 = { path = "../fs2", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] } gpui2 = { path = "../gpui2", features = ["test-support"] }
language2 = { path = "../language2", features = ["test-support"] } language2 = { path = "../language2", features = ["test-support"] }
lsp2 = { path = "../lsp2", features = ["test-support"] } lsp2 = { path = "../lsp2", features = ["test-support"] }

View File

@ -7,7 +7,7 @@ use anyhow::{anyhow, Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use client2::proto::{self, PeerId}; use client2::proto::{self, PeerId};
use futures::future; use futures::future;
use gpui2::{AppContext, AsyncAppContext, Handle}; use gpui2::{AppContext, AsyncAppContext, Model};
use language2::{ use language2::{
language_settings::{language_settings, InlayHintKind}, language_settings::{language_settings, InlayHintKind},
point_from_lsp, point_to_lsp, point_from_lsp, point_to_lsp,
@ -53,8 +53,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: <Self::LspRequest as lsp2::request::Request>::Result, message: <Self::LspRequest as lsp2::request::Request>::Result,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> Result<Self::Response>; ) -> Result<Self::Response>;
@ -63,8 +63,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
async fn from_proto( async fn from_proto(
message: Self::ProtoRequest, message: Self::ProtoRequest,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> Result<Self>; ) -> Result<Self>;
@ -79,8 +79,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: <Self::ProtoRequest as proto::RequestMessage>::Response, message: <Self::ProtoRequest as proto::RequestMessage>::Response,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> Result<Self::Response>; ) -> Result<Self::Response>;
@ -180,8 +180,8 @@ impl LspCommand for PrepareRename {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: Option<lsp2::PrepareRenameResponse>, message: Option<lsp2::PrepareRenameResponse>,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
_: LanguageServerId, _: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Option<Range<Anchor>>> { ) -> Result<Option<Range<Anchor>>> {
@ -215,8 +215,8 @@ impl LspCommand for PrepareRename {
async fn from_proto( async fn from_proto(
message: proto::PrepareRename, message: proto::PrepareRename,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -256,8 +256,8 @@ impl LspCommand for PrepareRename {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::PrepareRenameResponse, message: proto::PrepareRenameResponse,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Option<Range<Anchor>>> { ) -> Result<Option<Range<Anchor>>> {
if message.can_rename { if message.can_rename {
@ -307,8 +307,8 @@ impl LspCommand for PerformRename {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: Option<lsp2::WorkspaceEdit>, message: Option<lsp2::WorkspaceEdit>,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<ProjectTransaction> { ) -> Result<ProjectTransaction> {
@ -343,8 +343,8 @@ impl LspCommand for PerformRename {
async fn from_proto( async fn from_proto(
message: proto::PerformRename, message: proto::PerformRename,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -379,8 +379,8 @@ impl LspCommand for PerformRename {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::PerformRenameResponse, message: proto::PerformRenameResponse,
project: Handle<Project>, project: Model<Project>,
_: Handle<Buffer>, _: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<ProjectTransaction> { ) -> Result<ProjectTransaction> {
let message = message let message = message
@ -426,8 +426,8 @@ impl LspCommand for GetDefinition {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: Option<lsp2::GotoDefinitionResponse>, message: Option<lsp2::GotoDefinitionResponse>,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> { ) -> Result<Vec<LocationLink>> {
@ -447,8 +447,8 @@ impl LspCommand for GetDefinition {
async fn from_proto( async fn from_proto(
message: proto::GetDefinition, message: proto::GetDefinition,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -479,8 +479,8 @@ impl LspCommand for GetDefinition {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::GetDefinitionResponse, message: proto::GetDefinitionResponse,
project: Handle<Project>, project: Model<Project>,
_: Handle<Buffer>, _: Model<Buffer>,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> { ) -> Result<Vec<LocationLink>> {
location_links_from_proto(message.links, project, cx).await location_links_from_proto(message.links, project, cx).await
@ -527,8 +527,8 @@ impl LspCommand for GetTypeDefinition {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: Option<lsp2::GotoTypeDefinitionResponse>, message: Option<lsp2::GotoTypeDefinitionResponse>,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> { ) -> Result<Vec<LocationLink>> {
@ -548,8 +548,8 @@ impl LspCommand for GetTypeDefinition {
async fn from_proto( async fn from_proto(
message: proto::GetTypeDefinition, message: proto::GetTypeDefinition,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -580,8 +580,8 @@ impl LspCommand for GetTypeDefinition {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::GetTypeDefinitionResponse, message: proto::GetTypeDefinitionResponse,
project: Handle<Project>, project: Model<Project>,
_: Handle<Buffer>, _: Model<Buffer>,
cx: AsyncAppContext, cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> { ) -> Result<Vec<LocationLink>> {
location_links_from_proto(message.links, project, cx).await location_links_from_proto(message.links, project, cx).await
@ -593,8 +593,8 @@ impl LspCommand for GetTypeDefinition {
} }
fn language_server_for_buffer( fn language_server_for_buffer(
project: &Handle<Project>, project: &Model<Project>,
buffer: &Handle<Buffer>, buffer: &Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<(Arc<CachedLspAdapter>, Arc<LanguageServer>)> { ) -> Result<(Arc<CachedLspAdapter>, Arc<LanguageServer>)> {
@ -609,7 +609,7 @@ fn language_server_for_buffer(
async fn location_links_from_proto( async fn location_links_from_proto(
proto_links: Vec<proto::LocationLink>, proto_links: Vec<proto::LocationLink>,
project: Handle<Project>, project: Model<Project>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> { ) -> Result<Vec<LocationLink>> {
let mut links = Vec::new(); let mut links = Vec::new();
@ -671,8 +671,8 @@ async fn location_links_from_proto(
async fn location_links_from_lsp( async fn location_links_from_lsp(
message: Option<lsp2::GotoDefinitionResponse>, message: Option<lsp2::GotoDefinitionResponse>,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> { ) -> Result<Vec<LocationLink>> {
@ -814,8 +814,8 @@ impl LspCommand for GetReferences {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
locations: Option<Vec<lsp2::Location>>, locations: Option<Vec<lsp2::Location>>,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<Location>> { ) -> Result<Vec<Location>> {
@ -868,8 +868,8 @@ impl LspCommand for GetReferences {
async fn from_proto( async fn from_proto(
message: proto::GetReferences, message: proto::GetReferences,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -910,8 +910,8 @@ impl LspCommand for GetReferences {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::GetReferencesResponse, message: proto::GetReferencesResponse,
project: Handle<Project>, project: Model<Project>,
_: Handle<Buffer>, _: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<Location>> { ) -> Result<Vec<Location>> {
let mut locations = Vec::new(); let mut locations = Vec::new();
@ -977,8 +977,8 @@ impl LspCommand for GetDocumentHighlights {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
lsp_highlights: Option<Vec<lsp2::DocumentHighlight>>, lsp_highlights: Option<Vec<lsp2::DocumentHighlight>>,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
_: LanguageServerId, _: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<DocumentHighlight>> { ) -> Result<Vec<DocumentHighlight>> {
@ -1016,8 +1016,8 @@ impl LspCommand for GetDocumentHighlights {
async fn from_proto( async fn from_proto(
message: proto::GetDocumentHighlights, message: proto::GetDocumentHighlights,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -1060,8 +1060,8 @@ impl LspCommand for GetDocumentHighlights {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::GetDocumentHighlightsResponse, message: proto::GetDocumentHighlightsResponse,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<DocumentHighlight>> { ) -> Result<Vec<DocumentHighlight>> {
let mut highlights = Vec::new(); let mut highlights = Vec::new();
@ -1123,8 +1123,8 @@ impl LspCommand for GetHover {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: Option<lsp2::Hover>, message: Option<lsp2::Hover>,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
_: LanguageServerId, _: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self::Response> { ) -> Result<Self::Response> {
@ -1206,8 +1206,8 @@ impl LspCommand for GetHover {
async fn from_proto( async fn from_proto(
message: Self::ProtoRequest, message: Self::ProtoRequest,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -1272,8 +1272,8 @@ impl LspCommand for GetHover {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::GetHoverResponse, message: proto::GetHoverResponse,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self::Response> { ) -> Result<Self::Response> {
let contents: Vec<_> = message let contents: Vec<_> = message
@ -1341,8 +1341,8 @@ impl LspCommand for GetCompletions {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
completions: Option<lsp2::CompletionResponse>, completions: Option<lsp2::CompletionResponse>,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<Completion>> { ) -> Result<Vec<Completion>> {
@ -1484,8 +1484,8 @@ impl LspCommand for GetCompletions {
async fn from_proto( async fn from_proto(
message: proto::GetCompletions, message: proto::GetCompletions,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let version = deserialize_version(&message.version); let version = deserialize_version(&message.version);
@ -1523,8 +1523,8 @@ impl LspCommand for GetCompletions {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::GetCompletionsResponse, message: proto::GetCompletionsResponse,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<Completion>> { ) -> Result<Vec<Completion>> {
buffer buffer
@ -1589,8 +1589,8 @@ impl LspCommand for GetCodeActions {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
actions: Option<lsp2::CodeActionResponse>, actions: Option<lsp2::CodeActionResponse>,
_: Handle<Project>, _: Model<Project>,
_: Handle<Buffer>, _: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
_: AsyncAppContext, _: AsyncAppContext,
) -> Result<Vec<CodeAction>> { ) -> Result<Vec<CodeAction>> {
@ -1623,8 +1623,8 @@ impl LspCommand for GetCodeActions {
async fn from_proto( async fn from_proto(
message: proto::GetCodeActions, message: proto::GetCodeActions,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let start = message let start = message
@ -1663,8 +1663,8 @@ impl LspCommand for GetCodeActions {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::GetCodeActionsResponse, message: proto::GetCodeActionsResponse,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Vec<CodeAction>> { ) -> Result<Vec<CodeAction>> {
buffer buffer
@ -1726,8 +1726,8 @@ impl LspCommand for OnTypeFormatting {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: Option<Vec<lsp2::TextEdit>>, message: Option<Vec<lsp2::TextEdit>>,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Option<Transaction>> { ) -> Result<Option<Transaction>> {
@ -1763,8 +1763,8 @@ impl LspCommand for OnTypeFormatting {
async fn from_proto( async fn from_proto(
message: proto::OnTypeFormatting, message: proto::OnTypeFormatting,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let position = message let position = message
@ -1805,8 +1805,8 @@ impl LspCommand for OnTypeFormatting {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::OnTypeFormattingResponse, message: proto::OnTypeFormattingResponse,
_: Handle<Project>, _: Model<Project>,
_: Handle<Buffer>, _: Model<Buffer>,
_: AsyncAppContext, _: AsyncAppContext,
) -> Result<Option<Transaction>> { ) -> Result<Option<Transaction>> {
let Some(transaction) = message.transaction else { let Some(transaction) = message.transaction else {
@ -1825,7 +1825,7 @@ impl LspCommand for OnTypeFormatting {
impl InlayHints { impl InlayHints {
pub async fn lsp_to_project_hint( pub async fn lsp_to_project_hint(
lsp_hint: lsp2::InlayHint, lsp_hint: lsp2::InlayHint,
buffer_handle: &Handle<Buffer>, buffer_handle: &Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
resolve_state: ResolveState, resolve_state: ResolveState,
force_no_type_left_padding: bool, force_no_type_left_padding: bool,
@ -2230,8 +2230,8 @@ impl LspCommand for InlayHints {
async fn response_from_lsp( async fn response_from_lsp(
self, self,
message: Option<Vec<lsp2::InlayHint>>, message: Option<Vec<lsp2::InlayHint>>,
project: Handle<Project>, project: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
server_id: LanguageServerId, server_id: LanguageServerId,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> anyhow::Result<Vec<InlayHint>> { ) -> anyhow::Result<Vec<InlayHint>> {
@ -2286,8 +2286,8 @@ impl LspCommand for InlayHints {
async fn from_proto( async fn from_proto(
message: proto::InlayHints, message: proto::InlayHints,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> Result<Self> { ) -> Result<Self> {
let start = message let start = message
@ -2326,8 +2326,8 @@ impl LspCommand for InlayHints {
async fn response_from_proto( async fn response_from_proto(
self, self,
message: proto::InlayHintsResponse, message: proto::InlayHintsResponse,
_: Handle<Project>, _: Model<Project>,
buffer: Handle<Buffer>, buffer: Model<Buffer>,
mut cx: AsyncAppContext, mut cx: AsyncAppContext,
) -> anyhow::Result<Vec<InlayHint>> { ) -> anyhow::Result<Vec<InlayHint>> {
buffer buffer

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
use crate::Project; use crate::Project;
use gpui2::{AnyWindowHandle, Context, Handle, ModelContext, WeakHandle}; use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakModel};
use settings2::Settings; use settings2::Settings;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use terminal2::{ use terminal2::{
@ -11,7 +11,7 @@ use terminal2::{
use std::os::unix::ffi::OsStrExt; use std::os::unix::ffi::OsStrExt;
pub struct Terminals { pub struct Terminals {
pub(crate) local_handles: Vec<WeakHandle<terminal2::Terminal>>, pub(crate) local_handles: Vec<WeakModel<terminal2::Terminal>>,
} }
impl Project { impl Project {
@ -20,7 +20,7 @@ impl Project {
working_directory: Option<PathBuf>, working_directory: Option<PathBuf>,
window: AnyWindowHandle, window: AnyWindowHandle,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> anyhow::Result<Handle<Terminal>> { ) -> anyhow::Result<Model<Terminal>> {
if self.is_remote() { if self.is_remote() {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"creating terminals as a guest is not supported yet" "creating terminals as a guest is not supported yet"
@ -40,7 +40,7 @@ impl Project {
|_, _| todo!("color_for_index"), |_, _| todo!("color_for_index"),
) )
.map(|builder| { .map(|builder| {
let terminal_handle = cx.entity(|cx| builder.subscribe(cx)); let terminal_handle = cx.build_model(|cx| builder.subscribe(cx));
self.terminals self.terminals
.local_handles .local_handles
@ -108,7 +108,7 @@ impl Project {
fn activate_python_virtual_environment( fn activate_python_virtual_environment(
&mut self, &mut self,
activate_script: Option<PathBuf>, activate_script: Option<PathBuf>,
terminal_handle: &Handle<Terminal>, terminal_handle: &Model<Terminal>,
cx: &mut ModelContext<Project>, cx: &mut ModelContext<Project>,
) { ) {
if let Some(activate_script) = activate_script { if let Some(activate_script) = activate_script {
@ -121,7 +121,7 @@ impl Project {
} }
} }
pub fn local_terminal_handles(&self) -> &Vec<WeakHandle<terminal2::Terminal>> { pub fn local_terminal_handles(&self) -> &Vec<WeakModel<terminal2::Terminal>> {
&self.terminals.local_handles &self.terminals.local_handles
} }
} }

View File

@ -6,7 +6,7 @@ use anyhow::{anyhow, Context as _, Result};
use client2::{proto, Client}; use client2::{proto, Client};
use clock::ReplicaId; use clock::ReplicaId;
use collections::{HashMap, HashSet, VecDeque}; use collections::{HashMap, HashSet, VecDeque};
use fs::{ use fs2::{
repository::{GitFileStatus, GitRepository, RepoPath}, repository::{GitFileStatus, GitRepository, RepoPath},
Fs, Fs,
}; };
@ -22,7 +22,7 @@ use futures::{
use fuzzy2::CharBag; use fuzzy2::CharBag;
use git::{DOT_GIT, GITIGNORE}; use git::{DOT_GIT, GITIGNORE};
use gpui2::{ use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Executor, Handle, ModelContext, Task, AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext, Task,
}; };
use language2::{ use language2::{
proto::{ proto::{
@ -292,7 +292,7 @@ impl Worktree {
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
next_entry_id: Arc<AtomicUsize>, next_entry_id: Arc<AtomicUsize>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<Handle<Self>> { ) -> Result<Model<Self>> {
// After determining whether the root entry is a file or a directory, populate the // After determining whether the root entry is a file or a directory, populate the
// snapshot's "root name", which will be used for the purpose of fuzzy matching. // snapshot's "root name", which will be used for the purpose of fuzzy matching.
let abs_path = path.into(); let abs_path = path.into();
@ -301,7 +301,7 @@ impl Worktree {
.await .await
.context("failed to stat worktree path")?; .context("failed to stat worktree path")?;
cx.entity(move |cx: &mut ModelContext<Worktree>| { cx.build_model(move |cx: &mut ModelContext<Worktree>| {
let root_name = abs_path let root_name = abs_path
.file_name() .file_name()
.map_or(String::new(), |f| f.to_string_lossy().to_string()); .map_or(String::new(), |f| f.to_string_lossy().to_string());
@ -406,8 +406,8 @@ impl Worktree {
worktree: proto::WorktreeMetadata, worktree: proto::WorktreeMetadata,
client: Arc<Client>, client: Arc<Client>,
cx: &mut AppContext, cx: &mut AppContext,
) -> Handle<Self> { ) -> Model<Self> {
cx.entity(|cx: &mut ModelContext<Self>| { cx.build_model(|cx: &mut ModelContext<Self>| {
let snapshot = Snapshot { let snapshot = Snapshot {
id: WorktreeId(worktree.id as usize), id: WorktreeId(worktree.id as usize),
abs_path: Arc::from(PathBuf::from(worktree.abs_path)), abs_path: Arc::from(PathBuf::from(worktree.abs_path)),
@ -593,7 +593,7 @@ impl LocalWorktree {
id: u64, id: u64,
path: &Path, path: &Path,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Task<Result<Handle<Buffer>>> { ) -> Task<Result<Model<Buffer>>> {
let path = Arc::from(path); let path = Arc::from(path);
cx.spawn(move |this, mut cx| async move { cx.spawn(move |this, mut cx| async move {
let (file, contents, diff_base) = this let (file, contents, diff_base) = this
@ -603,7 +603,7 @@ impl LocalWorktree {
.executor() .executor()
.spawn(async move { text::Buffer::new(0, id, contents) }) .spawn(async move { text::Buffer::new(0, id, contents) })
.await; .await;
cx.entity(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file)))) cx.build_model(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file))))
}) })
} }
@ -920,7 +920,7 @@ impl LocalWorktree {
pub fn save_buffer( pub fn save_buffer(
&self, &self,
buffer_handle: Handle<Buffer>, buffer_handle: Model<Buffer>,
path: Arc<Path>, path: Arc<Path>,
has_changed_file: bool, has_changed_file: bool,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
@ -1331,7 +1331,7 @@ impl RemoteWorktree {
pub fn save_buffer( pub fn save_buffer(
&self, &self,
buffer_handle: Handle<Buffer>, buffer_handle: Model<Buffer>,
cx: &mut ModelContext<Worktree>, cx: &mut ModelContext<Worktree>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
let buffer = buffer_handle.read(cx); let buffer = buffer_handle.read(cx);
@ -2577,7 +2577,7 @@ impl fmt::Debug for Snapshot {
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct File { pub struct File {
pub worktree: Handle<Worktree>, pub worktree: Model<Worktree>,
pub path: Arc<Path>, pub path: Arc<Path>,
pub mtime: SystemTime, pub mtime: SystemTime,
pub(crate) entry_id: ProjectEntryId, pub(crate) entry_id: ProjectEntryId,
@ -2701,7 +2701,7 @@ impl language2::LocalFile for File {
} }
impl File { impl File {
pub fn for_entry(entry: Entry, worktree: Handle<Worktree>) -> Arc<Self> { pub fn for_entry(entry: Entry, worktree: Model<Worktree>) -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
worktree, worktree,
path: entry.path.clone(), path: entry.path.clone(),
@ -2714,7 +2714,7 @@ impl File {
pub fn from_proto( pub fn from_proto(
proto: rpc2::proto::File, proto: rpc2::proto::File,
worktree: Handle<Worktree>, worktree: Model<Worktree>,
cx: &AppContext, cx: &AppContext,
) -> Result<Self> { ) -> Result<Self> {
let worktree_id = worktree let worktree_id = worktree
@ -2815,7 +2815,7 @@ pub type UpdatedGitRepositoriesSet = Arc<[(Arc<Path>, GitRepositoryChange)]>;
impl Entry { impl Entry {
fn new( fn new(
path: Arc<Path>, path: Arc<Path>,
metadata: &fs::Metadata, metadata: &fs2::Metadata,
next_entry_id: &AtomicUsize, next_entry_id: &AtomicUsize,
root_char_bag: CharBag, root_char_bag: CharBag,
) -> Self { ) -> Self {

View File

@ -42,6 +42,7 @@ sha1 = "0.10.5"
ndarray = { version = "0.15.0" } ndarray = { version = "0.15.0" }
[dev-dependencies] [dev-dependencies]
ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] } language = { path = "../language", features = ["test-support"] }

View File

@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize, pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>, finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>, finished_files_rx: channel::Receiver<FileToEmbed>,
api_key: Option<String>,
} }
#[derive(Clone)] #[derive(Clone)]
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
} }
impl EmbeddingQueue { impl EmbeddingQueue {
pub fn new( pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: Arc<Background>,
api_key: Option<String>,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded(); let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self { Self {
embedding_provider, embedding_provider,
@ -64,14 +59,9 @@ impl EmbeddingQueue {
pending_batch_token_count: 0, pending_batch_token_count: 0,
finished_files_tx, finished_files_tx,
finished_files_rx, finished_files_rx,
api_key,
} }
} }
pub fn set_api_key(&mut self, api_key: Option<String>) {
self.api_key = api_key
}
pub fn push(&mut self, file: FileToEmbed) { pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() { if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap(); self.finished_files_tx.try_send(file).unwrap();
@ -118,7 +108,6 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone(); let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
self.executor self.executor
.spawn(async move { .spawn(async move {
@ -143,7 +132,7 @@ impl EmbeddingQueue {
return; return;
}; };
match embedding_provider.embed_batch(spans, api_key).await { match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => { Ok(embeddings) => {
let mut embeddings = embeddings.into_iter(); let mut embeddings = embeddings.into_iter();
for fragment in batch { for fragment in batch {

View File

@ -1,4 +1,7 @@
use ai::embedding::{Embedding, EmbeddingProvider}; use ai::{
embedding::{Embedding, EmbeddingProvider},
models::TruncationDirection,
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use language::{Grammar, Language}; use language::{Grammar, Language};
use rusqlite::{ use rusqlite::{
@ -108,7 +111,14 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("<item>", &content); .replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str()); let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span); let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span { Ok(vec![Span {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
@ -131,7 +141,15 @@ impl CodeContextRetriever {
) )
.replace("<item>", &content); .replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str()); let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span { Ok(vec![Span {
range: 0..content.len(), range: 0..content.len(),
content: document_span, content: document_span,
@ -222,8 +240,13 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref()) .replace("<language>", language_name.as_ref())
.replace("item", &span.content); .replace("item", &span.content);
let (document_content, token_count) = let model = self.embedding_provider.base_model();
self.embedding_provider.truncate(&document_content); let document_content = model.truncate(
&document_content,
model.capacity()?,
TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_content)?;
span.content = document_content; span.content = document_content;
span.token_count = token_count; span.token_count = token_count;

View File

@ -7,7 +7,8 @@ pub mod semantic_index_settings;
mod semantic_index_tests; mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings; use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase; use db::VectorDatabase;
@ -88,7 +89,7 @@ pub fn init(
let semantic_index = SemanticIndex::new( let semantic_index = SemanticIndex::new(
fs, fs,
db_file_path, db_file_path,
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
language_registry, language_registry,
cx.clone(), cx.clone(),
) )
@ -123,8 +124,6 @@ pub struct SemanticIndex {
_embedding_task: Task<()>, _embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>, _parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>, projects: HashMap<WeakModelHandle<Project>, ProjectState>,
api_key: Option<String>,
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
} }
struct ProjectState { struct ProjectState {
@ -278,18 +277,18 @@ impl SemanticIndex {
} }
} }
pub fn authenticate(&mut self, cx: &AppContext) { pub fn authenticate(&mut self, cx: &AppContext) -> bool {
if self.api_key.is_none() { if !self.embedding_provider.has_credentials() {
self.api_key = self.embedding_provider.retrieve_credentials(cx); self.embedding_provider.retrieve_credentials(cx);
} else {
self.embedding_queue return true;
.lock()
.set_api_key(self.api_key.clone());
} }
self.embedding_provider.has_credentials()
} }
pub fn is_authenticated(&self) -> bool { pub fn is_authenticated(&self) -> bool {
self.api_key.is_some() self.embedding_provider.has_credentials()
} }
pub fn enabled(cx: &AppContext) -> bool { pub fn enabled(cx: &AppContext) -> bool {
@ -339,7 +338,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| { Ok(cx.add_model(|cx| {
let t0 = Instant::now(); let t0 = Instant::now();
let embedding_queue = let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None); EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({ let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files(); let embedded_files = embedding_queue.finished_files();
let db = db.clone(); let db = db.clone();
@ -404,8 +403,6 @@ impl SemanticIndex {
_embedding_task, _embedding_task,
_parsing_files_tasks, _parsing_files_tasks,
projects: Default::default(), projects: Default::default(),
api_key: None,
embedding_queue
} }
})) }))
} }
@ -720,13 +717,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx); let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone(); let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
index.await?; index.await?;
let t0 = Instant::now(); let t0 = Instant::now();
let query = embedding_provider let query = embedding_provider
.embed_batch(vec![query], api_key) .embed_batch(vec![query])
.await? .await?
.pop() .pop()
.ok_or_else(|| anyhow!("could not embed query"))?; .ok_or_else(|| anyhow!("could not embed query"))?;
@ -944,7 +941,6 @@ impl SemanticIndex {
let fs = self.fs.clone(); let fs = self.fs.clone();
let db_path = self.db.path().clone(); let db_path = self.db.path().clone();
let background = cx.background().clone(); let background = cx.background().clone();
let api_key = self.api_key.clone();
cx.background().spawn(async move { cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new(); let mut results = Vec::<SearchResult>::new();
@ -959,15 +955,10 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language) .parse_file_with_template(None, &snapshot.text(), language)
.log_err() .log_err()
.unwrap_or_default(); .unwrap_or_default();
if Self::embed_spans( if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
&mut spans, .await
embedding_provider.as_ref(), .log_err()
&db, .is_some()
api_key.clone(),
)
.await
.log_err()
.is_some()
{ {
for span in spans { for span in spans {
let similarity = span.embedding.unwrap().similarity(&query); let similarity = span.embedding.unwrap().similarity(&query);
@ -1007,9 +998,8 @@ impl SemanticIndex {
project: ModelHandle<Project>, project: ModelHandle<Project>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if self.api_key.is_none() { if !self.is_authenticated() {
self.authenticate(cx); if !self.authenticate(cx) {
if self.api_key.is_none() {
return Task::ready(Err(anyhow!("user is not authenticated"))); return Task::ready(Err(anyhow!("user is not authenticated")));
} }
} }
@ -1192,7 +1182,6 @@ impl SemanticIndex {
spans: &mut [Span], spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider, embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase, db: &VectorDatabase,
api_key: Option<String>,
) -> Result<()> { ) -> Result<()> {
let mut batch = Vec::new(); let mut batch = Vec::new();
let mut batch_tokens = 0; let mut batch_tokens = 0;
@ -1215,7 +1204,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), api_key.clone()) .embed_batch(mem::take(&mut batch))
.await?; .await?;
embeddings.extend(batch_embeddings); embeddings.extend(batch_embeddings);
batch_tokens = 0; batch_tokens = 0;
@ -1227,7 +1216,7 @@ impl SemanticIndex {
if !batch.is_empty() { if !batch.is_empty() {
let batch_embeddings = embedding_provider let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), api_key) .embed_batch(mem::take(&mut batch))
.await?; .await?;
embeddings.extend(batch_embeddings); embeddings.extend(batch_embeddings);

View File

@ -4,10 +4,9 @@ use crate::{
semantic_index_settings::SemanticIndexSettings, semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
}; };
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}; use ai::test::FakeEmbeddingProvider;
use anyhow::Result;
use async_trait::async_trait; use gpui::{executor::Deterministic, Task, TestAppContext};
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex; use parking_lot::Mutex;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
use rand::{rngs::StdRng, Rng}; use rand::{rngs::StdRng, Rng};
use serde_json::json; use serde_json::json;
use settings::SettingsStore; use settings::SettingsStore;
use std::{ use std::{path::Path, sync::Arc, time::SystemTime};
path::Path,
sync::{
atomic::{self, AtomicUsize},
Arc,
},
time::{Instant, SystemTime},
};
use unindent::Unindent; use unindent::Unindent;
use util::RandomCharIter; use util::RandomCharIter;
@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None); let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files { for file in &files {
queue.push(file.clone()); queue.push(file.clone());
} }
@ -280,7 +272,7 @@ fn assert_search_results(
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_rust() { async fn test_code_context_retrieval_rust() {
let language = rust_lang(); let language = rust_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = " let text = "
@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_json() { async fn test_code_context_retrieval_json() {
let language = json_lang(); let language = json_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
@ -466,7 +458,7 @@ fn assert_documents_eq(
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_javascript() { async fn test_code_context_retrieval_javascript() {
let language = js_lang(); let language = js_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = " let text = "
@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_lua() { async fn test_code_context_retrieval_lua() {
let language = lua_lang(); let language = lua_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_elixir() { async fn test_code_context_retrieval_elixir() {
let language = elixir_lang(); let language = elixir_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_cpp() { async fn test_code_context_retrieval_cpp() {
let language = cpp_lang(); let language = cpp_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = " let text = "
@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_ruby() { async fn test_code_context_retrieval_ruby() {
let language = ruby_lang(); let language = ruby_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test] #[gpui::test]
async fn test_code_context_retrieval_php() { async fn test_code_context_retrieval_php() {
let language = php_lang(); let language = php_lang();
let embedding_provider = Arc::new(DummyEmbeddings {}); let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider); let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#" let text = r#"
@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
); );
} }
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
}
impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Fake Credentials".to_string())
}
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
}
fn max_tokens_per_batch(&self) -> usize {
200
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(
&self,
spans: Vec<String>,
_api_key: Option<String>,
) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
fn js_lang() -> Arc<Language> { fn js_lang() -> Arc<Language> {
Arc::new( Arc::new(
Language::new( Language::new(

View File

@ -1,9 +1,11 @@
mod colors;
mod focus; mod focus;
mod kitchen_sink; mod kitchen_sink;
mod scroll; mod scroll;
mod text; mod text;
mod z_index; mod z_index;
pub use colors::*;
pub use focus::*; pub use focus::*;
pub use kitchen_sink::*; pub use kitchen_sink::*;
pub use scroll::*; pub use scroll::*;

View File

@ -0,0 +1,38 @@
use crate::story::Story;
use gpui2::{px, Div, Render};
use ui::prelude::*;
pub struct ColorsStory;
impl Render for ColorsStory {
type Element = Div<Self>;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
let color_scales = theme2::default_color_scales();
Story::container(cx)
.child(Story::title(cx, "Colors"))
.child(
div()
.id("colors")
.flex()
.flex_col()
.gap_1()
.overflow_y_scroll()
.text_color(gpui2::white())
.children(color_scales.into_iter().map(|(name, scale)| {
div()
.flex()
.child(
div()
.w(px(75.))
.line_height(px(24.))
.child(name.to_string()),
)
.child(div().flex().gap_1().children(
(1..=12).map(|step| div().flex().size_6().bg(scale.step(cx, step))),
))
})),
)
}
}

View File

@ -1,9 +1,9 @@
use crate::themes::rose_pine;
use gpui2::{ use gpui2::{
div, Focusable, KeyBinding, ParentElement, StatelessInteractive, Styled, View, VisualContext, div, Div, FocusEnabled, Focusable, KeyBinding, ParentElement, Render, StatefulInteraction,
WindowContext, StatelessInteractive, Styled, View, VisualContext, WindowContext,
}; };
use serde::Deserialize; use serde::Deserialize;
use theme2::theme;
#[derive(Clone, Default, PartialEq, Deserialize)] #[derive(Clone, Default, PartialEq, Deserialize)]
struct ActionA; struct ActionA;
@ -14,12 +14,10 @@ struct ActionB;
#[derive(Clone, Default, PartialEq, Deserialize)] #[derive(Clone, Default, PartialEq, Deserialize)]
struct ActionC; struct ActionC;
pub struct FocusStory { pub struct FocusStory {}
text: View<()>,
}
impl FocusStory { impl FocusStory {
pub fn view(cx: &mut WindowContext) -> View<()> { pub fn view(cx: &mut WindowContext) -> View<Self> {
cx.bind_keys([ cx.bind_keys([
KeyBinding::new("cmd-a", ActionA, Some("parent")), KeyBinding::new("cmd-a", ActionA, Some("parent")),
KeyBinding::new("cmd-a", ActionB, Some("child-1")), KeyBinding::new("cmd-a", ActionB, Some("child-1")),
@ -27,91 +25,92 @@ impl FocusStory {
]); ]);
cx.register_action_type::<ActionA>(); cx.register_action_type::<ActionA>();
cx.register_action_type::<ActionB>(); cx.register_action_type::<ActionB>();
let theme = rose_pine();
let color_1 = theme.lowest.negative.default.foreground; cx.build_view(move |cx| Self {})
let color_2 = theme.lowest.positive.default.foreground; }
let color_3 = theme.lowest.warning.default.foreground; }
let color_4 = theme.lowest.accent.default.foreground;
let color_5 = theme.lowest.variant.default.foreground; impl Render for FocusStory {
let color_6 = theme.highest.negative.default.foreground; type Element = Div<Self, StatefulInteraction<Self>, FocusEnabled<Self>>;
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
let theme = theme(cx);
let color_1 = theme.git_created;
let color_2 = theme.git_modified;
let color_3 = theme.git_deleted;
let color_4 = theme.git_conflict;
let color_5 = theme.git_ignored;
let color_6 = theme.git_renamed;
let child_1 = cx.focus_handle(); let child_1 = cx.focus_handle();
let child_2 = cx.focus_handle(); let child_2 = cx.focus_handle();
cx.build_view( div()
|_| (), .id("parent")
move |_, cx| { .focusable()
.context("parent")
.on_action(|_, action: &ActionA, phase, cx| {
println!("Action A dispatched on parent during {:?}", phase);
})
.on_action(|_, action: &ActionB, phase, cx| {
println!("Action B dispatched on parent during {:?}", phase);
})
.on_focus(|_, _, _| println!("Parent focused"))
.on_blur(|_, _, _| println!("Parent blurred"))
.on_focus_in(|_, _, _| println!("Parent focus_in"))
.on_focus_out(|_, _, _| println!("Parent focus_out"))
.on_key_down(|_, event, phase, _| {
println!("Key down on parent {:?} {:?}", phase, event)
})
.on_key_up(|_, event, phase, _| println!("Key up on parent {:?} {:?}", phase, event))
.size_full()
.bg(color_1)
.focus(|style| style.bg(color_2))
.focus_in(|style| style.bg(color_3))
.child(
div() div()
.id("parent") .track_focus(&child_1)
.focusable() .context("child-1")
.context("parent")
.on_action(|_, action: &ActionA, phase, cx| {
println!("Action A dispatched on parent during {:?}", phase);
})
.on_action(|_, action: &ActionB, phase, cx| { .on_action(|_, action: &ActionB, phase, cx| {
println!("Action B dispatched on parent during {:?}", phase); println!("Action B dispatched on child 1 during {:?}", phase);
}) })
.on_focus(|_, _, _| println!("Parent focused")) .w_full()
.on_blur(|_, _, _| println!("Parent blurred")) .h_6()
.on_focus_in(|_, _, _| println!("Parent focus_in")) .bg(color_4)
.on_focus_out(|_, _, _| println!("Parent focus_out")) .focus(|style| style.bg(color_5))
.in_focus(|style| style.bg(color_6))
.on_focus(|_, _, _| println!("Child 1 focused"))
.on_blur(|_, _, _| println!("Child 1 blurred"))
.on_focus_in(|_, _, _| println!("Child 1 focus_in"))
.on_focus_out(|_, _, _| println!("Child 1 focus_out"))
.on_key_down(|_, event, phase, _| { .on_key_down(|_, event, phase, _| {
println!("Key down on parent {:?} {:?}", phase, event) println!("Key down on child 1 {:?} {:?}", phase, event)
}) })
.on_key_up(|_, event, phase, _| { .on_key_up(|_, event, phase, _| {
println!("Key up on parent {:?} {:?}", phase, event) println!("Key up on child 1 {:?} {:?}", phase, event)
}) })
.size_full() .child("Child 1"),
.bg(color_1) )
.focus(|style| style.bg(color_2)) .child(
.focus_in(|style| style.bg(color_3)) div()
.child( .track_focus(&child_2)
div() .context("child-2")
.track_focus(&child_1) .on_action(|_, action: &ActionC, phase, cx| {
.context("child-1") println!("Action C dispatched on child 2 during {:?}", phase);
.on_action(|_, action: &ActionB, phase, cx| { })
println!("Action B dispatched on child 1 during {:?}", phase); .w_full()
}) .h_6()
.w_full() .bg(color_4)
.h_6() .on_focus(|_, _, _| println!("Child 2 focused"))
.bg(color_4) .on_blur(|_, _, _| println!("Child 2 blurred"))
.focus(|style| style.bg(color_5)) .on_focus_in(|_, _, _| println!("Child 2 focus_in"))
.in_focus(|style| style.bg(color_6)) .on_focus_out(|_, _, _| println!("Child 2 focus_out"))
.on_focus(|_, _, _| println!("Child 1 focused")) .on_key_down(|_, event, phase, _| {
.on_blur(|_, _, _| println!("Child 1 blurred")) println!("Key down on child 2 {:?} {:?}", phase, event)
.on_focus_in(|_, _, _| println!("Child 1 focus_in")) })
.on_focus_out(|_, _, _| println!("Child 1 focus_out")) .on_key_up(|_, event, phase, _| {
.on_key_down(|_, event, phase, _| { println!("Key up on child 2 {:?} {:?}", phase, event)
println!("Key down on child 1 {:?} {:?}", phase, event) })
}) .child("Child 2"),
.on_key_up(|_, event, phase, _| { )
println!("Key up on child 1 {:?} {:?}", phase, event)
})
.child("Child 1"),
)
.child(
div()
.track_focus(&child_2)
.context("child-2")
.on_action(|_, action: &ActionC, phase, cx| {
println!("Action C dispatched on child 2 during {:?}", phase);
})
.w_full()
.h_6()
.bg(color_4)
.on_focus(|_, _, _| println!("Child 2 focused"))
.on_blur(|_, _, _| println!("Child 2 blurred"))
.on_focus_in(|_, _, _| println!("Child 2 focus_in"))
.on_focus_out(|_, _, _| println!("Child 2 focus_out"))
.on_key_down(|_, event, phase, _| {
println!("Key down on child 2 {:?} {:?}", phase, event)
})
.on_key_up(|_, event, phase, _| {
println!("Key up on child 2 {:?} {:?}", phase, event)
})
.child("Child 2"),
)
},
)
} }
} }

View File

@ -1,26 +1,23 @@
use gpui2::{AppContext, Context, View}; use crate::{
story::Story,
story_selector::{ComponentStory, ElementStory},
};
use gpui2::{Div, Render, StatefulInteraction, View, VisualContext};
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use ui::prelude::*; use ui::prelude::*;
use crate::story::Story; pub struct KitchenSinkStory;
use crate::story_selector::{ComponentStory, ElementStory};
pub struct KitchenSinkStory {}
impl KitchenSinkStory { impl KitchenSinkStory {
pub fn new() -> Self { pub fn view(cx: &mut WindowContext) -> View<Self> {
Self {} cx.build_view(|cx| Self)
} }
}
pub fn view(cx: &mut AppContext) -> View<Self> { impl Render for KitchenSinkStory {
{ type Element = Div<Self, StatefulInteraction<Self>>;
let state = cx.entity(|cx| Self::new());
let render = Self::render;
View::for_handle(state, render)
}
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> { fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
let element_stories = ElementStory::iter() let element_stories = ElementStory::iter()
.map(|selector| selector.story(cx)) .map(|selector| selector.story(cx))
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@ -1,58 +1,54 @@
use crate::themes::rose_pine;
use gpui2::{ use gpui2::{
div, px, Component, ParentElement, SharedString, Styled, View, VisualContext, WindowContext, div, px, Component, Div, ParentElement, Render, SharedString, StatefulInteraction, Styled,
View, VisualContext, WindowContext,
}; };
use theme2::theme;
pub struct ScrollStory { pub struct ScrollStory;
text: View<()>,
}
impl ScrollStory { impl ScrollStory {
pub fn view(cx: &mut WindowContext) -> View<()> { pub fn view(cx: &mut WindowContext) -> View<ScrollStory> {
let theme = rose_pine(); cx.build_view(|cx| ScrollStory)
{
cx.build_view(|cx| (), move |_, cx| checkerboard(1))
}
} }
} }
fn checkerboard<S>(depth: usize) -> impl Component<S> impl Render for ScrollStory {
where type Element = Div<Self, StatefulInteraction<Self>>;
S: 'static + Send + Sync,
{
let theme = rose_pine();
let color_1 = theme.lowest.positive.default.background;
let color_2 = theme.lowest.warning.default.background;
div() fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
.id("parent") let theme = theme(cx);
.bg(theme.lowest.base.default.background) let color_1 = theme.git_created;
.size_full() let color_2 = theme.git_modified;
.overflow_scroll()
.children((0..10).map(|row| { div()
div() .id("parent")
.w(px(1000.)) .bg(theme.background)
.h(px(100.)) .size_full()
.flex() .overflow_scroll()
.flex_row() .children((0..10).map(|row| {
.children((0..10).map(|column| { div()
let id = SharedString::from(format!("{}, {}", row, column)); .w(px(1000.))
let bg = if row % 2 == column % 2 { .h(px(100.))
color_1 .flex()
} else { .flex_row()
color_2 .children((0..10).map(|column| {
}; let id = SharedString::from(format!("{}, {}", row, column));
div().id(id).bg(bg).size(px(100. / depth as f32)).when( let bg = if row % 2 == column % 2 {
row >= 5 && column >= 5, color_1
|d| { } else {
d.overflow_scroll() color_2
.child(div().size(px(50.)).bg(color_1)) };
.child(div().size(px(50.)).bg(color_2)) div().id(id).bg(bg).size(px(100. as f32)).when(
.child(div().size(px(50.)).bg(color_1)) row >= 5 && column >= 5,
.child(div().size(px(50.)).bg(color_2)) |d| {
}, d.overflow_scroll()
) .child(div().size(px(50.)).bg(color_1))
})) .child(div().size(px(50.)).bg(color_2))
})) .child(div().size(px(50.)).bg(color_1))
.child(div().size(px(50.)).bg(color_2))
},
)
}))
}))
}
} }

View File

@ -1,20 +1,21 @@
use gpui2::{div, white, ParentElement, Styled, View, VisualContext, WindowContext}; use gpui2::{div, white, Div, ParentElement, Render, Styled, View, VisualContext, WindowContext};
pub struct TextStory { pub struct TextStory;
text: View<()>,
}
impl TextStory { impl TextStory {
pub fn view(cx: &mut WindowContext) -> View<()> { pub fn view(cx: &mut WindowContext) -> View<Self> {
cx.build_view(|cx| (), |_, cx| { cx.build_view(|cx| Self)
div() }
.size_full() }
.bg(white())
.child(concat!( impl Render for TextStory {
"The quick brown fox jumps over the lazy dog. ", type Element = Div<Self>;
"Meanwhile, the lazy dog decided it was time for a change. ",
"He started daily workout routines, ate healthier and became the fastest dog in town.", fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
)) div().size_full().bg(white()).child(concat!(
}) "The quick brown fox jumps over the lazy dog. ",
"Meanwhile, the lazy dog decided it was time for a change. ",
"He started daily workout routines, ate healthier and became the fastest dog in town.",
))
} }
} }

View File

@ -1,15 +1,16 @@
use gpui2::{px, rgb, Div, Hsla}; use gpui2::{px, rgb, Div, Hsla, Render};
use ui::prelude::*; use ui::prelude::*;
use crate::story::Story; use crate::story::Story;
/// A reimplementation of the MDN `z-index` example, found here: /// A reimplementation of the MDN `z-index` example, found here:
/// [https://developer.mozilla.org/en-US/docs/Web/CSS/z-index](https://developer.mozilla.org/en-US/docs/Web/CSS/z-index). /// [https://developer.mozilla.org/en-US/docs/Web/CSS/z-index](https://developer.mozilla.org/en-US/docs/Web/CSS/z-index).
#[derive(Component)]
pub struct ZIndexStory; pub struct ZIndexStory;
impl ZIndexStory { impl Render for ZIndexStory {
fn render<V: 'static>(self, _view: &mut V, cx: &mut ViewContext<V>) -> impl Component<V> { type Element = Div<Self>;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
Story::container(cx) Story::container(cx)
.child(Story::title(cx, "z-index")) .child(Story::title(cx, "z-index"))
.child( .child(

View File

@ -7,13 +7,14 @@ use clap::builder::PossibleValue;
use clap::ValueEnum; use clap::ValueEnum;
use gpui2::{AnyView, VisualContext}; use gpui2::{AnyView, VisualContext};
use strum::{EnumIter, EnumString, IntoEnumIterator}; use strum::{EnumIter, EnumString, IntoEnumIterator};
use ui::prelude::*; use ui::{prelude::*, AvatarStory, ButtonStory, DetailsStory, IconStory, InputStory, LabelStory};
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display, EnumString, EnumIter)] #[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display, EnumString, EnumIter)]
#[strum(serialize_all = "snake_case")] #[strum(serialize_all = "snake_case")]
pub enum ElementStory { pub enum ElementStory {
Avatar, Avatar,
Button, Button,
Colors,
Details, Details,
Focus, Focus,
Icon, Icon,
@ -27,18 +28,17 @@ pub enum ElementStory {
impl ElementStory { impl ElementStory {
pub fn story(&self, cx: &mut WindowContext) -> AnyView { pub fn story(&self, cx: &mut WindowContext) -> AnyView {
match self { match self {
Self::Avatar => { cx.build_view(|cx| (), |_, _| ui::AvatarStory.render()) }.into_any(), Self::Colors => cx.build_view(|_| ColorsStory).into_any(),
Self::Button => { cx.build_view(|cx| (), |_, _| ui::ButtonStory.render()) }.into_any(), Self::Avatar => cx.build_view(|_| AvatarStory).into_any(),
Self::Details => { Self::Button => cx.build_view(|_| ButtonStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::DetailsStory.render()) }.into_any() Self::Details => cx.build_view(|_| DetailsStory).into_any(),
}
Self::Focus => FocusStory::view(cx).into_any(), Self::Focus => FocusStory::view(cx).into_any(),
Self::Icon => { cx.build_view(|cx| (), |_, _| ui::IconStory.render()) }.into_any(), Self::Icon => cx.build_view(|_| IconStory).into_any(),
Self::Input => { cx.build_view(|cx| (), |_, _| ui::InputStory.render()) }.into_any(), Self::Input => cx.build_view(|_| InputStory).into_any(),
Self::Label => { cx.build_view(|cx| (), |_, _| ui::LabelStory.render()) }.into_any(), Self::Label => cx.build_view(|_| LabelStory).into_any(),
Self::Scroll => ScrollStory::view(cx).into_any(), Self::Scroll => ScrollStory::view(cx).into_any(),
Self::Text => TextStory::view(cx).into_any(), Self::Text => TextStory::view(cx).into_any(),
Self::ZIndex => { cx.build_view(|cx| (), |_, _| ZIndexStory.render()) }.into_any(), Self::ZIndex => cx.build_view(|_| ZIndexStory).into_any(),
} }
} }
} }
@ -77,69 +77,31 @@ pub enum ComponentStory {
impl ComponentStory { impl ComponentStory {
pub fn story(&self, cx: &mut WindowContext) -> AnyView { pub fn story(&self, cx: &mut WindowContext) -> AnyView {
match self { match self {
Self::AssistantPanel => { Self::AssistantPanel => cx.build_view(|_| ui::AssistantPanelStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::AssistantPanelStory.render()) }.into_any() Self::Buffer => cx.build_view(|_| ui::BufferStory).into_any(),
} Self::Breadcrumb => cx.build_view(|_| ui::BreadcrumbStory).into_any(),
Self::Buffer => { cx.build_view(|cx| (), |_, _| ui::BufferStory.render()) }.into_any(), Self::ChatPanel => cx.build_view(|_| ui::ChatPanelStory).into_any(),
Self::Breadcrumb => { Self::CollabPanel => cx.build_view(|_| ui::CollabPanelStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::BreadcrumbStory.render()) }.into_any() Self::CommandPalette => cx.build_view(|_| ui::CommandPaletteStory).into_any(),
} Self::ContextMenu => cx.build_view(|_| ui::ContextMenuStory).into_any(),
Self::ChatPanel => { Self::Facepile => cx.build_view(|_| ui::FacepileStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::ChatPanelStory.render()) }.into_any() Self::Keybinding => cx.build_view(|_| ui::KeybindingStory).into_any(),
} Self::LanguageSelector => cx.build_view(|_| ui::LanguageSelectorStory).into_any(),
Self::CollabPanel => { Self::MultiBuffer => cx.build_view(|_| ui::MultiBufferStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::CollabPanelStory.render()) }.into_any() Self::NotificationsPanel => cx.build_view(|cx| ui::NotificationsPanelStory).into_any(),
} Self::Palette => cx.build_view(|cx| ui::PaletteStory).into_any(),
Self::CommandPalette => { Self::Panel => cx.build_view(|cx| ui::PanelStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::CommandPaletteStory.render()) }.into_any() Self::ProjectPanel => cx.build_view(|_| ui::ProjectPanelStory).into_any(),
} Self::RecentProjects => cx.build_view(|_| ui::RecentProjectsStory).into_any(),
Self::ContextMenu => { Self::Tab => cx.build_view(|_| ui::TabStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::ContextMenuStory.render()) }.into_any() Self::TabBar => cx.build_view(|_| ui::TabBarStory).into_any(),
} Self::Terminal => cx.build_view(|_| ui::TerminalStory).into_any(),
Self::Facepile => { Self::ThemeSelector => cx.build_view(|_| ui::ThemeSelectorStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::FacepileStory.render()) }.into_any() Self::Toast => cx.build_view(|_| ui::ToastStory).into_any(),
} Self::Toolbar => cx.build_view(|_| ui::ToolbarStory).into_any(),
Self::Keybinding => { Self::TrafficLights => cx.build_view(|_| ui::TrafficLightsStory).into_any(),
{ cx.build_view(|cx| (), |_, _| ui::KeybindingStory.render()) }.into_any() Self::Copilot => cx.build_view(|_| ui::CopilotModalStory).into_any(),
}
Self::LanguageSelector => {
{ cx.build_view(|cx| (), |_, _| ui::LanguageSelectorStory.render()) }.into_any()
}
Self::MultiBuffer => {
{ cx.build_view(|cx| (), |_, _| ui::MultiBufferStory.render()) }.into_any()
}
Self::NotificationsPanel => {
{ cx.build_view(|cx| (), |_, _| ui::NotificationsPanelStory.render()) }.into_any()
}
Self::Palette => {
{ cx.build_view(|cx| (), |_, _| ui::PaletteStory.render()) }.into_any()
}
Self::Panel => { cx.build_view(|cx| (), |_, _| ui::PanelStory.render()) }.into_any(),
Self::ProjectPanel => {
{ cx.build_view(|cx| (), |_, _| ui::ProjectPanelStory.render()) }.into_any()
}
Self::RecentProjects => {
{ cx.build_view(|cx| (), |_, _| ui::RecentProjectsStory.render()) }.into_any()
}
Self::Tab => { cx.build_view(|cx| (), |_, _| ui::TabStory.render()) }.into_any(),
Self::TabBar => { cx.build_view(|cx| (), |_, _| ui::TabBarStory.render()) }.into_any(),
Self::Terminal => {
{ cx.build_view(|cx| (), |_, _| ui::TerminalStory.render()) }.into_any()
}
Self::ThemeSelector => {
{ cx.build_view(|cx| (), |_, _| ui::ThemeSelectorStory.render()) }.into_any()
}
Self::TitleBar => ui::TitleBarStory::view(cx).into_any(), Self::TitleBar => ui::TitleBarStory::view(cx).into_any(),
Self::Toast => { cx.build_view(|cx| (), |_, _| ui::ToastStory.render()) }.into_any(),
Self::Toolbar => {
{ cx.build_view(|cx| (), |_, _| ui::ToolbarStory.render()) }.into_any()
}
Self::TrafficLights => {
{ cx.build_view(|cx| (), |_, _| ui::TrafficLightsStory.render()) }.into_any()
}
Self::Copilot => {
{ cx.build_view(|cx| (), |_, _| ui::CopilotModalStory.render()) }.into_any()
}
Self::Workspace => ui::WorkspaceStory::view(cx).into_any(), Self::Workspace => ui::WorkspaceStory::view(cx).into_any(),
} }
} }

View File

@ -4,21 +4,20 @@ mod assets;
mod stories; mod stories;
mod story; mod story;
mod story_selector; mod story_selector;
mod themes;
use std::sync::Arc; use std::sync::Arc;
use clap::Parser; use clap::Parser;
use gpui2::{ use gpui2::{
div, px, size, AnyView, AppContext, Bounds, ViewContext, VisualContext, WindowBounds, div, px, size, AnyView, AppContext, Bounds, Div, Render, ViewContext, VisualContext,
WindowOptions, WindowBounds, WindowOptions,
}; };
use log::LevelFilter; use log::LevelFilter;
use settings2::{default_settings, Settings, SettingsStore}; use settings2::{default_settings, Settings, SettingsStore};
use simplelog::SimpleLogger; use simplelog::SimpleLogger;
use story_selector::ComponentStory; use story_selector::ComponentStory;
use theme2::{ThemeRegistry, ThemeSettings}; use theme2::{ThemeRegistry, ThemeSettings};
use ui::{prelude::*, themed}; use ui::prelude::*;
use crate::assets::Assets; use crate::assets::Assets;
use crate::story_selector::StorySelector; use crate::story_selector::StorySelector;
@ -50,7 +49,6 @@ fn main() {
let story_selector = args.story.clone(); let story_selector = args.story.clone();
let theme_name = args.theme.unwrap_or("One Dark".to_string()); let theme_name = args.theme.unwrap_or("One Dark".to_string());
let theme = themes::load_theme(theme_name.clone()).unwrap();
let asset_source = Arc::new(Assets); let asset_source = Arc::new(Assets);
gpui2::App::production(asset_source).run(move |cx| { gpui2::App::production(asset_source).run(move |cx| {
@ -84,12 +82,7 @@ fn main() {
}), }),
..Default::default() ..Default::default()
}, },
move |cx| { move |cx| cx.build_view(|cx| StoryWrapper::new(selector.story(cx))),
cx.build_view(
|cx| StoryWrapper::new(selector.story(cx), theme),
StoryWrapper::render,
)
},
); );
cx.activate(true); cx.activate(true);
@ -99,22 +92,23 @@ fn main() {
#[derive(Clone)] #[derive(Clone)]
pub struct StoryWrapper { pub struct StoryWrapper {
story: AnyView, story: AnyView,
theme: Theme,
} }
impl StoryWrapper { impl StoryWrapper {
pub(crate) fn new(story: AnyView, theme: Theme) -> Self { pub(crate) fn new(story: AnyView) -> Self {
Self { story, theme } Self { story }
} }
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> { impl Render for StoryWrapper {
themed(self.theme.clone(), cx, |cx| { type Element = Div<Self>;
div()
.flex() fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
.flex_col() div()
.size_full() .flex()
.child(self.story.clone()) .flex_col()
}) .size_full()
.child(self.story.clone())
} }
} }

View File

@ -1,30 +0,0 @@
mod rose_pine;
pub use rose_pine::*;
use anyhow::{Context, Result};
use gpui2::serde_json;
use serde::Deserialize;
use ui::Theme;
use crate::assets::Assets;
#[derive(Deserialize)]
struct LegacyTheme {
pub base_theme: serde_json::Value,
}
/// Loads the [`Theme`] with the given name.
pub fn load_theme(name: String) -> Result<Theme> {
let theme_contents = Assets::get(&format!("themes/{name}.json"))
.with_context(|| format!("theme file not found: '{name}'"))?;
let legacy_theme: LegacyTheme =
serde_json::from_str(std::str::from_utf8(&theme_contents.data)?)
.context("failed to parse legacy theme")?;
let theme: Theme = serde_json::from_value(legacy_theme.base_theme.clone())
.context("failed to parse `base_theme`")?;
Ok(theme)
}

File diff suppressed because it is too large Load Diff

2118
crates/theme2/src/default.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,4 @@
use crate::{ use crate::{themes, Theme, ThemeMetadata};
themes::{one_dark, rose_pine, rose_pine_dawn, rose_pine_moon, sandcastle},
Theme, ThemeMetadata,
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use gpui2::SharedString; use gpui2::SharedString;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
@ -41,11 +38,45 @@ impl Default for ThemeRegistry {
}; };
this.insert_themes([ this.insert_themes([
one_dark(), themes::andromeda(),
rose_pine(), themes::atelier_cave_dark(),
rose_pine_dawn(), themes::atelier_cave_light(),
rose_pine_moon(), themes::atelier_dune_dark(),
sandcastle(), themes::atelier_dune_light(),
themes::atelier_estuary_dark(),
themes::atelier_estuary_light(),
themes::atelier_forest_dark(),
themes::atelier_forest_light(),
themes::atelier_heath_dark(),
themes::atelier_heath_light(),
themes::atelier_lakeside_dark(),
themes::atelier_lakeside_light(),
themes::atelier_plateau_dark(),
themes::atelier_plateau_light(),
themes::atelier_savanna_dark(),
themes::atelier_savanna_light(),
themes::atelier_seaside_dark(),
themes::atelier_seaside_light(),
themes::atelier_sulphurpool_dark(),
themes::atelier_sulphurpool_light(),
themes::ayu_dark(),
themes::ayu_light(),
themes::ayu_mirage(),
themes::gruvbox_dark(),
themes::gruvbox_dark_hard(),
themes::gruvbox_dark_soft(),
themes::gruvbox_light(),
themes::gruvbox_light_hard(),
themes::gruvbox_light_soft(),
themes::one_dark(),
themes::one_light(),
themes::rose_pine(),
themes::rose_pine_dawn(),
themes::rose_pine_moon(),
themes::sandcastle(),
themes::solarized_dark(),
themes::solarized_light(),
themes::summercamp(),
]); ]);
this this

164
crates/theme2/src/scale.rs Normal file
View File

@ -0,0 +1,164 @@
use gpui2::{AppContext, Hsla};
use indexmap::IndexMap;
use crate::{theme, Appearance};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ColorScaleName {
Gray,
Mauve,
Slate,
Sage,
Olive,
Sand,
Gold,
Bronze,
Brown,
Yellow,
Amber,
Orange,
Tomato,
Red,
Ruby,
Crimson,
Pink,
Plum,
Purple,
Violet,
Iris,
Indigo,
Blue,
Cyan,
Teal,
Jade,
Green,
Grass,
Lime,
Mint,
Sky,
Black,
White,
}
impl std::fmt::Display for ColorScaleName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::Gray => "Gray",
Self::Mauve => "Mauve",
Self::Slate => "Slate",
Self::Sage => "Sage",
Self::Olive => "Olive",
Self::Sand => "Sand",
Self::Gold => "Gold",
Self::Bronze => "Bronze",
Self::Brown => "Brown",
Self::Yellow => "Yellow",
Self::Amber => "Amber",
Self::Orange => "Orange",
Self::Tomato => "Tomato",
Self::Red => "Red",
Self::Ruby => "Ruby",
Self::Crimson => "Crimson",
Self::Pink => "Pink",
Self::Plum => "Plum",
Self::Purple => "Purple",
Self::Violet => "Violet",
Self::Iris => "Iris",
Self::Indigo => "Indigo",
Self::Blue => "Blue",
Self::Cyan => "Cyan",
Self::Teal => "Teal",
Self::Jade => "Jade",
Self::Green => "Green",
Self::Grass => "Grass",
Self::Lime => "Lime",
Self::Mint => "Mint",
Self::Sky => "Sky",
Self::Black => "Black",
Self::White => "White",
}
)
}
}
pub type ColorScale = [Hsla; 12];
pub type ColorScales = IndexMap<ColorScaleName, ColorScaleSet>;
/// A one-based step in a [`ColorScale`].
pub type ColorScaleStep = usize;
pub struct ColorScaleSet {
name: ColorScaleName,
light: ColorScale,
dark: ColorScale,
light_alpha: ColorScale,
dark_alpha: ColorScale,
}
impl ColorScaleSet {
pub fn new(
name: ColorScaleName,
light: ColorScale,
light_alpha: ColorScale,
dark: ColorScale,
dark_alpha: ColorScale,
) -> Self {
Self {
name,
light,
light_alpha,
dark,
dark_alpha,
}
}
pub fn name(&self) -> String {
self.name.to_string()
}
pub fn light(&self, step: ColorScaleStep) -> Hsla {
self.light[step - 1]
}
pub fn light_alpha(&self, step: ColorScaleStep) -> Hsla {
self.light_alpha[step - 1]
}
pub fn dark(&self, step: ColorScaleStep) -> Hsla {
self.dark[step - 1]
}
pub fn dark_alpha(&self, step: ColorScaleStep) -> Hsla {
self.dark_alpha[step - 1]
}
fn current_appearance(cx: &AppContext) -> Appearance {
let theme = theme(cx);
if theme.metadata.is_light {
Appearance::Light
} else {
Appearance::Dark
}
}
pub fn step(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
let appearance = Self::current_appearance(cx);
match appearance {
Appearance::Light => self.light(step),
Appearance::Dark => self.dark(step),
}
}
pub fn step_alpha(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
let appearance = Self::current_appearance(cx);
match appearance {
Appearance::Light => self.light_alpha(step),
Appearance::Dark => self.dark_alpha(step),
}
}
}

View File

@ -1,14 +1,24 @@
mod default;
mod registry; mod registry;
mod scale;
mod settings; mod settings;
mod themes; mod themes;
pub use default::*;
pub use registry::*; pub use registry::*;
pub use scale::*;
pub use settings::*; pub use settings::*;
use gpui2::{AppContext, HighlightStyle, Hsla, SharedString}; use gpui2::{AppContext, HighlightStyle, Hsla, SharedString};
use settings2::Settings; use settings2::Settings;
use std::sync::Arc; use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum Appearance {
Light,
Dark,
}
pub fn init(cx: &mut AppContext) { pub fn init(cx: &mut AppContext) {
cx.set_global(ThemeRegistry::default()); cx.set_global(ThemeRegistry::default());
ThemeSettings::register(cx); ThemeSettings::register(cx);
@ -18,6 +28,10 @@ pub fn active_theme<'a>(cx: &'a AppContext) -> &'a Arc<Theme> {
&ThemeSettings::get_global(cx).active_theme &ThemeSettings::get_global(cx).active_theme
} }
pub fn theme(cx: &AppContext) -> Arc<Theme> {
active_theme(cx).clone()
}
pub struct Theme { pub struct Theme {
pub metadata: ThemeMetadata, pub metadata: ThemeMetadata,

View File

@ -0,0 +1,130 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn andromeda() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Andromeda".into(),
is_light: false,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x2b2f38ff).into(),
border_variant: rgba(0x2b2f38ff).into(),
border_focused: rgba(0x183934ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0x262933ff).into(),
surface: rgba(0x21242bff).into(),
background: rgba(0x262933ff).into(),
filled_element: rgba(0x262933ff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0x12231fff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0x12231fff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0xf7f7f8ff).into(),
text_muted: rgba(0xaca8aeff).into(),
text_placeholder: rgba(0xf82871ff).into(),
text_disabled: rgba(0x6b6b73ff).into(),
text_accent: rgba(0x10a793ff).into(),
icon_muted: rgba(0xaca8aeff).into(),
syntax: SyntaxTheme {
highlights: vec![
("emphasis".into(), rgba(0x10a793ff).into()),
("punctuation.bracket".into(), rgba(0xd8d5dbff).into()),
("attribute".into(), rgba(0x10a793ff).into()),
("variable".into(), rgba(0xf7f7f8ff).into()),
("predictive".into(), rgba(0x315f70ff).into()),
("property".into(), rgba(0x10a793ff).into()),
("variant".into(), rgba(0x10a793ff).into()),
("embedded".into(), rgba(0xf7f7f8ff).into()),
("string.special".into(), rgba(0xf29c14ff).into()),
("keyword".into(), rgba(0x10a793ff).into()),
("tag".into(), rgba(0x10a793ff).into()),
("enum".into(), rgba(0xf29c14ff).into()),
("link_text".into(), rgba(0xf29c14ff).into()),
("primary".into(), rgba(0xf7f7f8ff).into()),
("punctuation".into(), rgba(0xd8d5dbff).into()),
("punctuation.special".into(), rgba(0xd8d5dbff).into()),
("function".into(), rgba(0xfee56cff).into()),
("number".into(), rgba(0x96df71ff).into()),
("preproc".into(), rgba(0xf7f7f8ff).into()),
("operator".into(), rgba(0xf29c14ff).into()),
("constructor".into(), rgba(0x10a793ff).into()),
("string.escape".into(), rgba(0xafabb1ff).into()),
("string.special.symbol".into(), rgba(0xf29c14ff).into()),
("string".into(), rgba(0xf29c14ff).into()),
("comment".into(), rgba(0xafabb1ff).into()),
("hint".into(), rgba(0x618399ff).into()),
("type".into(), rgba(0x08e7c5ff).into()),
("label".into(), rgba(0x10a793ff).into()),
("comment.doc".into(), rgba(0xafabb1ff).into()),
("text.literal".into(), rgba(0xf29c14ff).into()),
("constant".into(), rgba(0x96df71ff).into()),
("string.regex".into(), rgba(0xf29c14ff).into()),
("emphasis.strong".into(), rgba(0x10a793ff).into()),
("title".into(), rgba(0xf7f7f8ff).into()),
("punctuation.delimiter".into(), rgba(0xd8d5dbff).into()),
("link_uri".into(), rgba(0x96df71ff).into()),
("boolean".into(), rgba(0x96df71ff).into()),
("punctuation.list_marker".into(), rgba(0xd8d5dbff).into()),
],
},
status_bar: rgba(0x262933ff).into(),
title_bar: rgba(0x262933ff).into(),
toolbar: rgba(0x1e2025ff).into(),
tab_bar: rgba(0x21242bff).into(),
editor: rgba(0x1e2025ff).into(),
editor_subheader: rgba(0x21242bff).into(),
editor_active_line: rgba(0x21242bff).into(),
terminal: rgba(0x1e2025ff).into(),
image_fallback_background: rgba(0x262933ff).into(),
git_created: rgba(0x96df71ff).into(),
git_modified: rgba(0x10a793ff).into(),
git_deleted: rgba(0xf82871ff).into(),
git_conflict: rgba(0xfee56cff).into(),
git_ignored: rgba(0x6b6b73ff).into(),
git_renamed: rgba(0xfee56cff).into(),
players: [
PlayerTheme {
cursor: rgba(0x10a793ff).into(),
selection: rgba(0x10a7933d).into(),
},
PlayerTheme {
cursor: rgba(0x96df71ff).into(),
selection: rgba(0x96df713d).into(),
},
PlayerTheme {
cursor: rgba(0xc74cecff).into(),
selection: rgba(0xc74cec3d).into(),
},
PlayerTheme {
cursor: rgba(0xf29c14ff).into(),
selection: rgba(0xf29c143d).into(),
},
PlayerTheme {
cursor: rgba(0x893ea6ff).into(),
selection: rgba(0x893ea63d).into(),
},
PlayerTheme {
cursor: rgba(0x08e7c5ff).into(),
selection: rgba(0x08e7c53d).into(),
},
PlayerTheme {
cursor: rgba(0xf82871ff).into(),
selection: rgba(0xf828713d).into(),
},
PlayerTheme {
cursor: rgba(0xfee56cff).into(),
selection: rgba(0xfee56c3d).into(),
},
],
}
}

View File

@ -0,0 +1,136 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn atelier_cave_dark() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Atelier Cave Dark".into(),
is_light: false,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x56505eff).into(),
border_variant: rgba(0x56505eff).into(),
border_focused: rgba(0x222953ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0x3a353fff).into(),
surface: rgba(0x221f26ff).into(),
background: rgba(0x3a353fff).into(),
filled_element: rgba(0x3a353fff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0x161a35ff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0x161a35ff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0xefecf4ff).into(),
text_muted: rgba(0x898591ff).into(),
text_placeholder: rgba(0xbe4677ff).into(),
text_disabled: rgba(0x756f7eff).into(),
text_accent: rgba(0x566ddaff).into(),
icon_muted: rgba(0x898591ff).into(),
syntax: SyntaxTheme {
highlights: vec![
("comment.doc".into(), rgba(0x8b8792ff).into()),
("tag".into(), rgba(0x566ddaff).into()),
("link_text".into(), rgba(0xaa563bff).into()),
("constructor".into(), rgba(0x566ddaff).into()),
("punctuation".into(), rgba(0xe2dfe7ff).into()),
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
("string.special.symbol".into(), rgba(0x299292ff).into()),
("string.escape".into(), rgba(0x8b8792ff).into()),
("emphasis".into(), rgba(0x566ddaff).into()),
("type".into(), rgba(0xa06d3aff).into()),
("punctuation.delimiter".into(), rgba(0x8b8792ff).into()),
("variant".into(), rgba(0xa06d3aff).into()),
("variable.special".into(), rgba(0x9559e7ff).into()),
("text.literal".into(), rgba(0xaa563bff).into()),
("punctuation.list_marker".into(), rgba(0xe2dfe7ff).into()),
("comment".into(), rgba(0x655f6dff).into()),
("function.method".into(), rgba(0x576cdbff).into()),
("property".into(), rgba(0xbe4677ff).into()),
("operator".into(), rgba(0x8b8792ff).into()),
("emphasis.strong".into(), rgba(0x566ddaff).into()),
("label".into(), rgba(0x566ddaff).into()),
("enum".into(), rgba(0xaa563bff).into()),
("number".into(), rgba(0xaa563bff).into()),
("primary".into(), rgba(0xe2dfe7ff).into()),
("keyword".into(), rgba(0x9559e7ff).into()),
(
"function.special.definition".into(),
rgba(0xa06d3aff).into(),
),
("punctuation.bracket".into(), rgba(0x8b8792ff).into()),
("constant".into(), rgba(0x2b9292ff).into()),
("string.special".into(), rgba(0xbf3fbfff).into()),
("title".into(), rgba(0xefecf4ff).into()),
("preproc".into(), rgba(0xefecf4ff).into()),
("link_uri".into(), rgba(0x2b9292ff).into()),
("string".into(), rgba(0x299292ff).into()),
("embedded".into(), rgba(0xefecf4ff).into()),
("hint".into(), rgba(0x706897ff).into()),
("boolean".into(), rgba(0x2b9292ff).into()),
("variable".into(), rgba(0xe2dfe7ff).into()),
("predictive".into(), rgba(0x615787ff).into()),
("string.regex".into(), rgba(0x388bc6ff).into()),
("function".into(), rgba(0x576cdbff).into()),
("attribute".into(), rgba(0x566ddaff).into()),
],
},
status_bar: rgba(0x3a353fff).into(),
title_bar: rgba(0x3a353fff).into(),
toolbar: rgba(0x19171cff).into(),
tab_bar: rgba(0x221f26ff).into(),
editor: rgba(0x19171cff).into(),
editor_subheader: rgba(0x221f26ff).into(),
editor_active_line: rgba(0x221f26ff).into(),
terminal: rgba(0x19171cff).into(),
image_fallback_background: rgba(0x3a353fff).into(),
git_created: rgba(0x2b9292ff).into(),
git_modified: rgba(0x566ddaff).into(),
git_deleted: rgba(0xbe4677ff).into(),
git_conflict: rgba(0xa06d3aff).into(),
git_ignored: rgba(0x756f7eff).into(),
git_renamed: rgba(0xa06d3aff).into(),
players: [
PlayerTheme {
cursor: rgba(0x566ddaff).into(),
selection: rgba(0x566dda3d).into(),
},
PlayerTheme {
cursor: rgba(0x2b9292ff).into(),
selection: rgba(0x2b92923d).into(),
},
PlayerTheme {
cursor: rgba(0xbf41bfff).into(),
selection: rgba(0xbf41bf3d).into(),
},
PlayerTheme {
cursor: rgba(0xaa563bff).into(),
selection: rgba(0xaa563b3d).into(),
},
PlayerTheme {
cursor: rgba(0x955ae6ff).into(),
selection: rgba(0x955ae63d).into(),
},
PlayerTheme {
cursor: rgba(0x3a8bc6ff).into(),
selection: rgba(0x3a8bc63d).into(),
},
PlayerTheme {
cursor: rgba(0xbe4677ff).into(),
selection: rgba(0xbe46773d).into(),
},
PlayerTheme {
cursor: rgba(0xa06d3aff).into(),
selection: rgba(0xa06d3a3d).into(),
},
],
}
}

View File

@ -0,0 +1,136 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn atelier_cave_light() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Atelier Cave Light".into(),
is_light: true,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x8f8b96ff).into(),
border_variant: rgba(0x8f8b96ff).into(),
border_focused: rgba(0xc8c7f2ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0xbfbcc5ff).into(),
surface: rgba(0xe6e3ebff).into(),
background: rgba(0xbfbcc5ff).into(),
filled_element: rgba(0xbfbcc5ff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0xe1e0f9ff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0xe1e0f9ff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0x19171cff).into(),
text_muted: rgba(0x5a5462ff).into(),
text_placeholder: rgba(0xbd4677ff).into(),
text_disabled: rgba(0x6e6876ff).into(),
text_accent: rgba(0x586cdaff).into(),
icon_muted: rgba(0x5a5462ff).into(),
syntax: SyntaxTheme {
highlights: vec![
("link_text".into(), rgba(0xaa573cff).into()),
("string".into(), rgba(0x299292ff).into()),
("emphasis".into(), rgba(0x586cdaff).into()),
("label".into(), rgba(0x586cdaff).into()),
("property".into(), rgba(0xbe4677ff).into()),
("emphasis.strong".into(), rgba(0x586cdaff).into()),
("constant".into(), rgba(0x2b9292ff).into()),
(
"function.special.definition".into(),
rgba(0xa06d3aff).into(),
),
("embedded".into(), rgba(0x19171cff).into()),
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
("function".into(), rgba(0x576cdbff).into()),
("tag".into(), rgba(0x586cdaff).into()),
("number".into(), rgba(0xaa563bff).into()),
("primary".into(), rgba(0x26232aff).into()),
("text.literal".into(), rgba(0xaa573cff).into()),
("variant".into(), rgba(0xa06d3aff).into()),
("type".into(), rgba(0xa06d3aff).into()),
("punctuation".into(), rgba(0x26232aff).into()),
("string.escape".into(), rgba(0x585260ff).into()),
("keyword".into(), rgba(0x9559e7ff).into()),
("title".into(), rgba(0x19171cff).into()),
("constructor".into(), rgba(0x586cdaff).into()),
("punctuation.list_marker".into(), rgba(0x26232aff).into()),
("string.special".into(), rgba(0xbf3fbfff).into()),
("operator".into(), rgba(0x585260ff).into()),
("function.method".into(), rgba(0x576cdbff).into()),
("link_uri".into(), rgba(0x2b9292ff).into()),
("variable.special".into(), rgba(0x9559e7ff).into()),
("hint".into(), rgba(0x776d9dff).into()),
("punctuation.bracket".into(), rgba(0x585260ff).into()),
("string.special.symbol".into(), rgba(0x299292ff).into()),
("predictive".into(), rgba(0x887fafff).into()),
("attribute".into(), rgba(0x586cdaff).into()),
("enum".into(), rgba(0xaa573cff).into()),
("preproc".into(), rgba(0x19171cff).into()),
("boolean".into(), rgba(0x2b9292ff).into()),
("variable".into(), rgba(0x26232aff).into()),
("comment.doc".into(), rgba(0x585260ff).into()),
("string.regex".into(), rgba(0x388bc6ff).into()),
("punctuation.delimiter".into(), rgba(0x585260ff).into()),
("comment".into(), rgba(0x7d7787ff).into()),
],
},
status_bar: rgba(0xbfbcc5ff).into(),
title_bar: rgba(0xbfbcc5ff).into(),
toolbar: rgba(0xefecf4ff).into(),
tab_bar: rgba(0xe6e3ebff).into(),
editor: rgba(0xefecf4ff).into(),
editor_subheader: rgba(0xe6e3ebff).into(),
editor_active_line: rgba(0xe6e3ebff).into(),
terminal: rgba(0xefecf4ff).into(),
image_fallback_background: rgba(0xbfbcc5ff).into(),
git_created: rgba(0x2b9292ff).into(),
git_modified: rgba(0x586cdaff).into(),
git_deleted: rgba(0xbd4677ff).into(),
git_conflict: rgba(0xa06e3bff).into(),
git_ignored: rgba(0x6e6876ff).into(),
git_renamed: rgba(0xa06e3bff).into(),
players: [
PlayerTheme {
cursor: rgba(0x586cdaff).into(),
selection: rgba(0x586cda3d).into(),
},
PlayerTheme {
cursor: rgba(0x2b9292ff).into(),
selection: rgba(0x2b92923d).into(),
},
PlayerTheme {
cursor: rgba(0xbf41bfff).into(),
selection: rgba(0xbf41bf3d).into(),
},
PlayerTheme {
cursor: rgba(0xaa573cff).into(),
selection: rgba(0xaa573c3d).into(),
},
PlayerTheme {
cursor: rgba(0x955ae6ff).into(),
selection: rgba(0x955ae63d).into(),
},
PlayerTheme {
cursor: rgba(0x3a8bc6ff).into(),
selection: rgba(0x3a8bc63d).into(),
},
PlayerTheme {
cursor: rgba(0xbd4677ff).into(),
selection: rgba(0xbd46773d).into(),
},
PlayerTheme {
cursor: rgba(0xa06e3bff).into(),
selection: rgba(0xa06e3b3d).into(),
},
],
}
}

View File

@ -0,0 +1,136 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn atelier_dune_dark() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Atelier Dune Dark".into(),
is_light: false,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x6c695cff).into(),
border_variant: rgba(0x6c695cff).into(),
border_focused: rgba(0x262f56ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0x45433bff).into(),
surface: rgba(0x262622ff).into(),
background: rgba(0x45433bff).into(),
filled_element: rgba(0x45433bff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0x171e38ff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0x171e38ff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0xfefbecff).into(),
text_muted: rgba(0xa4a08bff).into(),
text_placeholder: rgba(0xd73837ff).into(),
text_disabled: rgba(0x8f8b77ff).into(),
text_accent: rgba(0x6684e0ff).into(),
icon_muted: rgba(0xa4a08bff).into(),
syntax: SyntaxTheme {
highlights: vec![
("constructor".into(), rgba(0x6684e0ff).into()),
("punctuation".into(), rgba(0xe8e4cfff).into()),
("punctuation.delimiter".into(), rgba(0xa6a28cff).into()),
("string.special".into(), rgba(0xd43451ff).into()),
("string.escape".into(), rgba(0xa6a28cff).into()),
("comment".into(), rgba(0x7d7a68ff).into()),
("enum".into(), rgba(0xb65611ff).into()),
("variable.special".into(), rgba(0xb854d4ff).into()),
("primary".into(), rgba(0xe8e4cfff).into()),
("comment.doc".into(), rgba(0xa6a28cff).into()),
("label".into(), rgba(0x6684e0ff).into()),
("operator".into(), rgba(0xa6a28cff).into()),
("string".into(), rgba(0x5fac38ff).into()),
("variant".into(), rgba(0xae9512ff).into()),
("variable".into(), rgba(0xe8e4cfff).into()),
("function.method".into(), rgba(0x6583e1ff).into()),
(
"function.special.definition".into(),
rgba(0xae9512ff).into(),
),
("string.regex".into(), rgba(0x1ead82ff).into()),
("emphasis.strong".into(), rgba(0x6684e0ff).into()),
("punctuation.special".into(), rgba(0xd43451ff).into()),
("punctuation.bracket".into(), rgba(0xa6a28cff).into()),
("link_text".into(), rgba(0xb65611ff).into()),
("link_uri".into(), rgba(0x5fac39ff).into()),
("boolean".into(), rgba(0x5fac39ff).into()),
("hint".into(), rgba(0xb17272ff).into()),
("tag".into(), rgba(0x6684e0ff).into()),
("function".into(), rgba(0x6583e1ff).into()),
("title".into(), rgba(0xfefbecff).into()),
("property".into(), rgba(0xd73737ff).into()),
("type".into(), rgba(0xae9512ff).into()),
("constant".into(), rgba(0x5fac39ff).into()),
("attribute".into(), rgba(0x6684e0ff).into()),
("predictive".into(), rgba(0x9c6262ff).into()),
("string.special.symbol".into(), rgba(0x5fac38ff).into()),
("punctuation.list_marker".into(), rgba(0xe8e4cfff).into()),
("emphasis".into(), rgba(0x6684e0ff).into()),
("keyword".into(), rgba(0xb854d4ff).into()),
("text.literal".into(), rgba(0xb65611ff).into()),
("number".into(), rgba(0xb65610ff).into()),
("preproc".into(), rgba(0xfefbecff).into()),
("embedded".into(), rgba(0xfefbecff).into()),
],
},
status_bar: rgba(0x45433bff).into(),
title_bar: rgba(0x45433bff).into(),
toolbar: rgba(0x20201dff).into(),
tab_bar: rgba(0x262622ff).into(),
editor: rgba(0x20201dff).into(),
editor_subheader: rgba(0x262622ff).into(),
editor_active_line: rgba(0x262622ff).into(),
terminal: rgba(0x20201dff).into(),
image_fallback_background: rgba(0x45433bff).into(),
git_created: rgba(0x5fac39ff).into(),
git_modified: rgba(0x6684e0ff).into(),
git_deleted: rgba(0xd73837ff).into(),
git_conflict: rgba(0xae9414ff).into(),
git_ignored: rgba(0x8f8b77ff).into(),
git_renamed: rgba(0xae9414ff).into(),
players: [
PlayerTheme {
cursor: rgba(0x6684e0ff).into(),
selection: rgba(0x6684e03d).into(),
},
PlayerTheme {
cursor: rgba(0x5fac39ff).into(),
selection: rgba(0x5fac393d).into(),
},
PlayerTheme {
cursor: rgba(0xd43651ff).into(),
selection: rgba(0xd436513d).into(),
},
PlayerTheme {
cursor: rgba(0xb65611ff).into(),
selection: rgba(0xb656113d).into(),
},
PlayerTheme {
cursor: rgba(0xb854d3ff).into(),
selection: rgba(0xb854d33d).into(),
},
PlayerTheme {
cursor: rgba(0x20ad83ff).into(),
selection: rgba(0x20ad833d).into(),
},
PlayerTheme {
cursor: rgba(0xd73837ff).into(),
selection: rgba(0xd738373d).into(),
},
PlayerTheme {
cursor: rgba(0xae9414ff).into(),
selection: rgba(0xae94143d).into(),
},
],
}
}

Some files were not shown because too many files have changed in this diff Show More