Merge branch 'zed2' into zed2-workspace

This commit is contained in:
Mikayla 2023-10-30 16:53:21 -07:00
commit 3dadfb8ba8
No known key found for this signature in database
185 changed files with 12176 additions and 4816 deletions

42
Cargo.lock generated
View File

@ -108,6 +108,33 @@ dependencies = [
"util",
]
[[package]]
name = "ai2"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bincode",
"futures 0.3.28",
"gpui2",
"isahc",
"language2",
"lazy_static",
"log",
"matrixmultiply",
"ordered-float 2.10.0",
"parking_lot 0.11.2",
"parse_duration",
"postage",
"rand 0.8.5",
"regex",
"rusqlite",
"serde",
"serde_json",
"tiktoken-rs",
"util",
]
[[package]]
name = "alacritty_config"
version = "0.1.2-dev"
@ -1138,7 +1165,7 @@ dependencies = [
"audio2",
"client2",
"collections",
"fs",
"fs2",
"futures 0.3.28",
"gpui2",
"language2",
@ -4795,6 +4822,13 @@ dependencies = [
"gpui",
]
[[package]]
name = "menu2"
version = "0.1.0"
dependencies = [
"gpui2",
]
[[package]]
name = "metal"
version = "0.21.0"
@ -6000,7 +6034,7 @@ dependencies = [
"anyhow",
"client2",
"collections",
"fs",
"fs2",
"futures 0.3.28",
"gpui2",
"language2",
@ -6167,7 +6201,7 @@ dependencies = [
"ctor",
"db2",
"env_logger 0.9.3",
"fs",
"fs2",
"fsevent",
"futures 0.3.28",
"fuzzy2",
@ -8740,6 +8774,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"clap 4.4.4",
"convert_case 0.6.0",
"gpui2",
"log",
"rust-embed",
@ -10932,6 +10967,7 @@ dependencies = [
name = "zed2"
version = "0.109.0"
dependencies = [
"ai2",
"anyhow",
"async-compression",
"async-recursion 0.3.2",

View File

@ -59,6 +59,7 @@ members = [
"crates/lsp2",
"crates/media",
"crates/menu",
"crates/menu2",
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",

38
crates/Cargo.toml Normal file
View File

@ -0,0 +1,38 @@
[package]
name = "ai"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
language = { path = "../language" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View File

@ -8,6 +8,9 @@ publish = false
path = "src/ai.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }

View File

@ -1,4 +1,8 @@
pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod templates;
pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

15
crates/ai/src/auth.rs Normal file
View File

@ -0,0 +1,15 @@
use gpui::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
fn delete_credentials(&self, cx: &AppContext);
}

View File

@ -1,214 +1,23 @@
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::executor::Background;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
io,
sync::Arc,
};
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
use crate::{auth::CredentialProvider, models::LanguageModel};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
api_key: String,
executor: Arc<Background>,
mut request: OpenAIRequest,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = serde_json::to_string(&request)?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
pub trait CompletionProvider {
pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: OpenAIRequest,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
pub struct OpenAICompletionProvider {
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
Self { api_key, executor }
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn complete(
&self,
prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

View File

@ -1,32 +1,13 @@
use anyhow::{anyhow, Result};
use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
use parking_lot::Mutex;
use parse_duration::parse;
use postage::watch;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::completion::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
@ -87,301 +68,14 @@ impl Embedding {
}
}
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
async fn embed_batch(
&self,
spans: Vec<String>,
api_key: Option<String>,
) -> Result<Vec<Embedding>>;
pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
fn rate_limit_expiration(&self) -> Option<Instant>;
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Dummy API KEY".to_string())
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(
&self,
spans: Vec<String>,
_api_key: Option<String>,
) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
return Ok(vec![dummy_vec; spans.len()]);
}
fn max_tokens_per_batch(&self) -> usize {
OPENAI_INPUT_LIMIT
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let token_count = tokens.len();
let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
new_input.ok().unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, tokens.len())
}
}
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings {
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
}
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens.clone())
.ok()
.unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, tokens.len())
}
async fn embed_batch(
&self,
spans: Vec<String>,
api_key: Option<String>,
) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let Some(api_key) = api_key else {
return Err(anyhow!("no open ai key provided"));
};
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,66 +1,16 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
pub enum TruncationDirection {
Start,
End,
}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
bpe.decode(tokens[..length].to_vec())
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
bpe.decode(tokens[length..].to_vec())
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -6,7 +6,7 @@ use language::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::templates::repository_context::PromptCodeSnippet;
use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
@ -125,6 +125,9 @@ impl PromptChain {
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*;
#[test]
@ -141,7 +144,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(&content, max_token_length)?;
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
@ -162,7 +169,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(&content, max_token_length)?;
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
@ -171,38 +182,7 @@ pub(crate) mod tests {
}
}
#[derive(Clone)]
struct DummyLanguageModel {
capacity: usize,
}
impl LanguageModel for DummyLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
anyhow::Ok(
content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
)
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
anyhow::Ok(
content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
)
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@ -238,7 +218,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@ -275,7 +255,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@ -311,7 +291,7 @@ pub(crate) mod tests {
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,

View File

@ -3,8 +3,9 @@ use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
use crate::templates::base::PromptArguments;
use crate::templates::base::PromptTemplate;
use crate::models::TruncationDirection;
use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
@ -70,8 +71,9 @@ fn retrieve_context(
};
let truncated_start_window =
model.truncate_start(&start_window, start_goal_tokens)?;
let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
@ -89,7 +91,7 @@ fn retrieve_context(
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count)?;
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(&prompt, max_tokens)?;
prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;

View File

@ -1,4 +1,4 @@
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(&prompt, max_tokens)?;
prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
}
let token_count = args.model.count_tokens(&prompt)?;

View File

@ -1,4 +1,4 @@
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}

View File

@ -1,4 +1,4 @@
use crate::templates::base::{PromptArguments, PromptTemplate};
use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};

View File

@ -0,0 +1 @@
pub mod open_ai;

View File

@ -0,0 +1,298 @@
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::{executor::Background, AppContext};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
credential: ProviderCredential,
executor: Arc<Background>,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Clone)]
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
credential,
executor,
}
}
}
impl CredentialProvider for OpenAICompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
let mut credential = self.credential.write();
match *credential {
ProviderCredential::Credentials { .. } => {
return credential.clone();
}
_ => {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
*credential = ProviderCredential::Credentials { api_key };
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
*credential = ProviderCredential::Credentials { api_key };
}
} else {
};
}
}
credential.clone()
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential.clone() {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
*self.credential.write() = credential;
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the langauge model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -0,0 +1,306 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider {
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
impl CredentialProvider for OpenAIEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
let mut credential = self.credential.write();
match *credential {
ProviderCredential::Credentials { .. } => {
return credential.clone();
}
_ => {
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
*credential = ProviderCredential::Credentials { api_key };
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
*credential = ProviderCredential::Credentials { api_key };
}
} else {
};
}
}
credential.clone()
}
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
match credential.clone() {
ProviderCredential::Credentials { api_key } => {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
*self.credential.write() = credential;
}
fn delete_credentials(&self, cx: &AppContext) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View File

@ -0,0 +1,9 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

View File

@ -0,0 +1,57 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)]
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -0,0 +1,11 @@
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

191
crates/ai/src/test.rs Normal file
View File

@ -0,0 +1,191 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
fn delete_credentials(&self, _cx: &AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

38
crates/ai2/Cargo.toml Normal file
View File

@ -0,0 +1,38 @@
[package]
name = "ai2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/ai2.rs"
doctest = false
[features]
test-support = []
[dependencies]
gpui2 = { path = "../gpui2" }
util = { path = "../util" }
language2 = { path = "../language2" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
lazy_static.workspace = true
ordered-float.workspace = true
parking_lot.workspace = true
isahc.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
postage.workspace = true
rand.workspace = true
log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
gpui2 = { path = "../gpui2", features = ["test-support"] }

8
crates/ai2/src/ai2.rs Normal file
View File

@ -0,0 +1,8 @@
pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

17
crates/ai2/src/auth.rs Normal file
View File

@ -0,0 +1,17 @@
use async_trait::async_trait;
use gpui2::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
#[async_trait]
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
async fn delete_credentials(&self, cx: &mut AppContext);
}

View File

@ -0,0 +1,23 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::{auth::CredentialProvider, models::LanguageModel};
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

123
crates/ai2/src/embedding.rs Normal file
View File

@ -0,0 +1,123 @@
use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
OrderedFloat(result)
}
}
#[async_trait]
pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui2::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}

16
crates/ai2/src/models.rs Normal file
View File

@ -0,0 +1,16 @@
pub enum TruncationDirection {
Start,
End,
}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View File

@ -0,0 +1,330 @@
use std::cmp::Reverse;
use std::ops::Range;
use std::sync::Arc;
use language2::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
Code,
}
// TODO: Set this up to manage for defaults well
pub struct PromptArguments {
pub model: Arc<dyn LanguageModel>,
pub user_prompt: Option<String>,
pub language_name: Option<String>,
pub project_name: Option<String>,
pub snippets: Vec<PromptCodeSnippet>,
pub reserved_tokens: usize,
pub buffer: Option<BufferSnapshot>,
pub selected_range: Option<Range<usize>>,
}
impl PromptArguments {
pub(crate) fn get_file_type(&self) -> PromptFileType {
if self
.language_name
.as_ref()
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
.unwrap_or(true)
{
PromptFileType::Code
} else {
PromptFileType::Text
}
}
}
pub trait PromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)>;
}
#[repr(i8)]
#[derive(PartialEq, Eq, Ord)]
pub enum PromptPriority {
Mandatory, // Ignores truncation
Ordered { order: usize }, // Truncates based on priority
}
impl PartialOrd for PromptPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
}
}
}
pub struct PromptChain {
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
}
impl PromptChain {
pub fn new(
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
) -> Self {
PromptChain { args, templates }
}
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
// Argsort based on Prompt Priority
let seperator = "\n";
let seperator_tokens = self.args.model.count_tokens(seperator)?;
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
// If Truncate
let mut tokens_outstanding = if truncate {
Some(self.args.model.capacity()? - self.args.reserved_tokens)
} else {
None
};
let mut prompts = vec!["".to_string(); sorted_indices.len()];
for idx in sorted_indices {
let (_, template) = &self.templates[idx];
if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err()
{
if template_prompt != "" {
prompts[idx] = template_prompt;
if let Some(remaining_tokens) = tokens_outstanding {
let new_tokens = prompt_token_count + seperator_tokens;
tokens_outstanding = if remaining_tokens > new_tokens {
Some(remaining_tokens - new_tokens)
} else {
Some(0)
};
}
}
}
}
prompts.retain(|x| x != "");
let full_prompt = prompts.join(seperator);
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
anyhow::Ok((prompts.join(seperator), total_token_count))
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*;
#[test]
pub fn test_prompt_chain() {
struct TestPromptTemplate {}
impl PromptTemplate for TestPromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
struct TestLowPriorityTemplate {}
impl PromptTemplate for TestLowPriorityTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a low priority test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 2 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(prompt, "This is a test promp".to_string());
assert_eq!(token_count, capacity);
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Mandatory,
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(
prompt,
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
.to_string()
);
assert_eq!(token_count, capacity - reserved_tokens);
}
}

View File

@ -0,0 +1,164 @@
use anyhow::anyhow;
use language2::BufferSnapshot;
use language2::ToOffset;
use crate::models::LanguageModel;
use crate::models::TruncationDirection;
use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
impl PromptTemplate for FileContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
if let Some(buffer) = &args.buffer {
let mut prompt = String::new();
// Add Initial Preamble
// TODO: Do we want to add the path in here?
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
let language_name = args
.language_name
.clone()
.unwrap_or("".to_string())
.to_lowercase();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if truncated {
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
}
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
} else {
Err(anyhow!("no buffer provided to retrieve file context from"))
}
}
}

View File

@ -0,0 +1,99 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
pub fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}
pub struct GenerateInlineContent {}
impl PromptTemplate for GenerateInlineContent {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let Some(user_prompt) = &args.user_prompt else {
return Err(anyhow!("user prompt not provided"));
};
let file_type = args.get_file_type();
let content_type = match &file_type {
PromptFileType::Code => "code",
PromptFileType::Text => "text",
};
let mut prompt = String::new();
if let Some(selected_range) = &args.selected_range {
if selected_range.start == selected_range.end {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
capitalize(content_type)
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
}
} else {
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}"
)
.unwrap();
}
if let Some(language_name) = &args.language_name {
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
match file_type {
PromptFileType::Code => {
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
}
_ => {}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}

View File

@ -0,0 +1,5 @@
pub mod base;
pub mod file_context;
pub mod generate;
pub mod preamble;
pub mod repository_context;

View File

@ -0,0 +1,52 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut prompts = Vec::new();
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
"You are an expert {}engineer.",
args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
prompts.push("You are an expert engineer.".to_string());
}
}
if let Some(project_name) = args.project_name.clone() {
prompts.push(format!(
"You are currently working inside the '{project_name}' project in code editor Zed."
));
}
if let Some(mut remaining_tokens) = max_token_length {
let mut prompt = String::new();
let mut total_count = 0;
for prompt_piece in prompts {
let prompt_token_count =
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
if remaining_tokens > prompt_token_count {
writeln!(prompt, "{prompt_piece}").unwrap();
remaining_tokens -= prompt_token_count;
total_count += prompt_token_count;
}
}
anyhow::Ok((prompt, total_count))
} else {
let prompt = prompts.join("\n");
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}
}

View File

@ -0,0 +1,98 @@
use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui2::{AsyncAppContext, Model};
use language2::{Anchor, Buffer};
#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(
buffer: Model<Buffer>,
range: Range<Anchor>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Self> {
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
let language_name = buffer
.language()
.and_then(|language| Some(language.name().to_string().to_lowercase()));
let file_path = buffer
.file()
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
})?;
anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
})
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.and_then(|path| Some(path.to_string_lossy().to_string()))
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
pub struct RepositoryContext {}
impl PromptTemplate for RepositoryContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
let mut prompt = String::new();
let mut remaining_tokens = max_token_length.clone();
let seperator_token_length = args.model.count_tokens("\n")?;
for snippet in &args.snippets {
let mut snippet_prompt = template.to_string();
let content = snippet.to_string();
writeln!(snippet_prompt, "{content}").unwrap();
let token_count = args.model.count_tokens(&snippet_prompt)?;
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
if let Some(tokens_left) = remaining_tokens {
if tokens_left >= token_count {
writeln!(prompt, "{snippet_prompt}").unwrap();
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
{
Some(tokens_left - token_count - seperator_token_length)
} else {
Some(0)
};
}
} else {
writeln!(prompt, "{snippet_prompt}").unwrap();
}
}
}
let total_token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, total_token_count))
}
}

View File

@ -0,0 +1 @@
pub mod open_ai;

View File

@ -0,0 +1,306 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui2::{AppContext, Executor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAIRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAIUsage>,
}
pub async fn stream_completion(
credential: ProviderCredential,
executor: Arc<Executor>,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
let json_data = request.data()?;
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAIResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(&data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAIResponse {
error: OpenAIError,
}
#[derive(Deserialize)]
struct OpenAIError {
message: String,
}
match serde_json::from_str::<OpenAIResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Clone)]
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Executor>,
}
impl OpenAICompletionProvider {
pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
credential,
executor,
}
}
}
#[async_trait]
impl CredentialProvider for OpenAICompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = cx
.run_on_main(move |cx| match existing_credential {
ProviderCredential::Credentials { .. } => {
return existing_credential.clone();
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
return ProviderCredential::Credentials { api_key };
}
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
return ProviderCredential::Credentials { api_key };
} else {
return ProviderCredential::NoCredentials;
}
} else {
return ProviderCredential::NoCredentials;
}
}
})
.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
cx.run_on_main(move |cx| match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
})
.await;
}
async fn delete_credentials(&self, cx: &mut AppContext) {
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
.await;
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the langauge model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -0,0 +1,313 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui2::Executor;
use gpui2::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Executor>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAIEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
data: Vec<OpenAIEmbedding>,
usage: OpenAIEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAIEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAIEmbeddingProvider {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider {
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
#[async_trait]
impl CredentialProvider for OpenAIEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
let existing_credential = self.credential.read().clone();
let retrieved_credential = cx
.run_on_main(move |cx| match existing_credential {
ProviderCredential::Credentials { .. } => {
return existing_credential.clone();
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
return ProviderCredential::Credentials { api_key };
}
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
return ProviderCredential::Credentials { api_key };
} else {
return ProviderCredential::NoCredentials;
}
} else {
return ProviderCredential::NoCredentials;
}
}
})
.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
*self.credential.write() = credential.clone();
let credential = credential.clone();
cx.run_on_main(move |cx| match credential {
ProviderCredential::Credentials { api_key } => {
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
})
.await;
}
async fn delete_credentials(&self, cx: &mut AppContext) {
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
.await;
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View File

@ -0,0 +1,9 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

View File

@ -0,0 +1,57 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
use crate::models::{LanguageModel, TruncationDirection};
#[derive(Clone)]
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -0,0 +1,11 @@
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

193
crates/ai2/src/test.rs Normal file
View File

@ -0,0 +1,193 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui2::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
async fn delete_credentials(&self, _cx: &mut AppContext) {}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
#[async_trait]
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
async fn delete_credentials(&self, _cx: &mut AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
ai = { path = "../ai", features = ["test-support"]}
ctor.workspace = true
env_logger.workspace = true

View File

@ -4,7 +4,7 @@ mod codegen;
mod prompts;
mod streaming_diff;
use ai::completion::Role;
use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;

View File

@ -5,12 +5,14 @@ use crate::{
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
use ai::{
completion::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
},
templates::repository_context::PromptCodeSnippet,
auth::ProviderCredential,
completion::{CompletionProvider, CompletionRequest},
providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
};
use ai::prompts::repository_context::PromptCodeSnippet;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@ -43,8 +45,8 @@ use search::BufferSearchBar;
use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
cmp, env,
cell::Cell,
cmp,
fmt::Write,
iter,
ops::Range,
@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
cx.capture_action(ConversationEditor::copy);
cx.add_action(ConversationEditor::split);
cx.capture_action(ConversationEditor::cycle_message_role);
cx.add_action(AssistantPanel::save_api_key);
cx.add_action(AssistantPanel::reset_api_key);
cx.add_action(AssistantPanel::save_credentials);
cx.add_action(AssistantPanel::reset_credentials);
cx.add_action(AssistantPanel::toggle_zoom);
cx.add_action(AssistantPanel::deploy);
cx.add_action(AssistantPanel::select_next_match);
@ -140,9 +142,8 @@ pub struct AssistantPanel {
zoomed: bool,
has_focus: bool,
toolbar: ViewHandle<Toolbar>,
api_key: Rc<RefCell<Option<String>>>,
completion_provider: Box<dyn CompletionProvider>,
api_key_editor: Option<ViewHandle<Editor>>,
has_read_credentials: bool,
languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
subscriptions: Vec<Subscription>,
@ -202,6 +203,11 @@ impl AssistantPanel {
});
let semantic_index = SemanticIndex::global(cx);
// Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider = Box::new(OpenAICompletionProvider::new(
"gpt-4",
cx.background().clone(),
));
let mut this = Self {
workspace: workspace_handle,
@ -213,9 +219,8 @@ impl AssistantPanel {
zoomed: false,
has_focus: false,
toolbar,
api_key: Rc::new(RefCell::new(None)),
completion_provider,
api_key_editor: None,
has_read_credentials: false,
languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(),
width: None,
@ -254,10 +259,7 @@ impl AssistantPanel {
cx: &mut ViewContext<Workspace>,
) {
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
if this
.update(cx, |assistant, cx| assistant.load_api_key(cx))
.is_some()
{
if this.update(cx, |assistant, _| assistant.has_credentials()) {
this
} else {
workspace.focus_panel::<AssistantPanel>(cx);
@ -289,12 +291,6 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>,
project: &ModelHandle<Project>,
) {
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
api_key
} else {
return;
};
let selection = editor.read(cx).selections.newest_anchor().clone();
if selection.start.excerpt_id != selection.end.excerpt_id {
return;
@ -325,10 +321,13 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
api_key,
"gpt-4",
cx.background().clone(),
));
// Retrieve Credentials Authenticates the Provider
// provider.retrieve_credentials(cx);
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
});
@ -745,13 +744,14 @@ impl AssistantPanel {
content: prompt,
});
let request = OpenAIRequest {
let request = Box::new(OpenAIRequest {
model: model.full_name().into(),
messages,
stream: true,
stop: vec!["|END|>".to_string()],
temperature,
};
});
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
anyhow::Ok(())
})
@ -811,7 +811,7 @@ impl AssistantPanel {
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let editor = cx.add_view(|cx| {
ConversationEditor::new(
self.api_key.clone(),
self.completion_provider.clone(),
self.languages.clone(),
self.fs.clone(),
self.workspace.clone(),
@ -870,17 +870,19 @@ impl AssistantPanel {
}
}
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
if let Some(api_key) = self
.api_key_editor
.as_ref()
.map(|editor| editor.read(cx).text(cx))
{
if !api_key.is_empty() {
cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
*self.api_key.borrow_mut() = Some(api_key);
let credential = ProviderCredential::Credentials {
api_key: api_key.clone(),
};
self.completion_provider.save_credentials(cx, credential);
self.api_key_editor.take();
cx.focus_self();
cx.notify();
@ -890,9 +892,8 @@ impl AssistantPanel {
}
}
fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
self.api_key.take();
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
self.completion_provider.delete_credentials(cx);
self.api_key_editor = Some(build_api_key_editor(cx));
cx.focus_self();
cx.notify();
@ -1151,13 +1152,12 @@ impl AssistantPanel {
let fs = self.fs.clone();
let workspace = self.workspace.clone();
let api_key = self.api_key.clone();
let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.add_model(|cx| {
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
});
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
@ -1181,30 +1181,12 @@ impl AssistantPanel {
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
}
fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
if self.api_key.borrow().is_none() && !self.has_read_credentials {
self.has_read_credentials = true;
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
};
if let Some(api_key) = api_key {
*self.api_key.borrow_mut() = Some(api_key);
} else if self.api_key_editor.is_none() {
self.api_key_editor = Some(build_api_key_editor(cx));
cx.notify();
}
}
fn has_credentials(&mut self) -> bool {
self.completion_provider.has_credentials()
}
self.api_key.borrow().clone()
fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
self.completion_provider.retrieve_credentials(cx);
}
}
@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
if active {
self.load_api_key(cx);
self.load_credentials(cx);
if self.editors.is_empty() {
self.new_conversation(cx);
@ -1454,10 +1436,10 @@ struct Conversation {
token_count: Option<usize>,
max_token_count: usize,
pending_token_count: Task<Option<()>>,
api_key: Rc<RefCell<Option<String>>>,
pending_save: Task<Result<()>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
completion_provider: Box<dyn CompletionProvider>,
}
impl Entity for Conversation {
@ -1466,9 +1448,9 @@ impl Entity for Conversation {
impl Conversation {
fn new(
api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
completion_provider: Box<dyn CompletionProvider>,
) -> Self {
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
@ -1507,8 +1489,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
api_key,
buffer,
completion_provider,
};
let message = MessageAnchor {
id: MessageId(post_inc(&mut this.next_message_id.0)),
@ -1554,7 +1536,6 @@ impl Conversation {
fn deserialize(
saved_conversation: SavedConversation,
path: PathBuf,
api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
@ -1563,6 +1544,10 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()),
};
let model = saved_conversation.model;
let completion_provider: Box<dyn CompletionProvider> = Box::new(
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
);
completion_provider.retrieve_credentials(cx);
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
@ -1609,8 +1594,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
api_key,
buffer,
completion_provider,
};
this.count_remaining_tokens(cx);
this
@ -1731,11 +1716,11 @@ impl Conversation {
}
if should_assist {
let Some(api_key) = self.api_key.borrow().clone() else {
if !self.completion_provider.has_credentials() {
return Default::default();
};
}
let request = OpenAIRequest {
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(),
messages: self
.messages(cx)
@ -1745,9 +1730,9 @@ impl Conversation {
stream: true,
stop: vec![],
temperature: 1.0,
};
});
let stream = stream_completion(api_key, cx.background().clone(), request);
let stream = self.completion_provider.complete(request);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@ -1765,33 +1750,28 @@ impl Conversation {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix =
this.message_anchors.iter().position(|message| {
message.id == assistant_message_id
})?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(ConversationEvent::StreamedCompletion);
let text = message?;
Some(())
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let message_ix = this
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message.start.to_offset(buffer).saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
}
cx.emit(ConversationEvent::StreamedCompletion);
Some(())
});
smol::future::yield_now().await;
}
@ -2013,57 +1993,54 @@ impl Conversation {
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
let messages = self
.messages(cx)
.take(2)
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage {
role: Role::User,
content:
"Summarize the conversation into a short title without punctuation"
.into(),
}));
let request = OpenAIRequest {
model: self.model.full_name().to_string(),
messages: messages.collect(),
stream: true,
stop: vec![],
temperature: 1.0,
};
let stream = stream_completion(api_key, cx.background().clone(), request);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
let text = choice.delta.content.unwrap_or_default();
this.update(&mut cx, |this, cx| {
this.summary
.get_or_insert(Default::default())
.text
.push_str(&text);
cx.emit(ConversationEvent::SummaryChanged);
});
}
}
this.update(&mut cx, |this, cx| {
if let Some(summary) = this.summary.as_mut() {
summary.done = true;
cx.emit(ConversationEvent::SummaryChanged);
}
});
anyhow::Ok(())
}
.log_err()
});
if !self.completion_provider.has_credentials() {
return;
}
let messages = self
.messages(cx)
.take(2)
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.chain(Some(RequestMessage {
role: Role::User,
content: "Summarize the conversation into a short title without punctuation"
.into(),
}));
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(),
messages: messages.collect(),
stream: true,
stop: vec![],
temperature: 1.0,
});
let stream = self.completion_provider.complete(request);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let text = message?;
this.update(&mut cx, |this, cx| {
this.summary
.get_or_insert(Default::default())
.text
.push_str(&text);
cx.emit(ConversationEvent::SummaryChanged);
});
}
this.update(&mut cx, |this, cx| {
if let Some(summary) = this.summary.as_mut() {
summary.done = true;
cx.emit(ConversationEvent::SummaryChanged);
}
});
anyhow::Ok(())
}
.log_err()
});
}
}
@ -2224,13 +2201,14 @@ struct ConversationEditor {
impl ConversationEditor {
fn new(
api_key: Rc<RefCell<Option<String>>>,
completion_provider: Box<dyn CompletionProvider>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
workspace: WeakViewHandle<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
let conversation =
cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
Self::for_conversation(conversation, fs, workspace, cx)
}
@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use crate::MessageId;
use ai::test::FakeCompletionProvider;
use gpui::AppContext;
#[gpui::test]
@ -3426,7 +3405,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3554,7 +3535,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3650,7 +3633,8 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3732,8 +3716,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation =
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_0 = conversation.read(cx).message_anchors[0].id;
let message_1 = conversation.update(cx, |conversation, cx| {
@ -3770,7 +3755,6 @@ mod tests {
Conversation::deserialize(
conversation.read(cx).serialize(cx),
Default::default(),
Default::default(),
registry.clone(),
cx,
)

View File

@ -1,5 +1,5 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, OpenAIRequest};
use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@ -96,7 +96,7 @@ impl Codegen {
self.error.as_ref()
}
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
};
use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use serde::Serialize;
use settings::SettingsStore;
use smol::future::FutureExt;
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
@ -372,7 +380,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let provider = Arc::new(TestCompletionProvider::new());
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@ -381,7 +389,11 @@ mod tests {
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
" let mut x = 0;\n",
@ -434,7 +446,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(TestCompletionProvider::new());
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@ -443,7 +455,11 @@ mod tests {
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"t mut x = 0;\n",
@ -496,7 +512,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(TestCompletionProvider::new());
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@ -505,7 +521,11 @@ mod tests {
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"let mut x = 0;\n",
@ -593,38 +613,6 @@ mod tests {
}
}
struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl TestCompletionProvider {
fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CompletionProvider for TestCompletionProvider {
fn complete(
&self,
_prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {

View File

@ -1,9 +1,10 @@
use ai::models::{LanguageModel, OpenAILanguageModel};
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::templates::file_context::FileContext;
use ai::templates::generate::GenerateInlineContent;
use ai::templates::preamble::EngineerPreamble;
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::models::LanguageModel;
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::prompts::file_context::FileContext;
use ai::prompts::generate::GenerateInlineContent;
use ai::prompts::preamble::EngineerPreamble;
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::providers::open_ai::OpenAILanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;

View File

@ -25,7 +25,7 @@ collections = { path = "../collections" }
gpui2 = { path = "../gpui2" }
log.workspace = true
live_kit_client = { path = "../live_kit_client" }
fs = { path = "../fs" }
fs2 = { path = "../fs2" }
language2 = { path = "../language2" }
media = { path = "../media" }
project2 = { path = "../project2" }
@ -43,7 +43,7 @@ serde_derive.workspace = true
[dev-dependencies]
client2 = { path = "../client2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
fs2 = { path = "../fs2", features = ["test-support"] }
language2 = { path = "../language2", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] }

View File

@ -12,8 +12,8 @@ use client2::{
use collections::HashSet;
use futures::{future::Shared, FutureExt};
use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Subscription, Task,
WeakHandle,
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
WeakModel,
};
use postage::watch;
use project2::Project;
@ -23,10 +23,10 @@ use std::sync::Arc;
pub use participant::ParticipantLocation;
pub use room::Room;
pub fn init(client: Arc<Client>, user_store: Handle<UserStore>, cx: &mut AppContext) {
pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
CallSettings::register(cx);
let active_call = cx.entity(|cx| ActiveCall::new(client, user_store, cx));
let active_call = cx.build_model(|cx| ActiveCall::new(client, user_store, cx));
cx.set_global(active_call);
}
@ -40,16 +40,16 @@ pub struct IncomingCall {
/// Singleton global maintaining the user's participation in a room across workspaces.
pub struct ActiveCall {
room: Option<(Handle<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<Handle<Room>, Arc<anyhow::Error>>>>>,
location: Option<WeakHandle<Project>>,
room: Option<(Model<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
location: Option<WeakModel<Project>>,
pending_invites: HashSet<u64>,
incoming_call: (
watch::Sender<Option<IncomingCall>>,
watch::Receiver<Option<IncomingCall>>,
),
client: Arc<Client>,
user_store: Handle<UserStore>,
user_store: Model<UserStore>,
_subscriptions: Vec<client2::Subscription>,
}
@ -58,11 +58,7 @@ impl EventEmitter for ActiveCall {
}
impl ActiveCall {
fn new(
client: Arc<Client>,
user_store: Handle<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
Self {
room: None,
pending_room_creation: None,
@ -84,7 +80,7 @@ impl ActiveCall {
}
async fn handle_incoming_call(
this: Handle<Self>,
this: Model<Self>,
envelope: TypedEnvelope<proto::IncomingCall>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -112,7 +108,7 @@ impl ActiveCall {
}
async fn handle_call_canceled(
this: Handle<Self>,
this: Model<Self>,
envelope: TypedEnvelope<proto::CallCanceled>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -129,14 +125,14 @@ impl ActiveCall {
Ok(())
}
pub fn global(cx: &AppContext) -> Handle<Self> {
cx.global::<Handle<Self>>().clone()
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<Model<Self>>().clone()
}
pub fn invite(
&mut self,
called_user_id: u64,
initial_project: Option<Handle<Project>>,
initial_project: Option<Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.pending_invites.insert(called_user_id) {
@ -291,7 +287,7 @@ impl ActiveCall {
&mut self,
channel_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<Handle<Room>>> {
) -> Task<Result<Model<Room>>> {
if let Some(room) = self.room().cloned() {
if room.read(cx).channel_id() == Some(channel_id) {
return Task::ready(Ok(room));
@ -327,7 +323,7 @@ impl ActiveCall {
pub fn share_project(
&mut self,
project: Handle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> {
if let Some((room, _)) = self.room.as_ref() {
@ -340,7 +336,7 @@ impl ActiveCall {
pub fn unshare_project(
&mut self,
project: Handle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
if let Some((room, _)) = self.room.as_ref() {
@ -351,13 +347,13 @@ impl ActiveCall {
}
}
pub fn location(&self) -> Option<&WeakHandle<Project>> {
pub fn location(&self) -> Option<&WeakModel<Project>> {
self.location.as_ref()
}
pub fn set_location(
&mut self,
project: Option<&Handle<Project>>,
project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
@ -371,7 +367,7 @@ impl ActiveCall {
fn set_room(
&mut self,
room: Option<Handle<Room>>,
room: Option<Model<Room>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
@ -407,7 +403,7 @@ impl ActiveCall {
}
}
pub fn room(&self) -> Option<&Handle<Room>> {
pub fn room(&self) -> Option<&Model<Room>> {
self.room.as_ref().map(|(room, _)| room)
}

View File

@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use client2::ParticipantIndex;
use client2::{proto, User};
use gpui2::WeakHandle;
use gpui2::WeakModel;
pub use live_kit_client::Frame;
use project2::Project;
use std::{fmt, sync::Arc};
@ -33,7 +33,7 @@ impl ParticipantLocation {
#[derive(Clone, Default)]
pub struct LocalParticipant {
pub projects: Vec<proto::ParticipantProject>,
pub active_project: Option<WeakHandle<Project>>,
pub active_project: Option<WeakModel<Project>>,
}
#[derive(Clone, Debug)]

View File

@ -13,10 +13,10 @@ use client2::{
Client, ParticipantIndex, TypedEnvelope, User, UserStore,
};
use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use fs2::Fs;
use futures::{FutureExt, StreamExt};
use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Task, WeakHandle,
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
};
use language2::LanguageRegistry;
use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate};
@ -61,8 +61,8 @@ pub struct Room {
channel_id: Option<u64>,
// live_kit: Option<LiveKitRoom>,
status: RoomStatus,
shared_projects: HashSet<WeakHandle<Project>>,
joined_projects: HashSet<WeakHandle<Project>>,
shared_projects: HashSet<WeakModel<Project>>,
joined_projects: HashSet<WeakModel<Project>>,
local_participant: LocalParticipant,
remote_participants: BTreeMap<u64, RemoteParticipant>,
pending_participants: Vec<Arc<User>>,
@ -70,7 +70,7 @@ pub struct Room {
pending_call_count: usize,
leave_when_empty: bool,
client: Arc<Client>,
user_store: Handle<UserStore>,
user_store: Model<UserStore>,
follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>,
client_subscriptions: Vec<client2::Subscription>,
_subscriptions: Vec<gpui2::Subscription>,
@ -111,7 +111,7 @@ impl Room {
channel_id: Option<u64>,
live_kit_connection_info: Option<proto::LiveKitConnectionInfo>,
client: Arc<Client>,
user_store: Handle<UserStore>,
user_store: Model<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
todo!()
@ -237,15 +237,15 @@ impl Room {
pub(crate) fn create(
called_user_id: u64,
initial_project: Option<Handle<Project>>,
initial_project: Option<Model<Project>>,
client: Arc<Client>,
user_store: Handle<UserStore>,
user_store: Model<UserStore>,
cx: &mut AppContext,
) -> Task<Result<Handle<Self>>> {
) -> Task<Result<Model<Self>>> {
cx.spawn(move |mut cx| async move {
let response = client.request(proto::CreateRoom {}).await?;
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
let room = cx.entity(|cx| {
let room = cx.build_model(|cx| {
Self::new(
room_proto.id,
None,
@ -283,9 +283,9 @@ impl Room {
pub(crate) fn join_channel(
channel_id: u64,
client: Arc<Client>,
user_store: Handle<UserStore>,
user_store: Model<UserStore>,
cx: &mut AppContext,
) -> Task<Result<Handle<Self>>> {
) -> Task<Result<Model<Self>>> {
cx.spawn(move |cx| async move {
Self::from_join_response(
client.request(proto::JoinChannel { channel_id }).await?,
@ -299,9 +299,9 @@ impl Room {
pub(crate) fn join(
call: &IncomingCall,
client: Arc<Client>,
user_store: Handle<UserStore>,
user_store: Model<UserStore>,
cx: &mut AppContext,
) -> Task<Result<Handle<Self>>> {
) -> Task<Result<Model<Self>>> {
let id = call.room_id;
cx.spawn(move |cx| async move {
Self::from_join_response(
@ -343,11 +343,11 @@ impl Room {
fn from_join_response(
response: proto::JoinRoomResponse,
client: Arc<Client>,
user_store: Handle<UserStore>,
user_store: Model<UserStore>,
mut cx: AsyncAppContext,
) -> Result<Handle<Self>> {
) -> Result<Model<Self>> {
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
let room = cx.entity(|cx| {
let room = cx.build_model(|cx| {
Self::new(
room_proto.id,
response.channel_id,
@ -424,7 +424,7 @@ impl Room {
}
async fn maintain_connection(
this: WeakHandle<Self>,
this: WeakModel<Self>,
client: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
@ -661,7 +661,7 @@ impl Room {
}
async fn handle_room_updated(
this: Handle<Self>,
this: Model<Self>,
envelope: TypedEnvelope<proto::RoomUpdated>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -1101,7 +1101,7 @@ impl Room {
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Handle<Project>>> {
) -> Task<Result<Model<Project>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
cx.emit(Event::RemoteProjectJoined { project_id: id });
@ -1125,7 +1125,7 @@ impl Room {
pub(crate) fn share_project(
&mut self,
project: Handle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<u64>> {
if let Some(project_id) = project.read(cx).remote_id() {
@ -1161,7 +1161,7 @@ impl Room {
pub(crate) fn unshare_project(
&mut self,
project: Handle<Project>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
let project_id = match project.read(cx).remote_id() {
@ -1175,7 +1175,7 @@ impl Room {
pub(crate) fn set_location(
&mut self,
project: Option<&Handle<Project>>,
project: Option<&Model<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if self.status.is_offline() {

View File

@ -14,8 +14,8 @@ use futures::{
future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt,
};
use gpui2::{
serde_json, AnyHandle, AnyWeakHandle, AppContext, AsyncAppContext, Handle, SemanticVersion,
Task, WeakHandle,
serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task,
WeakModel,
};
use lazy_static::lazy_static;
use parking_lot::RwLock;
@ -227,7 +227,7 @@ struct ClientState {
_reconnect_task: Option<Task<()>>,
reconnect_interval: Duration,
entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
models_by_message_type: HashMap<TypeId, AnyWeakHandle>,
models_by_message_type: HashMap<TypeId, AnyWeakModel>,
entity_types_by_message_type: HashMap<TypeId, TypeId>,
#[allow(clippy::type_complexity)]
message_handlers: HashMap<
@ -236,7 +236,7 @@ struct ClientState {
dyn Send
+ Sync
+ Fn(
AnyHandle,
AnyModel,
Box<dyn AnyTypedEnvelope>,
&Arc<Client>,
AsyncAppContext,
@ -246,7 +246,7 @@ struct ClientState {
}
enum WeakSubscriber {
Entity { handle: AnyWeakHandle },
Entity { handle: AnyWeakModel },
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
}
@ -314,7 +314,7 @@ impl<T> PendingEntitySubscription<T>
where
T: 'static + Send,
{
pub fn set_model(mut self, model: &Handle<T>, cx: &mut AsyncAppContext) -> Subscription {
pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
self.consumed = true;
let mut state = self.client.state.write();
let id = (TypeId::of::<T>(), self.remote_id);
@ -552,13 +552,13 @@ impl Client {
#[track_caller]
pub fn add_message_handler<M, E, H, F>(
self: &Arc<Self>,
entity: WeakHandle<E>,
entity: WeakModel<E>,
handler: H,
) -> Subscription
where
M: EnvelopedMessage,
E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<()>> + Send,
{
let message_type_id = TypeId::of::<M>();
@ -594,13 +594,13 @@ impl Client {
pub fn add_request_handler<M, E, H, F>(
self: &Arc<Self>,
model: WeakHandle<E>,
model: WeakModel<E>,
handler: H,
) -> Subscription
where
M: RequestMessage,
E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<M::Response>> + Send,
{
self.add_message_handler(model, move |handle, envelope, this, cx| {
@ -616,7 +616,7 @@ impl Client {
where
M: EntityMessage,
E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<()>> + Send,
{
self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
@ -628,7 +628,7 @@ impl Client {
where
M: EntityMessage,
E: 'static + Send,
H: 'static + Send + Sync + Fn(AnyHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<()>> + Send,
{
let model_type_id = TypeId::of::<E>();
@ -667,7 +667,7 @@ impl Client {
where
M: EntityMessage + RequestMessage,
E: 'static + Send,
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<M::Response>> + Send,
{
self.add_model_message_handler(move |entity, envelope, client, cx| {
@ -1546,7 +1546,7 @@ mod tests {
let (done_tx1, mut done_rx1) = smol::channel::unbounded();
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
client.add_model_message_handler(
move |model: Handle<Model>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
match model.update(&mut cx, |model, _| model.id).unwrap() {
1 => done_tx1.try_send(()).unwrap(),
2 => done_tx2.try_send(()).unwrap(),
@ -1555,15 +1555,15 @@ mod tests {
async { Ok(()) }
},
);
let model1 = cx.entity(|_| Model {
let model1 = cx.build_model(|_| TestModel {
id: 1,
subscription: None,
});
let model2 = cx.entity(|_| Model {
let model2 = cx.build_model(|_| TestModel {
id: 2,
subscription: None,
});
let model3 = cx.entity(|_| Model {
let model3 = cx.build_model(|_| TestModel {
id: 3,
subscription: None,
});
@ -1596,7 +1596,7 @@ mod tests {
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
let server = FakeServer::for_client(user_id, &client, cx).await;
let model = cx.entity(|_| Model::default());
let model = cx.build_model(|_| TestModel::default());
let (done_tx1, _done_rx1) = smol::channel::unbounded();
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
let subscription1 = client.add_message_handler(
@ -1624,11 +1624,11 @@ mod tests {
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
let server = FakeServer::for_client(user_id, &client, cx).await;
let model = cx.entity(|_| Model::default());
let model = cx.build_model(|_| TestModel::default());
let (done_tx, mut done_rx) = smol::channel::unbounded();
let subscription = client.add_message_handler(
model.clone().downgrade(),
move |model: Handle<Model>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
model
.update(&mut cx, |model, _| model.subscription.take())
.unwrap();
@ -1644,7 +1644,7 @@ mod tests {
}
#[derive(Default)]
struct Model {
struct TestModel {
id: usize,
subscription: Option<Subscription>,
}

View File

@ -5,7 +5,9 @@ use parking_lot::Mutex;
use serde::Serialize;
use settings2::Settings;
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt};
use sysinfo::{
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
};
use tempfile::NamedTempFile;
use util::http::HttpClient;
use util::{channel::ReleaseChannel, TryFutureExt};
@ -161,8 +163,16 @@ impl Telemetry {
let this = self.clone();
cx.spawn(|cx| async move {
let mut system = System::new_all();
system.refresh_all();
// Avoiding calling `System::new_all()`, as there have been crashes related to it
let refresh_kind = RefreshKind::new()
.with_memory() // For memory usage
.with_processes(ProcessRefreshKind::everything()) // For process usage
.with_cpu(CpuRefreshKind::everything()); // For core count
let mut system = System::new_with_specifics(refresh_kind);
// Avoiding calling `refresh_all()`, just update what we need
system.refresh_specifics(refresh_kind);
loop {
// Waiting some amount of time before the first query is important to get a reasonable value
@ -170,8 +180,7 @@ impl Telemetry {
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
system.refresh_memory();
system.refresh_processes();
system.refresh_specifics(refresh_kind);
let current_process = Pid::from_u32(std::process::id());
let Some(process) = system.processes().get(&current_process) else {

View File

@ -1,7 +1,7 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result};
use futures::{stream::BoxStream, StreamExt};
use gpui2::{Context, Executor, Handle, TestAppContext};
use gpui2::{Context, Executor, Model, TestAppContext};
use parking_lot::Mutex;
use rpc2::{
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
@ -194,9 +194,9 @@ impl FakeServer {
&self,
client: Arc<Client>,
cx: &mut TestAppContext,
) -> Handle<UserStore> {
) -> Model<UserStore> {
let http_client = FakeHttpClient::with_404_response();
let user_store = cx.entity(|cx| UserStore::new(client, http_client, cx));
let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx));
assert_eq!(
self.receive::<proto::GetUsers>()
.await

View File

@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result};
use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags2::FeatureFlagAppExt;
use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
use gpui2::{AsyncAppContext, EventEmitter, Handle, ImageData, ModelContext, Task};
use gpui2::{AsyncAppContext, EventEmitter, ImageData, Model, ModelContext, Task};
use postage::{sink::Sink, watch};
use rpc2::proto::{RequestMessage, UsersResponse};
use std::sync::{Arc, Weak};
@ -213,7 +213,7 @@ impl UserStore {
}
async fn handle_update_invite_info(
this: Handle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::UpdateInviteInfo>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -229,7 +229,7 @@ impl UserStore {
}
async fn handle_show_contacts(
this: Handle<Self>,
this: Model<Self>,
_: TypedEnvelope<proto::ShowContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -243,7 +243,7 @@ impl UserStore {
}
async fn handle_update_contacts(
this: Handle<Self>,
this: Model<Self>,
message: TypedEnvelope<proto::UpdateContacts>,
_: Arc<Client>,
mut cx: AsyncAppContext,
@ -690,7 +690,7 @@ impl User {
impl Contact {
async fn from_proto(
contact: proto::Contact,
user_store: &Handle<UserStore>,
user_store: &Model<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Self> {
let user = user_store

View File

@ -7,8 +7,8 @@ use async_tar::Archive;
use collections::{HashMap, HashSet};
use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
use gpui2::{
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Handle, ModelContext, Task,
WeakHandle,
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Model, ModelContext, Task,
WeakModel,
};
use language2::{
language_settings::{all_language_settings, language_settings},
@ -49,7 +49,7 @@ pub fn init(
node_runtime: Arc<dyn NodeRuntime>,
cx: &mut AppContext,
) {
let copilot = cx.entity({
let copilot = cx.build_model({
let node_runtime = node_runtime.clone();
move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
});
@ -183,7 +183,7 @@ struct RegisteredBuffer {
impl RegisteredBuffer {
fn report_changes(
&mut self,
buffer: &Handle<Buffer>,
buffer: &Model<Buffer>,
cx: &mut ModelContext<Copilot>,
) -> oneshot::Receiver<(i32, BufferSnapshot)> {
let (done_tx, done_rx) = oneshot::channel();
@ -278,7 +278,7 @@ pub struct Copilot {
http: Arc<dyn HttpClient>,
node_runtime: Arc<dyn NodeRuntime>,
server: CopilotServer,
buffers: HashSet<WeakHandle<Buffer>>,
buffers: HashSet<WeakModel<Buffer>>,
server_id: LanguageServerId,
_subscription: gpui2::Subscription,
}
@ -292,9 +292,9 @@ impl EventEmitter for Copilot {
}
impl Copilot {
pub fn global(cx: &AppContext) -> Option<Handle<Self>> {
if cx.has_global::<Handle<Self>>() {
Some(cx.global::<Handle<Self>>().clone())
pub fn global(cx: &AppContext) -> Option<Model<Self>> {
if cx.has_global::<Model<Self>>() {
Some(cx.global::<Model<Self>>().clone())
} else {
None
}
@ -383,7 +383,7 @@ impl Copilot {
new_server_id: LanguageServerId,
http: Arc<dyn HttpClient>,
node_runtime: Arc<dyn NodeRuntime>,
this: WeakHandle<Self>,
this: WeakModel<Self>,
mut cx: AsyncAppContext,
) -> impl Future<Output = ()> {
async move {
@ -590,7 +590,7 @@ impl Copilot {
}
}
pub fn register_buffer(&mut self, buffer: &Handle<Buffer>, cx: &mut ModelContext<Self>) {
pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
let weak_buffer = buffer.downgrade();
self.buffers.insert(weak_buffer.clone());
@ -646,7 +646,7 @@ impl Copilot {
fn handle_buffer_event(
&mut self,
buffer: Handle<Buffer>,
buffer: Model<Buffer>,
event: &language2::Event,
cx: &mut ModelContext<Self>,
) -> Result<()> {
@ -706,7 +706,7 @@ impl Copilot {
Ok(())
}
fn unregister_buffer(&mut self, buffer: &WeakHandle<Buffer>) {
fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
if let Ok(server) = self.server.as_running() {
if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
server
@ -723,7 +723,7 @@ impl Copilot {
pub fn completions<T>(
&mut self,
buffer: &Handle<Buffer>,
buffer: &Model<Buffer>,
position: T,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>>
@ -735,7 +735,7 @@ impl Copilot {
pub fn completions_cycling<T>(
&mut self,
buffer: &Handle<Buffer>,
buffer: &Model<Buffer>,
position: T,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>>
@ -792,7 +792,7 @@ impl Copilot {
fn request_completions<R, T>(
&mut self,
buffer: &Handle<Buffer>,
buffer: &Model<Buffer>,
position: T,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>>
@ -926,7 +926,7 @@ fn id_for_language(language: Option<&Arc<Language>>) -> String {
}
}
fn uri_for_buffer(buffer: &Handle<Buffer>, cx: &AppContext) -> lsp2::Url {
fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp2::Url {
if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
lsp2::Url::from_file_path(file.abs_path(cx)).unwrap()
} else {

View File

@ -967,7 +967,6 @@ impl CompletionsMenu {
self.selected_item -= 1;
} else {
self.selected_item = self.matches.len() - 1;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
}
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
self.attempt_resolve_selected_completion_documentation(project, cx);
@ -1538,7 +1537,6 @@ impl CodeActionsMenu {
self.selected_item -= 1;
} else {
self.selected_item = self.actions.len() - 1;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
}
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
cx.notify();
@ -1547,11 +1545,10 @@ impl CodeActionsMenu {
fn select_next(&mut self, cx: &mut ViewContext<Editor>) {
if self.selected_item + 1 < self.actions.len() {
self.selected_item += 1;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
} else {
self.selected_item = 0;
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
}
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
cx.notify();
}

View File

@ -16,7 +16,7 @@ use crate::{
current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AnyWindowHandle,
AppMetadata, AssetSource, ClipboardItem, Context, DispatchPhase, DisplayId, Executor,
FocusEvent, FocusHandle, FocusId, KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly,
Pixels, Platform, Point, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
Pixels, Platform, Point, Render, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
TextStyle, TextStyleRefinement, TextSystem, View, ViewContext, Window, WindowContext,
WindowHandle, WindowId,
};
@ -309,10 +309,17 @@ impl AppContext {
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
) -> Result<R>
where
V: 'static,
V: 'static + Send,
{
self.update_window(handle.any_handle, |cx| {
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap();
let root_view = cx
.window
.root_view
.as_ref()
.unwrap()
.clone()
.downcast()
.unwrap();
root_view.update(cx, update)
})
}
@ -685,7 +692,7 @@ impl AppContext {
pub fn observe_release<E: 'static>(
&mut self,
handle: &Handle<E>,
handle: &Model<E>,
mut on_release: impl FnMut(&mut E, &mut AppContext) + Send + 'static,
) -> Subscription {
self.release_listeners.insert(
@ -750,35 +757,35 @@ impl AppContext {
}
impl Context for AppContext {
type EntityContext<'a, T> = ModelContext<'a, T>;
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T;
/// Build an entity that is owned by the application. The given function will be invoked with
/// a `ModelContext` and must return an object representing the entity. A `Handle` will be returned
/// a `ModelContext` and must return an object representing the entity. A `Model` will be returned
/// which can be used to access the entity in a context.
fn entity<T: 'static + Send>(
fn build_model<T: 'static + Send>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Handle<T> {
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Model<T> {
self.update(|cx| {
let slot = cx.entities.reserve();
let entity = build_entity(&mut ModelContext::mutable(cx, slot.downgrade()));
let entity = build_model(&mut ModelContext::mutable(cx, slot.downgrade()));
cx.entities.insert(slot, entity)
})
}
/// Update the entity referenced by the given handle. The function is passed a mutable reference to the
/// Update the entity referenced by the given model. The function is passed a mutable reference to the
/// entity along with a `ModelContext` for the entity.
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R {
self.update(|cx| {
let mut entity = cx.entities.lease(handle);
let mut entity = cx.entities.lease(model);
let result = update(
&mut entity,
&mut ModelContext::mutable(cx, handle.downgrade()),
&mut ModelContext::mutable(cx, model.downgrade()),
);
cx.entities.end_lease(entity);
result
@ -861,10 +868,17 @@ impl MainThread<AppContext> {
update: impl FnOnce(&mut V, &mut MainThread<ViewContext<'_, '_, V>>) -> R,
) -> Result<R>
where
V: 'static,
V: 'static + Send,
{
self.update_window(handle.any_handle, |cx| {
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap();
let root_view = cx
.window
.root_view
.as_ref()
.unwrap()
.clone()
.downcast()
.unwrap();
root_view.update(cx, update)
})
}
@ -872,7 +886,7 @@ impl MainThread<AppContext> {
/// Opens a new window with the given option and the root view returned by the given function.
/// The function is invoked with a `WindowContext`, which can be used to interact with window-specific
/// functionality.
pub fn open_window<V: 'static>(
pub fn open_window<V: Render>(
&mut self,
options: crate::WindowOptions,
build_root_view: impl FnOnce(&mut WindowContext) -> View<V> + Send + 'static,
@ -955,10 +969,8 @@ impl<G: 'static> DerefMut for GlobalLease<G> {
/// Contains state associated with an active drag operation, started by dragging an element
/// within the window or by dragging into the app from the underlying platform.
pub(crate) struct AnyDrag {
pub drag_handle_view: Option<AnyView>,
pub view: AnyView,
pub cursor_offset: Point<Pixels>,
pub state: AnyBox,
pub state_type: TypeId,
}
#[cfg(test)]

View File

@ -1,6 +1,6 @@
use crate::{
AnyWindowHandle, AppContext, Component, Context, Executor, Handle, MainThread, ModelContext,
Result, Task, View, ViewContext, VisualContext, WindowContext, WindowHandle,
AnyWindowHandle, AppContext, Context, Executor, MainThread, Model, ModelContext, Result, Task,
View, ViewContext, VisualContext, WindowContext, WindowHandle,
};
use anyhow::Context as _;
use derive_more::{Deref, DerefMut};
@ -14,25 +14,25 @@ pub struct AsyncAppContext {
}
impl Context for AsyncAppContext {
type EntityContext<'a, T> = ModelContext<'a, T>;
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = Result<T>;
fn entity<T: 'static>(
fn build_model<T: 'static>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Self::Result<Handle<T>>
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send,
{
let app = self.app.upgrade().context("app was released")?;
let mut lock = app.lock(); // Need this to compile
Ok(lock.entity(build_entity))
Ok(lock.build_model(build_model))
}
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> {
let app = self.app.upgrade().context("app was released")?;
let mut lock = app.lock(); // Need this to compile
@ -84,7 +84,7 @@ impl AsyncAppContext {
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
) -> Result<R>
where
V: 'static,
V: 'static + Send,
{
let app = self.app.upgrade().context("app was released")?;
let mut app_context = app.lock();
@ -234,24 +234,24 @@ impl AsyncWindowContext {
}
impl Context for AsyncWindowContext {
type EntityContext<'a, T> = ModelContext<'a, T>;
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = Result<T>;
fn entity<T>(
fn build_model<T>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Result<Handle<T>>
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Result<Model<T>>
where
T: 'static + Send,
{
self.app
.update_window(self.window, |cx| cx.entity(build_entity))
.update_window(self.window, |cx| cx.build_model(build_model))
}
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Result<R> {
self.app
.update_window(self.window, |cx| cx.update_entity(handle, update))
@ -261,17 +261,15 @@ impl Context for AsyncWindowContext {
impl VisualContext for AsyncWindowContext {
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
fn build_view<E, V>(
fn build_view<V>(
&mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
) -> Self::Result<View<V>>
where
E: Component<V>,
V: 'static + Send,
{
self.app
.update_window(self.window, |cx| cx.build_view(build_entity, render))
.update_window(self.window, |cx| cx.build_view(build_view_state))
}
fn update_view<V: 'static, R>(

View File

@ -1,4 +1,4 @@
use crate::{AnyBox, AppContext, Context, EntityHandle};
use crate::{AnyBox, AppContext, Context};
use anyhow::{anyhow, Result};
use derive_more::{Deref, DerefMut};
use parking_lot::{RwLock, RwLockUpgradableReadGuard};
@ -53,29 +53,29 @@ impl EntityMap {
/// Reserve a slot for an entity, which you can subsequently use with `insert`.
pub fn reserve<T: 'static>(&self) -> Slot<T> {
let id = self.ref_counts.write().counts.insert(1.into());
Slot(Handle::new(id, Arc::downgrade(&self.ref_counts)))
Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
}
/// Insert an entity into a slot obtained by calling `reserve`.
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Handle<T>
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
where
T: 'static + Send,
{
let handle = slot.0;
self.entities.insert(handle.entity_id, Box::new(entity));
handle
let model = slot.0;
self.entities.insert(model.entity_id, Box::new(entity));
model
}
/// Move an entity to the stack.
pub fn lease<'a, T>(&mut self, handle: &'a Handle<T>) -> Lease<'a, T> {
self.assert_valid_context(handle);
pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
self.assert_valid_context(model);
let entity = Some(
self.entities
.remove(handle.entity_id)
.remove(model.entity_id)
.expect("Circular entity lease. Is the entity already being updated?"),
);
Lease {
handle,
model,
entity,
entity_type: PhantomData,
}
@ -84,18 +84,18 @@ impl EntityMap {
/// Return an entity after moving it to the stack.
pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
self.entities
.insert(lease.handle.entity_id, lease.entity.take().unwrap());
.insert(lease.model.entity_id, lease.entity.take().unwrap());
}
pub fn read<T: 'static>(&self, handle: &Handle<T>) -> &T {
self.assert_valid_context(handle);
self.entities[handle.entity_id].downcast_ref().unwrap()
pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
self.assert_valid_context(model);
self.entities[model.entity_id].downcast_ref().unwrap()
}
fn assert_valid_context(&self, handle: &AnyHandle) {
fn assert_valid_context(&self, model: &AnyModel) {
debug_assert!(
Weak::ptr_eq(&handle.entity_map, &Arc::downgrade(&self.ref_counts)),
"used a handle with the wrong context"
Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
"used a model with the wrong context"
);
}
@ -115,7 +115,7 @@ impl EntityMap {
pub struct Lease<'a, T> {
entity: Option<AnyBox>,
pub handle: &'a Handle<T>,
pub model: &'a Model<T>,
entity_type: PhantomData<T>,
}
@ -143,15 +143,15 @@ impl<'a, T> Drop for Lease<'a, T> {
}
#[derive(Deref, DerefMut)]
pub struct Slot<T>(Handle<T>);
pub struct Slot<T>(Model<T>);
pub struct AnyHandle {
pub struct AnyModel {
pub(crate) entity_id: EntityId,
entity_type: TypeId,
pub(crate) entity_type: TypeId,
entity_map: Weak<RwLock<EntityRefCounts>>,
}
impl AnyHandle {
impl AnyModel {
fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
Self {
entity_id: id,
@ -164,18 +164,18 @@ impl AnyHandle {
self.entity_id
}
pub fn downgrade(&self) -> AnyWeakHandle {
AnyWeakHandle {
pub fn downgrade(&self) -> AnyWeakModel {
AnyWeakModel {
entity_id: self.entity_id,
entity_type: self.entity_type,
entity_ref_counts: self.entity_map.clone(),
}
}
pub fn downcast<T: 'static>(&self) -> Option<Handle<T>> {
pub fn downcast<T: 'static>(&self) -> Option<Model<T>> {
if TypeId::of::<T>() == self.entity_type {
Some(Handle {
any_handle: self.clone(),
Some(Model {
any_model: self.clone(),
entity_type: PhantomData,
})
} else {
@ -184,16 +184,16 @@ impl AnyHandle {
}
}
impl Clone for AnyHandle {
impl Clone for AnyModel {
fn clone(&self) -> Self {
if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.read();
let count = entity_map
.counts
.get(self.entity_id)
.expect("detected over-release of a handle");
.expect("detected over-release of a model");
let prev_count = count.fetch_add(1, SeqCst);
assert_ne!(prev_count, 0, "Detected over-release of a handle.");
assert_ne!(prev_count, 0, "Detected over-release of a model.");
}
Self {
@ -204,16 +204,16 @@ impl Clone for AnyHandle {
}
}
impl Drop for AnyHandle {
impl Drop for AnyModel {
fn drop(&mut self) {
if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.upgradable_read();
let count = entity_map
.counts
.get(self.entity_id)
.expect("Detected over-release of a handle.");
.expect("Detected over-release of a model.");
let prev_count = count.fetch_sub(1, SeqCst);
assert_ne!(prev_count, 0, "Detected over-release of a handle.");
assert_ne!(prev_count, 0, "Detected over-release of a model.");
if prev_count == 1 {
// We were the last reference to this entity, so we can remove it.
let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
@ -223,60 +223,65 @@ impl Drop for AnyHandle {
}
}
impl<T> From<Handle<T>> for AnyHandle {
fn from(handle: Handle<T>) -> Self {
handle.any_handle
impl<T> From<Model<T>> for AnyModel {
fn from(model: Model<T>) -> Self {
model.any_model
}
}
impl Hash for AnyHandle {
impl Hash for AnyModel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state);
}
}
impl PartialEq for AnyHandle {
impl PartialEq for AnyModel {
fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id
}
}
impl Eq for AnyHandle {}
impl Eq for AnyModel {}
#[derive(Deref, DerefMut)]
pub struct Handle<T> {
pub struct Model<T> {
#[deref]
#[deref_mut]
any_handle: AnyHandle,
entity_type: PhantomData<T>,
pub(crate) any_model: AnyModel,
pub(crate) entity_type: PhantomData<T>,
}
unsafe impl<T> Send for Handle<T> {}
unsafe impl<T> Sync for Handle<T> {}
unsafe impl<T> Send for Model<T> {}
unsafe impl<T> Sync for Model<T> {}
impl<T: 'static> Handle<T> {
impl<T: 'static> Model<T> {
fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
where
T: 'static,
{
Self {
any_handle: AnyHandle::new(id, TypeId::of::<T>(), entity_map),
any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
entity_type: PhantomData,
}
}
pub fn downgrade(&self) -> WeakHandle<T> {
WeakHandle {
any_handle: self.any_handle.downgrade(),
pub fn downgrade(&self) -> WeakModel<T> {
WeakModel {
any_model: self.any_model.downgrade(),
entity_type: self.entity_type,
}
}
/// Convert this into a dynamically typed model.
pub fn into_any(self) -> AnyModel {
self.any_model
}
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
cx.entities.read(self)
}
/// Update the entity referenced by this handle with the given function.
/// Update the entity referenced by this model with the given function.
///
/// The update function receives a context appropriate for its environment.
/// When updating in an `AppContext`, it receives a `ModelContext`.
@ -284,7 +289,7 @@ impl<T: 'static> Handle<T> {
pub fn update<C, R>(
&self,
cx: &mut C,
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R,
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
) -> C::Result<R>
where
C: Context,
@ -293,73 +298,54 @@ impl<T: 'static> Handle<T> {
}
}
impl<T> Clone for Handle<T> {
impl<T> Clone for Model<T> {
fn clone(&self) -> Self {
Self {
any_handle: self.any_handle.clone(),
any_model: self.any_model.clone(),
entity_type: self.entity_type,
}
}
}
impl<T> std::fmt::Debug for Handle<T> {
impl<T> std::fmt::Debug for Model<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Handle {{ entity_id: {:?}, entity_type: {:?} }}",
self.any_handle.entity_id,
"Model {{ entity_id: {:?}, entity_type: {:?} }}",
self.any_model.entity_id,
type_name::<T>()
)
}
}
impl<T> Hash for Handle<T> {
impl<T> Hash for Model<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.any_handle.hash(state);
self.any_model.hash(state);
}
}
impl<T> PartialEq for Handle<T> {
impl<T> PartialEq for Model<T> {
fn eq(&self, other: &Self) -> bool {
self.any_handle == other.any_handle
self.any_model == other.any_model
}
}
impl<T> Eq for Handle<T> {}
impl<T> Eq for Model<T> {}
impl<T> PartialEq<WeakHandle<T>> for Handle<T> {
fn eq(&self, other: &WeakHandle<T>) -> bool {
self.entity_id == other.entity_id
}
}
impl<T: 'static> EntityHandle<T> for Handle<T> {
type Weak = WeakHandle<T>;
fn entity_id(&self) -> EntityId {
self.entity_id
}
fn downgrade(&self) -> Self::Weak {
self.downgrade()
}
fn upgrade_from(weak: &Self::Weak) -> Option<Self>
where
Self: Sized,
{
weak.upgrade()
impl<T> PartialEq<WeakModel<T>> for Model<T> {
fn eq(&self, other: &WeakModel<T>) -> bool {
self.entity_id() == other.entity_id()
}
}
#[derive(Clone)]
pub struct AnyWeakHandle {
pub struct AnyWeakModel {
pub(crate) entity_id: EntityId,
entity_type: TypeId,
entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
}
impl AnyWeakHandle {
impl AnyWeakModel {
pub fn entity_id(&self) -> EntityId {
self.entity_id
}
@ -373,14 +359,14 @@ impl AnyWeakHandle {
ref_count > 0
}
pub fn upgrade(&self) -> Option<AnyHandle> {
pub fn upgrade(&self) -> Option<AnyModel> {
let entity_map = self.entity_ref_counts.upgrade()?;
entity_map
.read()
.counts
.get(self.entity_id)?
.fetch_add(1, SeqCst);
Some(AnyHandle {
Some(AnyModel {
entity_id: self.entity_id,
entity_type: self.entity_type,
entity_map: self.entity_ref_counts.clone(),
@ -388,55 +374,55 @@ impl AnyWeakHandle {
}
}
impl<T> From<WeakHandle<T>> for AnyWeakHandle {
fn from(handle: WeakHandle<T>) -> Self {
handle.any_handle
impl<T> From<WeakModel<T>> for AnyWeakModel {
fn from(model: WeakModel<T>) -> Self {
model.any_model
}
}
impl Hash for AnyWeakHandle {
impl Hash for AnyWeakModel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state);
}
}
impl PartialEq for AnyWeakHandle {
impl PartialEq for AnyWeakModel {
fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id
}
}
impl Eq for AnyWeakHandle {}
impl Eq for AnyWeakModel {}
#[derive(Deref, DerefMut)]
pub struct WeakHandle<T> {
pub struct WeakModel<T> {
#[deref]
#[deref_mut]
any_handle: AnyWeakHandle,
any_model: AnyWeakModel,
entity_type: PhantomData<T>,
}
unsafe impl<T> Send for WeakHandle<T> {}
unsafe impl<T> Sync for WeakHandle<T> {}
unsafe impl<T> Send for WeakModel<T> {}
unsafe impl<T> Sync for WeakModel<T> {}
impl<T> Clone for WeakHandle<T> {
impl<T> Clone for WeakModel<T> {
fn clone(&self) -> Self {
Self {
any_handle: self.any_handle.clone(),
any_model: self.any_model.clone(),
entity_type: self.entity_type,
}
}
}
impl<T: 'static> WeakHandle<T> {
pub fn upgrade(&self) -> Option<Handle<T>> {
Some(Handle {
any_handle: self.any_handle.upgrade()?,
impl<T: 'static> WeakModel<T> {
pub fn upgrade(&self) -> Option<Model<T>> {
Some(Model {
any_model: self.any_model.upgrade()?,
entity_type: self.entity_type,
})
}
/// Update the entity referenced by this handle with the given function if
/// Update the entity referenced by this model with the given function if
/// the referenced entity still exists. Returns an error if the entity has
/// been released.
///
@ -446,7 +432,7 @@ impl<T: 'static> WeakHandle<T> {
pub fn update<C, R>(
&self,
cx: &mut C,
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R,
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
) -> Result<R>
where
C: Context,
@ -460,22 +446,22 @@ impl<T: 'static> WeakHandle<T> {
}
}
impl<T> Hash for WeakHandle<T> {
impl<T> Hash for WeakModel<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.any_handle.hash(state);
self.any_model.hash(state);
}
}
impl<T> PartialEq for WeakHandle<T> {
impl<T> PartialEq for WeakModel<T> {
fn eq(&self, other: &Self) -> bool {
self.any_handle == other.any_handle
self.any_model == other.any_model
}
}
impl<T> Eq for WeakHandle<T> {}
impl<T> Eq for WeakModel<T> {}
impl<T> PartialEq<Handle<T>> for WeakHandle<T> {
fn eq(&self, other: &Handle<T>) -> bool {
self.entity_id == other.entity_id
impl<T> PartialEq<Model<T>> for WeakModel<T> {
fn eq(&self, other: &Model<T>) -> bool {
self.entity_id() == other.entity_id()
}
}

View File

@ -1,6 +1,6 @@
use crate::{
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, Handle, MainThread,
Reference, Subscription, Task, WeakHandle,
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, MainThread, Model,
Reference, Subscription, Task, WeakModel,
};
use derive_more::{Deref, DerefMut};
use futures::FutureExt;
@ -15,11 +15,11 @@ pub struct ModelContext<'a, T> {
#[deref]
#[deref_mut]
app: Reference<'a, AppContext>,
model_state: WeakHandle<T>,
model_state: WeakModel<T>,
}
impl<'a, T: 'static> ModelContext<'a, T> {
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakHandle<T>) -> Self {
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel<T>) -> Self {
Self {
app: Reference::Mutable(app),
model_state,
@ -30,20 +30,20 @@ impl<'a, T: 'static> ModelContext<'a, T> {
self.model_state.entity_id
}
pub fn handle(&self) -> Handle<T> {
pub fn handle(&self) -> Model<T> {
self.weak_handle()
.upgrade()
.expect("The entity must be alive if we have a model context")
}
pub fn weak_handle(&self) -> WeakHandle<T> {
pub fn weak_handle(&self) -> WeakModel<T> {
self.model_state.clone()
}
pub fn observe<T2: 'static>(
&mut self,
handle: &Handle<T2>,
mut on_notify: impl FnMut(&mut T, Handle<T2>, &mut ModelContext<'_, T>) + Send + 'static,
handle: &Model<T2>,
mut on_notify: impl FnMut(&mut T, Model<T2>, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription
where
T: 'static + Send,
@ -65,10 +65,8 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn subscribe<E: 'static + EventEmitter>(
&mut self,
handle: &Handle<E>,
mut on_event: impl FnMut(&mut T, Handle<E>, &E::Event, &mut ModelContext<'_, T>)
+ Send
+ 'static,
handle: &Model<E>,
mut on_event: impl FnMut(&mut T, Model<E>, &E::Event, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription
where
T: 'static + Send,
@ -107,7 +105,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn observe_release<E: 'static>(
&mut self,
handle: &Handle<E>,
handle: &Model<E>,
mut on_release: impl FnMut(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + 'static,
) -> Subscription
where
@ -182,7 +180,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn spawn<Fut, R>(
&self,
f: impl FnOnce(WeakHandle<T>, AsyncAppContext) -> Fut + Send + 'static,
f: impl FnOnce(WeakModel<T>, AsyncAppContext) -> Fut + Send + 'static,
) -> Task<R>
where
T: 'static,
@ -195,7 +193,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn spawn_on_main<Fut, R>(
&self,
f: impl FnOnce(WeakHandle<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
f: impl FnOnce(WeakModel<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
) -> Task<R>
where
Fut: Future<Output = R> + 'static,
@ -220,23 +218,23 @@ where
}
impl<'a, T> Context for ModelContext<'a, T> {
type EntityContext<'b, U> = ModelContext<'b, U>;
type ModelContext<'b, U> = ModelContext<'b, U>;
type Result<U> = U;
fn entity<U>(
fn build_model<U>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, U>) -> U,
) -> Handle<U>
build_model: impl FnOnce(&mut Self::ModelContext<'_, U>) -> U,
) -> Model<U>
where
U: 'static + Send,
{
self.app.entity(build_entity)
self.app.build_model(build_model)
}
fn update_entity<U: 'static, R>(
&mut self,
handle: &Handle<U>,
update: impl FnOnce(&mut U, &mut Self::EntityContext<'_, U>) -> R,
handle: &Model<U>,
update: impl FnOnce(&mut U, &mut Self::ModelContext<'_, U>) -> R,
) -> R {
self.app.update_entity(handle, update)
}

View File

@ -1,5 +1,5 @@
use crate::{
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, Handle, MainThread,
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, MainThread, Model,
ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext,
};
use parking_lot::Mutex;
@ -12,24 +12,24 @@ pub struct TestAppContext {
}
impl Context for TestAppContext {
type EntityContext<'a, T> = ModelContext<'a, T>;
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T;
fn entity<T: 'static>(
fn build_model<T: 'static>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Self::Result<Handle<T>>
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send,
{
let mut lock = self.app.lock();
lock.entity(build_entity)
lock.build_model(build_model)
}
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> {
let mut lock = self.app.lock();
lock.update_entity(handle, update)

View File

@ -4,7 +4,7 @@ pub(crate) use smallvec::SmallVec;
use std::{any::Any, mem};
pub trait Element<V: 'static> {
type ElementState: 'static;
type ElementState: 'static + Send;
fn id(&self) -> Option<ElementId>;

View File

@ -70,33 +70,31 @@ use taffy::TaffyLayoutEngine;
type AnyBox = Box<dyn Any + Send>;
pub trait Context {
type EntityContext<'a, T>;
type ModelContext<'a, T>;
type Result<T>;
fn entity<T>(
fn build_model<T>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Self::Result<Handle<T>>
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send;
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R>;
}
pub trait VisualContext: Context {
type ViewContext<'a, 'w, V>;
fn build_view<E, V>(
fn build_view<V>(
&mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
) -> Self::Result<View<V>>
where
E: Component<V>,
V: 'static + Send;
fn update_view<V: 'static, R>(
@ -140,37 +138,37 @@ impl<T> DerefMut for MainThread<T> {
}
impl<C: Context> Context for MainThread<C> {
type EntityContext<'a, T> = MainThread<C::EntityContext<'a, T>>;
type ModelContext<'a, T> = MainThread<C::ModelContext<'a, T>>;
type Result<T> = C::Result<T>;
fn entity<T>(
fn build_model<T>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Self::Result<Handle<T>>
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Self::Result<Model<T>>
where
T: 'static + Send,
{
self.0.entity(|cx| {
self.0.build_model(|cx| {
let cx = unsafe {
mem::transmute::<
&mut C::EntityContext<'_, T>,
&mut MainThread<C::EntityContext<'_, T>>,
&mut C::ModelContext<'_, T>,
&mut MainThread<C::ModelContext<'_, T>>,
>(cx)
};
build_entity(cx)
build_model(cx)
})
}
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
handle: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> Self::Result<R> {
self.0.update_entity(handle, |entity, cx| {
let cx = unsafe {
mem::transmute::<
&mut C::EntityContext<'_, T>,
&mut MainThread<C::EntityContext<'_, T>>,
&mut C::ModelContext<'_, T>,
&mut MainThread<C::ModelContext<'_, T>>,
>(cx)
};
update(entity, cx)
@ -181,27 +179,22 @@ impl<C: Context> Context for MainThread<C> {
impl<C: VisualContext> VisualContext for MainThread<C> {
type ViewContext<'a, 'w, V> = MainThread<C::ViewContext<'a, 'w, V>>;
fn build_view<E, V>(
fn build_view<V>(
&mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
) -> Self::Result<View<V>>
where
E: Component<V>,
V: 'static + Send,
{
self.0.build_view(
|cx| {
let cx = unsafe {
mem::transmute::<
&mut C::ViewContext<'_, '_, V>,
&mut MainThread<C::ViewContext<'_, '_, V>>,
>(cx)
};
build_entity(cx)
},
render,
)
self.0.build_view(|cx| {
let cx = unsafe {
mem::transmute::<
&mut C::ViewContext<'_, '_, V>,
&mut MainThread<C::ViewContext<'_, '_, V>>,
>(cx)
};
build_view_state(cx)
})
}
fn update_view<V: 'static, R>(

View File

@ -1,7 +1,7 @@
use crate::{
point, px, Action, AnyBox, AnyDrag, AppContext, BorrowWindow, Bounds, Component,
DispatchContext, DispatchPhase, Element, ElementId, FocusHandle, KeyMatch, Keystroke,
Modifiers, Overflow, Pixels, Point, SharedString, Size, Style, StyleRefinement, View,
div, point, px, Action, AnyDrag, AnyView, AppContext, BorrowWindow, Bounds, Component,
DispatchContext, DispatchPhase, Div, Element, ElementId, FocusHandle, KeyMatch, Keystroke,
Modifiers, Overflow, Pixels, Point, Render, SharedString, Size, Style, StyleRefinement, View,
ViewContext,
};
use collections::HashMap;
@ -258,17 +258,17 @@ pub trait StatelessInteractive<V: 'static>: Element<V> {
self
}
fn on_drop<S: 'static>(
fn on_drop<W: 'static + Send>(
mut self,
listener: impl Fn(&mut V, S, &mut ViewContext<V>) + Send + 'static,
listener: impl Fn(&mut V, View<W>, &mut ViewContext<V>) + Send + 'static,
) -> Self
where
Self: Sized,
{
self.stateless_interaction().drop_listeners.push((
TypeId::of::<S>(),
Box::new(move |view, drag_state, cx| {
listener(view, *drag_state.downcast().unwrap(), cx);
TypeId::of::<W>(),
Box::new(move |view, dragged_view, cx| {
listener(view, dragged_view.downcast().unwrap(), cx);
}),
));
self
@ -314,36 +314,22 @@ pub trait StatefulInteractive<V: 'static>: StatelessInteractive<V> {
self
}
fn on_drag<S, R, E>(
fn on_drag<W>(
mut self,
listener: impl Fn(&mut V, &mut ViewContext<V>) -> Drag<S, R, V, E> + Send + 'static,
listener: impl Fn(&mut V, &mut ViewContext<V>) -> View<W> + Send + 'static,
) -> Self
where
Self: Sized,
S: Any + Send,
R: Fn(&mut V, &mut ViewContext<V>) -> E,
R: 'static + Send,
E: Component<V>,
W: 'static + Send + Render,
{
debug_assert!(
self.stateful_interaction().drag_listener.is_none(),
"calling on_drag more than once on the same element is not supported"
);
self.stateful_interaction().drag_listener =
Some(Box::new(move |view_state, cursor_offset, cx| {
let drag = listener(view_state, cx);
let drag_handle_view = Some(
View::for_handle(cx.handle().upgrade().unwrap(), move |view_state, cx| {
(drag.render_drag_handle)(view_state, cx)
})
.into_any(),
);
AnyDrag {
drag_handle_view,
cursor_offset,
state: Box::new(drag.state),
state_type: TypeId::of::<S>(),
}
Some(Box::new(move |view_state, cursor_offset, cx| AnyDrag {
view: listener(view_state, cx).into_any(),
cursor_offset,
}));
self
}
@ -412,7 +398,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
if let Some(drag) = cx.active_drag.take() {
for (state_type, group_drag_style) in &self.as_stateless().group_drag_over_styles {
if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
if *state_type == drag.state_type
if *state_type == drag.view.entity_type()
&& group_bounds.contains_point(&mouse_position)
{
style.refine(&group_drag_style.style);
@ -421,7 +407,8 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
}
for (state_type, drag_over_style) in &self.as_stateless().drag_over_styles {
if *state_type == drag.state_type && bounds.contains_point(&mouse_position) {
if *state_type == drag.view.entity_type() && bounds.contains_point(&mouse_position)
{
style.refine(drag_over_style);
}
}
@ -509,7 +496,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
cx.on_mouse_event(move |view, event: &MouseUpEvent, phase, cx| {
if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
if let Some(drag_state_type) =
cx.active_drag.as_ref().map(|drag| drag.state_type)
cx.active_drag.as_ref().map(|drag| drag.view.entity_type())
{
for (drop_state_type, listener) in &drop_listeners {
if *drop_state_type == drag_state_type {
@ -517,7 +504,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
.active_drag
.take()
.expect("checked for type drag state type above");
listener(view, drag.state, cx);
listener(view, drag.view.clone(), cx);
cx.notify();
cx.stop_propagation();
}
@ -685,7 +672,7 @@ impl<V> From<ElementId> for StatefulInteraction<V> {
}
}
type DropListener<V> = dyn Fn(&mut V, AnyBox, &mut ViewContext<V>) + 'static + Send;
type DropListener<V> = dyn Fn(&mut V, AnyView, &mut ViewContext<V>) + 'static + Send;
pub struct StatelessInteraction<V> {
pub dispatch_context: DispatchContext,
@ -866,7 +853,7 @@ pub struct Drag<S, R, V, E>
where
R: Fn(&mut V, &mut ViewContext<V>) -> E,
V: 'static,
E: Component<V>,
E: Component<()>,
{
pub state: S,
pub render_drag_handle: R,
@ -877,7 +864,7 @@ impl<S, R, V, E> Drag<S, R, V, E>
where
R: Fn(&mut V, &mut ViewContext<V>) -> E,
V: 'static,
E: Component<V>,
E: Component<()>,
{
pub fn new(state: S, render_drag_handle: R) -> Self {
Drag {
@ -888,6 +875,10 @@ where
}
}
// impl<S, R, V, E> Render for Drag<S, R, V, E> {
// // fn render(&mut self, cx: ViewContext<Self>) ->
// }
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
pub enum MouseButton {
Left,
@ -995,6 +986,14 @@ impl Deref for MouseExitEvent {
#[derive(Debug, Clone, Default)]
pub struct ExternalPaths(pub(crate) SmallVec<[PathBuf; 2]>);
impl Render for ExternalPaths {
type Element = Div<Self>;
fn render(&mut self, _: &mut ViewContext<Self>) -> Self::Element {
div() // Intentionally left empty because the platform will render icons for the dragged files
}
}
#[derive(Debug, Clone)]
pub enum FileDropEvent {
Entered {

View File

@ -1,45 +1,35 @@
use crate::{
AnyBox, AnyElement, AppContext, AvailableSpace, BorrowWindow, Bounds, Component, Element,
ElementId, EntityHandle, EntityId, Flatten, Handle, LayoutId, Pixels, Size, ViewContext,
VisualContext, WeakHandle, WindowContext,
AnyBox, AnyElement, AnyModel, AppContext, AvailableSpace, BorrowWindow, Bounds, Component,
Element, ElementId, EntityHandle, EntityId, Flatten, LayoutId, Model, Pixels, Size,
ViewContext, VisualContext, WeakModel, WindowContext,
};
use anyhow::{Context, Result};
use parking_lot::Mutex;
use std::{
any::Any,
any::{Any, TypeId},
marker::PhantomData,
sync::{Arc, Weak},
sync::Arc,
};
pub struct View<V> {
pub(crate) state: Handle<V>,
render: Arc<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
pub trait Render: 'static + Sized {
type Element: Element<Self> + 'static + Send;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element;
}
impl<V: 'static> View<V> {
pub fn for_handle<E>(
state: Handle<V>,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
) -> View<V>
where
E: Component<V>,
{
View {
state,
render: Arc::new(Mutex::new(
move |state: &mut V, cx: &mut ViewContext<'_, '_, V>| render(state, cx).render(),
)),
}
}
pub struct View<V> {
pub(crate) model: Model<V>,
}
impl<V: Render> View<V> {
pub fn into_any(self) -> AnyView {
AnyView(Arc::new(self))
}
}
impl<V: 'static> View<V> {
pub fn downgrade(&self) -> WeakView<V> {
WeakView {
state: self.state.downgrade(),
render: Arc::downgrade(&self.render),
model: self.model.downgrade(),
}
}
@ -55,20 +45,19 @@ impl<V: 'static> View<V> {
}
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a V {
cx.entities.read(&self.state)
self.model.read(cx)
}
}
impl<V> Clone for View<V> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
render: self.render.clone(),
model: self.model.clone(),
}
}
}
impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V> {
impl<V: Render, ParentViewState: 'static> Component<ParentViewState> for View<V> {
fn render(self) -> AnyElement<ParentViewState> {
AnyElement::new(EraseViewState {
view: self,
@ -77,11 +66,14 @@ impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V
}
}
impl<V: 'static> Element<()> for View<V> {
impl<V> Element<()> for View<V>
where
V: Render,
{
type ElementState = AnyElement<V>;
fn id(&self) -> Option<ElementId> {
Some(ElementId::View(self.state.entity_id))
fn id(&self) -> Option<crate::ElementId> {
Some(ElementId::View(self.model.entity_id))
}
fn initialize(
@ -91,7 +83,7 @@ impl<V: 'static> Element<()> for View<V> {
cx: &mut ViewContext<()>,
) -> Self::ElementState {
self.update(cx, |state, cx| {
let mut any_element = (self.render.lock())(state, cx);
let mut any_element = AnyElement::new(state.render(cx));
any_element.initialize(state, cx);
any_element
})
@ -121,7 +113,7 @@ impl<T: 'static> EntityHandle<T> for View<T> {
type Weak = WeakView<T>;
fn entity_id(&self) -> EntityId {
self.state.entity_id
self.model.entity_id
}
fn downgrade(&self) -> Self::Weak {
@ -137,15 +129,13 @@ impl<T: 'static> EntityHandle<T> for View<T> {
}
pub struct WeakView<V> {
pub(crate) state: WeakHandle<V>,
render: Weak<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
pub(crate) model: WeakModel<V>,
}
impl<V: 'static> WeakView<V> {
pub fn upgrade(&self) -> Option<View<V>> {
let state = self.state.upgrade()?;
let render = self.render.upgrade()?;
Some(View { state, render })
let model = self.model.upgrade()?;
Some(View { model })
}
pub fn update<C, R>(
@ -165,8 +155,7 @@ impl<V: 'static> WeakView<V> {
impl<V> Clone for WeakView<V> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
render: self.render.clone(),
model: self.model.clone(),
}
}
}
@ -178,13 +167,13 @@ struct EraseViewState<V, ParentV> {
unsafe impl<V, ParentV> Send for EraseViewState<V, ParentV> {}
impl<V: 'static, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> {
impl<V: Render, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> {
fn render(self) -> AnyElement<ParentV> {
AnyElement::new(self)
}
}
impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> {
impl<V: Render, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> {
type ElementState = AnyBox;
fn id(&self) -> Option<ElementId> {
@ -221,30 +210,43 @@ impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, Parent
}
trait ViewObject: Send + Sync {
fn entity_type(&self) -> TypeId;
fn entity_id(&self) -> EntityId;
fn model(&self) -> AnyModel;
fn initialize(&self, cx: &mut WindowContext) -> AnyBox;
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId;
fn paint(&self, bounds: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext);
fn as_any(&self) -> &dyn Any;
}
impl<V: 'static> ViewObject for View<V> {
impl<V> ViewObject for View<V>
where
V: Render,
{
fn entity_type(&self) -> TypeId {
TypeId::of::<V>()
}
fn entity_id(&self) -> EntityId {
self.state.entity_id
self.model.entity_id
}
fn model(&self) -> AnyModel {
self.model.clone().into_any()
}
fn initialize(&self, cx: &mut WindowContext) -> AnyBox {
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
self.update(cx, |state, cx| {
let mut any_element = Box::new((self.render.lock())(state, cx));
let mut any_element = Box::new(AnyElement::new(state.render(cx)));
any_element.initialize(state, cx);
any_element as AnyBox
any_element
})
})
}
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId {
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
self.update(cx, |state, cx| {
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
element.layout(state, cx)
@ -253,7 +255,7 @@ impl<V: 'static> ViewObject for View<V> {
}
fn paint(&self, _: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext) {
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
self.update(cx, |state, cx| {
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
element.paint(state, cx);
@ -270,8 +272,12 @@ impl<V: 'static> ViewObject for View<V> {
pub struct AnyView(Arc<dyn ViewObject>);
impl AnyView {
pub fn downcast<V: 'static>(&self) -> Option<View<V>> {
self.0.as_any().downcast_ref().cloned()
pub fn downcast<V: 'static + Send>(self) -> Option<View<V>> {
self.0.model().downcast().map(|model| View { model })
}
pub(crate) fn entity_type(&self) -> TypeId {
self.0.entity_type()
}
pub(crate) fn draw(&self, available_space: Size<AvailableSpace>, cx: &mut WindowContext) {
@ -343,6 +349,18 @@ impl<ParentV: 'static> Component<ParentV> for EraseAnyViewState<ParentV> {
}
}
impl<T, E> Render for T
where
T: 'static + FnMut(&mut WindowContext) -> E,
E: 'static + Send + Element<T>,
{
type Element = E;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
(self)(cx)
}
}
impl<ParentV: 'static> Element<ParentV> for EraseAnyViewState<ParentV> {
type ElementState = AnyBox;

View File

@ -1,14 +1,14 @@
use crate::{
px, size, Action, AnyBox, AnyDrag, AnyView, AppContext, AsyncWindowContext, AvailableSpace,
Bounds, BoxShadow, Context, Corners, DevicePixels, DispatchContext, DisplayId, Edges, Effect,
EntityHandle, EntityId, EventEmitter, ExternalPaths, FileDropEvent, FocusEvent, FontId,
GlobalElementId, GlyphId, Handle, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch,
KeyMatcher, Keystroke, LayoutId, MainThread, MainThreadOnly, ModelContext, Modifiers,
MonochromeSprite, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels,
PlatformAtlas, PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams,
RenderImageParams, RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size,
Style, Subscription, TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext,
WeakHandle, WeakView, WindowOptions, SUBPIXEL_VARIANTS,
EntityHandle, EntityId, EventEmitter, FileDropEvent, FocusEvent, FontId, GlobalElementId,
GlyphId, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch, KeyMatcher, Keystroke,
LayoutId, MainThread, MainThreadOnly, Model, ModelContext, Modifiers, MonochromeSprite,
MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas,
PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams, RenderImageParams,
RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size, Style, Subscription,
TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakModel, WeakView,
WindowOptions, SUBPIXEL_VARIANTS,
};
use anyhow::Result;
use collections::HashMap;
@ -918,15 +918,13 @@ impl<'a, 'w> WindowContext<'a, 'w> {
root_view.draw(available_space, cx);
});
if let Some(mut active_drag) = self.app.active_drag.take() {
if let Some(active_drag) = self.app.active_drag.take() {
self.stack(1, |cx| {
let offset = cx.mouse_position() - active_drag.cursor_offset;
cx.with_element_offset(Some(offset), |cx| {
let available_space =
size(AvailableSpace::MinContent, AvailableSpace::MinContent);
if let Some(drag_handle_view) = &mut active_drag.drag_handle_view {
drag_handle_view.draw(available_space, cx);
}
active_drag.view.draw(available_space, cx);
cx.active_drag = Some(active_drag);
});
});
@ -994,12 +992,12 @@ impl<'a, 'w> WindowContext<'a, 'w> {
InputEvent::FileDrop(file_drop) => match file_drop {
FileDropEvent::Entered { position, files } => {
self.window.mouse_position = position;
self.active_drag.get_or_insert_with(|| AnyDrag {
drag_handle_view: None,
cursor_offset: position,
state: Box::new(files),
state_type: TypeId::of::<ExternalPaths>(),
});
if self.active_drag.is_none() {
self.active_drag = Some(AnyDrag {
view: self.build_view(|_| files).into_any(),
cursor_offset: position,
});
}
InputEvent::MouseDown(MouseDownEvent {
position,
button: MouseButton::Left,
@ -1267,30 +1265,30 @@ impl<'a, 'w> WindowContext<'a, 'w> {
}
impl Context for WindowContext<'_, '_> {
type EntityContext<'a, T> = ModelContext<'a, T>;
type ModelContext<'a, T> = ModelContext<'a, T>;
type Result<T> = T;
fn entity<T>(
fn build_model<T>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Handle<T>
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Model<T>
where
T: 'static + Send,
{
let slot = self.app.entities.reserve();
let entity = build_entity(&mut ModelContext::mutable(&mut *self.app, slot.downgrade()));
self.entities.insert(slot, entity)
let model = build_model(&mut ModelContext::mutable(&mut *self.app, slot.downgrade()));
self.entities.insert(slot, model)
}
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R {
let mut entity = self.entities.lease(handle);
let mut entity = self.entities.lease(model);
let result = update(
&mut *entity,
&mut ModelContext::mutable(&mut *self.app, handle.downgrade()),
&mut ModelContext::mutable(&mut *self.app, model.downgrade()),
);
self.entities.end_lease(entity);
result
@ -1300,21 +1298,17 @@ impl Context for WindowContext<'_, '_> {
impl VisualContext for WindowContext<'_, '_> {
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
/// Builds a new view in the current window. The first argument is a function that builds
/// an entity representing the view's state. It is invoked with a `ViewContext` that provides
/// entity-specific access to the window and application state during construction. The second
/// argument is a render function that returns a component based on the view's state.
fn build_view<E, V>(
fn build_view<V>(
&mut self,
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
) -> Self::Result<View<V>>
where
E: crate::Component<V>,
V: 'static + Send,
{
let slot = self.app.entities.reserve();
let view = View::for_handle(slot.clone(), render);
let view = View {
model: slot.clone(),
};
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
let entity = build_view_state(&mut cx);
self.entities.insert(slot, entity);
@ -1327,7 +1321,7 @@ impl VisualContext for WindowContext<'_, '_> {
view: &View<T>,
update: impl FnOnce(&mut T, &mut Self::ViewContext<'_, '_, T>) -> R,
) -> Self::Result<R> {
let mut lease = self.app.entities.lease(&view.state);
let mut lease = self.app.entities.lease(&view.model);
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
let result = update(&mut *lease, &mut cx);
cx.app.entities.end_lease(lease);
@ -1582,8 +1576,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
self.view.clone()
}
pub fn handle(&self) -> WeakHandle<V> {
self.view.state.clone()
pub fn model(&self) -> WeakModel<V> {
self.view.model.clone()
}
pub fn stack<R>(&mut self, order: u32, f: impl FnOnce(&mut Self) -> R) -> R {
@ -1603,8 +1597,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
pub fn observe<E>(
&mut self,
handle: &Handle<E>,
mut on_notify: impl FnMut(&mut V, Handle<E>, &mut ViewContext<'_, '_, V>) + Send + 'static,
handle: &Model<E>,
mut on_notify: impl FnMut(&mut V, Model<E>, &mut ViewContext<'_, '_, V>) + Send + 'static,
) -> Subscription
where
E: 'static,
@ -1665,7 +1659,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
) -> Subscription {
let window_handle = self.window.handle;
self.app.release_listeners.insert(
self.view.state.entity_id,
self.view.model.entity_id,
Box::new(move |this, cx| {
let this = this.downcast_mut().expect("invalid entity type");
// todo!("are we okay with silently swallowing the error?")
@ -1676,7 +1670,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
pub fn observe_release<T: 'static>(
&mut self,
handle: &Handle<T>,
handle: &Model<T>,
mut on_release: impl FnMut(&mut V, &mut T, &mut ViewContext<'_, '_, V>) + Send + 'static,
) -> Subscription
where
@ -1698,7 +1692,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
pub fn notify(&mut self) {
self.window_cx.notify();
self.window_cx.app.push_effect(Effect::Notify {
emitter: self.view.state.entity_id,
emitter: self.view.model.entity_id,
});
}
@ -1878,7 +1872,7 @@ where
V::Event: Any + Send,
{
pub fn emit(&mut self, event: V::Event) {
let emitter = self.view.state.entity_id;
let emitter = self.view.model.entity_id;
self.app.push_effect(Effect::Emit {
emitter,
event: Box::new(event),
@ -1897,41 +1891,36 @@ impl<'a, 'w, V: 'static> MainThread<ViewContext<'a, 'w, V>> {
}
impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> {
type EntityContext<'b, U> = ModelContext<'b, U>;
type ModelContext<'b, U> = ModelContext<'b, U>;
type Result<U> = U;
fn entity<T>(
fn build_model<T>(
&mut self,
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
) -> Handle<T>
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
) -> Model<T>
where
T: 'static + Send,
{
self.window_cx.entity(build_entity)
self.window_cx.build_model(build_model)
}
fn update_entity<T: 'static, R>(
&mut self,
handle: &Handle<T>,
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R {
self.window_cx.update_entity(handle, update)
self.window_cx.update_entity(model, update)
}
}
impl<V: 'static> VisualContext for ViewContext<'_, '_, V> {
type ViewContext<'a, 'w, V2> = ViewContext<'a, 'w, V2>;
fn build_view<E, V2>(
fn build_view<W: 'static + Send>(
&mut self,
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V2>) -> V2,
render: impl Fn(&mut V2, &mut ViewContext<'_, '_, V2>) -> E + Send + 'static,
) -> Self::Result<View<V2>>
where
E: crate::Component<V2>,
V2: 'static + Send,
{
self.window_cx.build_view(build_entity, render)
build_view: impl FnOnce(&mut Self::ViewContext<'_, '_, W>) -> W,
) -> Self::Result<View<W>> {
self.window_cx.build_view(build_view)
}
fn update_view<V2: 'static, R>(

View File

@ -5,7 +5,7 @@ use crate::language_settings::{
use crate::Buffer;
use clock::ReplicaId;
use collections::BTreeMap;
use gpui2::{AppContext, Handle};
use gpui2::{AppContext, Model};
use gpui2::{Context, TestAppContext};
use indoc::indoc;
use proto::deserialize_operation;
@ -42,7 +42,7 @@ fn init_logger() {
fn test_line_endings(cx: &mut gpui2::AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "one\r\ntwo\rthree")
.with_language(Arc::new(rust_lang()), cx);
assert_eq!(buffer.text(), "one\ntwo\nthree");
@ -138,8 +138,8 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
let buffer_1_events = Arc::new(Mutex::new(Vec::new()));
let buffer_2_events = Arc::new(Mutex::new(Vec::new()));
let buffer1 = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef"));
let buffer2 = cx.entity(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef"));
let buffer1 = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef"));
let buffer2 = cx.build_model(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef"));
let buffer1_ops = Arc::new(Mutex::new(Vec::new()));
buffer1.update(cx, {
let buffer1_ops = buffer1_ops.clone();
@ -218,7 +218,7 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
#[gpui2::test]
async fn test_apply_diff(cx: &mut TestAppContext) {
let text = "a\nbb\nccc\ndddd\neeeee\nffffff\n";
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
let anchor = buffer.update(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)));
let text = "a\nccc\ndddd\nffffff\n";
@ -250,7 +250,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
]
.join("\n");
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
// Spawn a task to format the buffer's whitespace.
// Pause so that the foratting task starts running.
@ -314,7 +314,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
#[gpui2::test]
async fn test_reparse(cx: &mut gpui2::TestAppContext) {
let text = "fn a() {}";
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
});
@ -442,7 +442,7 @@ async fn test_reparse(cx: &mut gpui2::TestAppContext) {
#[gpui2::test]
async fn test_resetting_language(cx: &mut gpui2::TestAppContext) {
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), "{}").with_language(Arc::new(rust_lang()), cx);
buffer.set_sync_parse_timeout(Duration::ZERO);
@ -492,7 +492,7 @@ async fn test_outline(cx: &mut gpui2::TestAppContext) {
"#
.unindent();
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
});
let outline = buffer
@ -578,7 +578,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui2::TestAppContext) {
"#
.unindent();
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
});
let outline = buffer
@ -616,7 +616,7 @@ async fn test_outline_with_extra_context(cx: &mut gpui2::TestAppContext) {
"#
.unindent();
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(language), cx)
});
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
@ -660,7 +660,7 @@ async fn test_symbols_containing(cx: &mut gpui2::TestAppContext) {
"#
.unindent();
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
});
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
@ -881,7 +881,7 @@ fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &
#[gpui2::test]
fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
cx.entity(|cx| {
cx.build_model(|cx| {
let text = "fn a() { b(|c| {}) }";
let buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -922,7 +922,7 @@ fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
fn test_autoindent_with_soft_tabs(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = "fn a() {}";
let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -965,7 +965,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
settings.defaults.hard_tabs = Some(true);
});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = "fn a() {}";
let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -1006,7 +1006,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let entity_id = cx.entity_id();
let mut buffer = Buffer::new(
0,
@ -1080,7 +1080,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
buffer
});
cx.entity(|cx| {
cx.build_model(|cx| {
eprintln!("second buffer: {:?}", cx.entity_id());
let mut buffer = Buffer::new(
@ -1147,7 +1147,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let mut buffer = Buffer::new(
0,
cx.entity_id().as_u64(),
@ -1209,7 +1209,7 @@ fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut Ap
fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let mut buffer = Buffer::new(
0,
cx.entity_id().as_u64(),
@ -1266,7 +1266,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = "a\nb";
let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
@ -1284,7 +1284,7 @@ fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = "
const a: usize = 1;
fn b() {
@ -1326,7 +1326,7 @@ fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
fn test_autoindent_block_mode(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = r#"
fn a() {
b();
@ -1410,7 +1410,7 @@ fn test_autoindent_block_mode(cx: &mut AppContext) {
fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = r#"
fn a() {
if b() {
@ -1490,7 +1490,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContex
fn test_autoindent_language_without_indents_query(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = "
* one
- a
@ -1559,7 +1559,7 @@ fn test_autoindent_with_injected_languages(cx: &mut AppContext) {
language_registry.add(html_language.clone());
language_registry.add(javascript_language.clone());
cx.entity(|cx| {
cx.build_model(|cx| {
let (text, ranges) = marked_text_ranges(
&"
<div>ˇ
@ -1610,7 +1610,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
settings.defaults.tab_size = Some(2.try_into().unwrap());
});
cx.entity(|cx| {
cx.build_model(|cx| {
let mut buffer =
Buffer::new(0, cx.entity_id().as_u64(), "").with_language(Arc::new(ruby_lang()), cx);
@ -1653,7 +1653,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let language = Language::new(
LanguageConfig {
name: "JavaScript".into(),
@ -1742,7 +1742,7 @@ fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
fn test_language_scope_at_with_rust(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let language = Language::new(
LanguageConfig {
name: "Rust".into(),
@ -1810,7 +1810,7 @@ fn test_language_scope_at_with_rust(cx: &mut AppContext) {
fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
init_settings(cx, |_| {});
cx.entity(|cx| {
cx.build_model(|cx| {
let text = r#"
<ol>
<% people.each do |person| %>
@ -1858,7 +1858,7 @@ fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
fn test_serialization(cx: &mut gpui2::AppContext) {
let mut now = Instant::now();
let buffer1 = cx.entity(|cx| {
let buffer1 = cx.build_model(|cx| {
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "abc");
buffer.edit([(3..3, "D")], None, cx);
@ -1881,7 +1881,7 @@ fn test_serialization(cx: &mut gpui2::AppContext) {
let ops = cx
.executor()
.block(buffer1.read(cx).serialize_ops(None, cx));
let buffer2 = cx.entity(|cx| {
let buffer2 = cx.build_model(|cx| {
let mut buffer = Buffer::from_proto(1, state, None).unwrap();
buffer
.apply_ops(
@ -1914,10 +1914,11 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
let mut replica_ids = Vec::new();
let mut buffers = Vec::new();
let network = Arc::new(Mutex::new(Network::new(rng.clone())));
let base_buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str()));
let base_buffer =
cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str()));
for i in 0..rng.gen_range(min_peers..=max_peers) {
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
let state = base_buffer.read(cx).to_proto();
let ops = cx
.executor()
@ -2034,7 +2035,7 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
new_replica_id,
replica_id
);
new_buffer = Some(cx.entity(|cx| {
new_buffer = Some(cx.build_model(|cx| {
let mut new_buffer =
Buffer::from_proto(new_replica_id, old_buffer_state, None).unwrap();
new_buffer
@ -2396,7 +2397,7 @@ fn javascript_lang() -> Language {
.unwrap()
}
fn get_tree_sexp(buffer: &Handle<Buffer>, cx: &mut gpui2::TestAppContext) -> String {
fn get_tree_sexp(buffer: &Model<Buffer>, cx: &mut gpui2::TestAppContext) -> String {
buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let layers = snapshot.syntax.layers(buffer.as_text_snapshot());
@ -2412,7 +2413,7 @@ fn assert_bracket_pairs(
cx: &mut AppContext,
) {
let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false);
let buffer = cx.entity(|cx| {
let buffer = cx.build_model(|cx| {
Buffer::new(0, cx.entity_id().as_u64(), expected_text.clone())
.with_language(Arc::new(language), cx)
});

12
crates/menu2/Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "menu2"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/menu2.rs"
doctest = false
[dependencies]
gpui2 = { path = "../gpui2" }

25
crates/menu2/src/menu2.rs Normal file
View File

@ -0,0 +1,25 @@
// todo!(use actions! macro)
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Cancel;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Confirm;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SecondaryConfirm;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectPrev;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectNext;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectFirst;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct SelectLast;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ShowContextMenu;

View File

@ -16,7 +16,7 @@ client2 = { path = "../client2" }
collections = { path = "../collections"}
language2 = { path = "../language2" }
gpui2 = { path = "../gpui2" }
fs = { path = "../fs" }
fs2 = { path = "../fs2" }
lsp2 = { path = "../lsp2" }
node_runtime = { path = "../node_runtime"}
util = { path = "../util" }
@ -32,4 +32,4 @@ parking_lot.workspace = true
[dev-dependencies]
language2 = { path = "../language2", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
fs2 = { path = "../fs2", features = ["test-support"] }

View File

@ -1,7 +1,7 @@
use anyhow::Context;
use collections::{HashMap, HashSet};
use fs::Fs;
use gpui2::{AsyncAppContext, Handle};
use fs2::Fs;
use gpui2::{AsyncAppContext, Model};
use language2::{language_settings::language_settings, Buffer, BundledFormatter, Diff};
use lsp2::{LanguageServer, LanguageServerId};
use node_runtime::NodeRuntime;
@ -183,7 +183,7 @@ impl Prettier {
pub async fn format(
&self,
buffer: &Handle<Buffer>,
buffer: &Model<Buffer>,
buffer_path: Option<PathBuf>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Diff> {

View File

@ -25,7 +25,7 @@ client2 = { path = "../client2" }
clock = { path = "../clock" }
collections = { path = "../collections" }
db2 = { path = "../db2" }
fs = { path = "../fs" }
fs2 = { path = "../fs2" }
fsevent = { path = "../fsevent" }
fuzzy2 = { path = "../fuzzy2" }
git = { path = "../git" }
@ -71,7 +71,7 @@ pretty_assertions.workspace = true
client2 = { path = "../client2", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
db2 = { path = "../db2", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
fs2 = { path = "../fs2", features = ["test-support"] }
gpui2 = { path = "../gpui2", features = ["test-support"] }
language2 = { path = "../language2", features = ["test-support"] }
lsp2 = { path = "../lsp2", features = ["test-support"] }

View File

@ -7,7 +7,7 @@ use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use client2::proto::{self, PeerId};
use futures::future;
use gpui2::{AppContext, AsyncAppContext, Handle};
use gpui2::{AppContext, AsyncAppContext, Model};
use language2::{
language_settings::{language_settings, InlayHintKind},
point_from_lsp, point_to_lsp,
@ -53,8 +53,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
async fn response_from_lsp(
self,
message: <Self::LspRequest as lsp2::request::Request>::Result,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
cx: AsyncAppContext,
) -> Result<Self::Response>;
@ -63,8 +63,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
async fn from_proto(
message: Self::ProtoRequest,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
cx: AsyncAppContext,
) -> Result<Self>;
@ -79,8 +79,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
async fn response_from_proto(
self,
message: <Self::ProtoRequest as proto::RequestMessage>::Response,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
cx: AsyncAppContext,
) -> Result<Self::Response>;
@ -180,8 +180,8 @@ impl LspCommand for PrepareRename {
async fn response_from_lsp(
self,
message: Option<lsp2::PrepareRenameResponse>,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
_: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<Option<Range<Anchor>>> {
@ -215,8 +215,8 @@ impl LspCommand for PrepareRename {
async fn from_proto(
message: proto::PrepareRename,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -256,8 +256,8 @@ impl LspCommand for PrepareRename {
async fn response_from_proto(
self,
message: proto::PrepareRenameResponse,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Option<Range<Anchor>>> {
if message.can_rename {
@ -307,8 +307,8 @@ impl LspCommand for PerformRename {
async fn response_from_lsp(
self,
message: Option<lsp2::WorkspaceEdit>,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<ProjectTransaction> {
@ -343,8 +343,8 @@ impl LspCommand for PerformRename {
async fn from_proto(
message: proto::PerformRename,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -379,8 +379,8 @@ impl LspCommand for PerformRename {
async fn response_from_proto(
self,
message: proto::PerformRenameResponse,
project: Handle<Project>,
_: Handle<Buffer>,
project: Model<Project>,
_: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<ProjectTransaction> {
let message = message
@ -426,8 +426,8 @@ impl LspCommand for GetDefinition {
async fn response_from_lsp(
self,
message: Option<lsp2::GotoDefinitionResponse>,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> {
@ -447,8 +447,8 @@ impl LspCommand for GetDefinition {
async fn from_proto(
message: proto::GetDefinition,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -479,8 +479,8 @@ impl LspCommand for GetDefinition {
async fn response_from_proto(
self,
message: proto::GetDefinitionResponse,
project: Handle<Project>,
_: Handle<Buffer>,
project: Model<Project>,
_: Model<Buffer>,
cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> {
location_links_from_proto(message.links, project, cx).await
@ -527,8 +527,8 @@ impl LspCommand for GetTypeDefinition {
async fn response_from_lsp(
self,
message: Option<lsp2::GotoTypeDefinitionResponse>,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> {
@ -548,8 +548,8 @@ impl LspCommand for GetTypeDefinition {
async fn from_proto(
message: proto::GetTypeDefinition,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -580,8 +580,8 @@ impl LspCommand for GetTypeDefinition {
async fn response_from_proto(
self,
message: proto::GetTypeDefinitionResponse,
project: Handle<Project>,
_: Handle<Buffer>,
project: Model<Project>,
_: Model<Buffer>,
cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> {
location_links_from_proto(message.links, project, cx).await
@ -593,8 +593,8 @@ impl LspCommand for GetTypeDefinition {
}
fn language_server_for_buffer(
project: &Handle<Project>,
buffer: &Handle<Buffer>,
project: &Model<Project>,
buffer: &Model<Buffer>,
server_id: LanguageServerId,
cx: &mut AsyncAppContext,
) -> Result<(Arc<CachedLspAdapter>, Arc<LanguageServer>)> {
@ -609,7 +609,7 @@ fn language_server_for_buffer(
async fn location_links_from_proto(
proto_links: Vec<proto::LocationLink>,
project: Handle<Project>,
project: Model<Project>,
mut cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> {
let mut links = Vec::new();
@ -671,8 +671,8 @@ async fn location_links_from_proto(
async fn location_links_from_lsp(
message: Option<lsp2::GotoDefinitionResponse>,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<Vec<LocationLink>> {
@ -814,8 +814,8 @@ impl LspCommand for GetReferences {
async fn response_from_lsp(
self,
locations: Option<Vec<lsp2::Location>>,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<Vec<Location>> {
@ -868,8 +868,8 @@ impl LspCommand for GetReferences {
async fn from_proto(
message: proto::GetReferences,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -910,8 +910,8 @@ impl LspCommand for GetReferences {
async fn response_from_proto(
self,
message: proto::GetReferencesResponse,
project: Handle<Project>,
_: Handle<Buffer>,
project: Model<Project>,
_: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Vec<Location>> {
let mut locations = Vec::new();
@ -977,8 +977,8 @@ impl LspCommand for GetDocumentHighlights {
async fn response_from_lsp(
self,
lsp_highlights: Option<Vec<lsp2::DocumentHighlight>>,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
_: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<Vec<DocumentHighlight>> {
@ -1016,8 +1016,8 @@ impl LspCommand for GetDocumentHighlights {
async fn from_proto(
message: proto::GetDocumentHighlights,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -1060,8 +1060,8 @@ impl LspCommand for GetDocumentHighlights {
async fn response_from_proto(
self,
message: proto::GetDocumentHighlightsResponse,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Vec<DocumentHighlight>> {
let mut highlights = Vec::new();
@ -1123,8 +1123,8 @@ impl LspCommand for GetHover {
async fn response_from_lsp(
self,
message: Option<lsp2::Hover>,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
_: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<Self::Response> {
@ -1206,8 +1206,8 @@ impl LspCommand for GetHover {
async fn from_proto(
message: Self::ProtoRequest,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -1272,8 +1272,8 @@ impl LspCommand for GetHover {
async fn response_from_proto(
self,
message: proto::GetHoverResponse,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self::Response> {
let contents: Vec<_> = message
@ -1341,8 +1341,8 @@ impl LspCommand for GetCompletions {
async fn response_from_lsp(
self,
completions: Option<lsp2::CompletionResponse>,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<Vec<Completion>> {
@ -1484,8 +1484,8 @@ impl LspCommand for GetCompletions {
async fn from_proto(
message: proto::GetCompletions,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let version = deserialize_version(&message.version);
@ -1523,8 +1523,8 @@ impl LspCommand for GetCompletions {
async fn response_from_proto(
self,
message: proto::GetCompletionsResponse,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Vec<Completion>> {
buffer
@ -1589,8 +1589,8 @@ impl LspCommand for GetCodeActions {
async fn response_from_lsp(
self,
actions: Option<lsp2::CodeActionResponse>,
_: Handle<Project>,
_: Handle<Buffer>,
_: Model<Project>,
_: Model<Buffer>,
server_id: LanguageServerId,
_: AsyncAppContext,
) -> Result<Vec<CodeAction>> {
@ -1623,8 +1623,8 @@ impl LspCommand for GetCodeActions {
async fn from_proto(
message: proto::GetCodeActions,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let start = message
@ -1663,8 +1663,8 @@ impl LspCommand for GetCodeActions {
async fn response_from_proto(
self,
message: proto::GetCodeActionsResponse,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Vec<CodeAction>> {
buffer
@ -1726,8 +1726,8 @@ impl LspCommand for OnTypeFormatting {
async fn response_from_lsp(
self,
message: Option<Vec<lsp2::TextEdit>>,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
mut cx: AsyncAppContext,
) -> Result<Option<Transaction>> {
@ -1763,8 +1763,8 @@ impl LspCommand for OnTypeFormatting {
async fn from_proto(
message: proto::OnTypeFormatting,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let position = message
@ -1805,8 +1805,8 @@ impl LspCommand for OnTypeFormatting {
async fn response_from_proto(
self,
message: proto::OnTypeFormattingResponse,
_: Handle<Project>,
_: Handle<Buffer>,
_: Model<Project>,
_: Model<Buffer>,
_: AsyncAppContext,
) -> Result<Option<Transaction>> {
let Some(transaction) = message.transaction else {
@ -1825,7 +1825,7 @@ impl LspCommand for OnTypeFormatting {
impl InlayHints {
pub async fn lsp_to_project_hint(
lsp_hint: lsp2::InlayHint,
buffer_handle: &Handle<Buffer>,
buffer_handle: &Model<Buffer>,
server_id: LanguageServerId,
resolve_state: ResolveState,
force_no_type_left_padding: bool,
@ -2230,8 +2230,8 @@ impl LspCommand for InlayHints {
async fn response_from_lsp(
self,
message: Option<Vec<lsp2::InlayHint>>,
project: Handle<Project>,
buffer: Handle<Buffer>,
project: Model<Project>,
buffer: Model<Buffer>,
server_id: LanguageServerId,
mut cx: AsyncAppContext,
) -> anyhow::Result<Vec<InlayHint>> {
@ -2286,8 +2286,8 @@ impl LspCommand for InlayHints {
async fn from_proto(
message: proto::InlayHints,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let start = message
@ -2326,8 +2326,8 @@ impl LspCommand for InlayHints {
async fn response_from_proto(
self,
message: proto::InlayHintsResponse,
_: Handle<Project>,
buffer: Handle<Buffer>,
_: Model<Project>,
buffer: Model<Buffer>,
mut cx: AsyncAppContext,
) -> anyhow::Result<Vec<InlayHint>> {
buffer

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
use crate::Project;
use gpui2::{AnyWindowHandle, Context, Handle, ModelContext, WeakHandle};
use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakModel};
use settings2::Settings;
use std::path::{Path, PathBuf};
use terminal2::{
@ -11,7 +11,7 @@ use terminal2::{
use std::os::unix::ffi::OsStrExt;
pub struct Terminals {
pub(crate) local_handles: Vec<WeakHandle<terminal2::Terminal>>,
pub(crate) local_handles: Vec<WeakModel<terminal2::Terminal>>,
}
impl Project {
@ -20,7 +20,7 @@ impl Project {
working_directory: Option<PathBuf>,
window: AnyWindowHandle,
cx: &mut ModelContext<Self>,
) -> anyhow::Result<Handle<Terminal>> {
) -> anyhow::Result<Model<Terminal>> {
if self.is_remote() {
return Err(anyhow::anyhow!(
"creating terminals as a guest is not supported yet"
@ -40,7 +40,7 @@ impl Project {
|_, _| todo!("color_for_index"),
)
.map(|builder| {
let terminal_handle = cx.entity(|cx| builder.subscribe(cx));
let terminal_handle = cx.build_model(|cx| builder.subscribe(cx));
self.terminals
.local_handles
@ -108,7 +108,7 @@ impl Project {
fn activate_python_virtual_environment(
&mut self,
activate_script: Option<PathBuf>,
terminal_handle: &Handle<Terminal>,
terminal_handle: &Model<Terminal>,
cx: &mut ModelContext<Project>,
) {
if let Some(activate_script) = activate_script {
@ -121,7 +121,7 @@ impl Project {
}
}
pub fn local_terminal_handles(&self) -> &Vec<WeakHandle<terminal2::Terminal>> {
pub fn local_terminal_handles(&self) -> &Vec<WeakModel<terminal2::Terminal>> {
&self.terminals.local_handles
}
}

View File

@ -6,7 +6,7 @@ use anyhow::{anyhow, Context as _, Result};
use client2::{proto, Client};
use clock::ReplicaId;
use collections::{HashMap, HashSet, VecDeque};
use fs::{
use fs2::{
repository::{GitFileStatus, GitRepository, RepoPath},
Fs,
};
@ -22,7 +22,7 @@ use futures::{
use fuzzy2::CharBag;
use git::{DOT_GIT, GITIGNORE};
use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Executor, Handle, ModelContext, Task,
AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext, Task,
};
use language2::{
proto::{
@ -292,7 +292,7 @@ impl Worktree {
fs: Arc<dyn Fs>,
next_entry_id: Arc<AtomicUsize>,
cx: &mut AsyncAppContext,
) -> Result<Handle<Self>> {
) -> Result<Model<Self>> {
// After determining whether the root entry is a file or a directory, populate the
// snapshot's "root name", which will be used for the purpose of fuzzy matching.
let abs_path = path.into();
@ -301,7 +301,7 @@ impl Worktree {
.await
.context("failed to stat worktree path")?;
cx.entity(move |cx: &mut ModelContext<Worktree>| {
cx.build_model(move |cx: &mut ModelContext<Worktree>| {
let root_name = abs_path
.file_name()
.map_or(String::new(), |f| f.to_string_lossy().to_string());
@ -406,8 +406,8 @@ impl Worktree {
worktree: proto::WorktreeMetadata,
client: Arc<Client>,
cx: &mut AppContext,
) -> Handle<Self> {
cx.entity(|cx: &mut ModelContext<Self>| {
) -> Model<Self> {
cx.build_model(|cx: &mut ModelContext<Self>| {
let snapshot = Snapshot {
id: WorktreeId(worktree.id as usize),
abs_path: Arc::from(PathBuf::from(worktree.abs_path)),
@ -593,7 +593,7 @@ impl LocalWorktree {
id: u64,
path: &Path,
cx: &mut ModelContext<Worktree>,
) -> Task<Result<Handle<Buffer>>> {
) -> Task<Result<Model<Buffer>>> {
let path = Arc::from(path);
cx.spawn(move |this, mut cx| async move {
let (file, contents, diff_base) = this
@ -603,7 +603,7 @@ impl LocalWorktree {
.executor()
.spawn(async move { text::Buffer::new(0, id, contents) })
.await;
cx.entity(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file))))
cx.build_model(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file))))
})
}
@ -920,7 +920,7 @@ impl LocalWorktree {
pub fn save_buffer(
&self,
buffer_handle: Handle<Buffer>,
buffer_handle: Model<Buffer>,
path: Arc<Path>,
has_changed_file: bool,
cx: &mut ModelContext<Worktree>,
@ -1331,7 +1331,7 @@ impl RemoteWorktree {
pub fn save_buffer(
&self,
buffer_handle: Handle<Buffer>,
buffer_handle: Model<Buffer>,
cx: &mut ModelContext<Worktree>,
) -> Task<Result<()>> {
let buffer = buffer_handle.read(cx);
@ -2577,7 +2577,7 @@ impl fmt::Debug for Snapshot {
#[derive(Clone, PartialEq)]
pub struct File {
pub worktree: Handle<Worktree>,
pub worktree: Model<Worktree>,
pub path: Arc<Path>,
pub mtime: SystemTime,
pub(crate) entry_id: ProjectEntryId,
@ -2701,7 +2701,7 @@ impl language2::LocalFile for File {
}
impl File {
pub fn for_entry(entry: Entry, worktree: Handle<Worktree>) -> Arc<Self> {
pub fn for_entry(entry: Entry, worktree: Model<Worktree>) -> Arc<Self> {
Arc::new(Self {
worktree,
path: entry.path.clone(),
@ -2714,7 +2714,7 @@ impl File {
pub fn from_proto(
proto: rpc2::proto::File,
worktree: Handle<Worktree>,
worktree: Model<Worktree>,
cx: &AppContext,
) -> Result<Self> {
let worktree_id = worktree
@ -2815,7 +2815,7 @@ pub type UpdatedGitRepositoriesSet = Arc<[(Arc<Path>, GitRepositoryChange)]>;
impl Entry {
fn new(
path: Arc<Path>,
metadata: &fs::Metadata,
metadata: &fs2::Metadata,
next_entry_id: &AtomicUsize,
root_char_bag: CharBag,
) -> Self {

View File

@ -42,6 +42,7 @@ sha1 = "0.10.5"
ndarray = { version = "0.15.0" }
[dev-dependencies]
ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }

View File

@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
api_key: Option<String>,
}
#[derive(Clone)]
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: Arc<Background>,
api_key: Option<String>,
) -> Self {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
@ -64,14 +59,9 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
api_key,
}
}
pub fn set_api_key(&mut self, api_key: Option<String>) {
self.api_key = api_key
}
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
@ -118,7 +108,6 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
self.executor
.spawn(async move {
@ -143,7 +132,7 @@ impl EmbeddingQueue {
return;
};
match embedding_provider.embed_batch(spans, api_key).await {
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {

View File

@ -1,4 +1,7 @@
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::{
embedding::{Embedding, EmbeddingProvider},
models::TruncationDirection,
};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{
@ -108,7 +111,14 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@ -131,7 +141,15 @@ impl CodeContextRetriever {
)
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
let model = self.embedding_provider.base_model();
let document_span = model.truncate(
&document_span,
model.capacity()?,
ai::models::TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_span)?;
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@ -222,8 +240,13 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("item", &span.content);
let (document_content, token_count) =
self.embedding_provider.truncate(&document_content);
let model = self.embedding_provider.base_model();
let document_content = model.truncate(
&document_content,
model.capacity()?,
TruncationDirection::End,
)?;
let token_count = model.count_tokens(&document_content)?;
span.content = document_content;
span.token_count = token_count;

View File

@ -7,7 +7,8 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@ -88,7 +89,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@ -123,8 +124,6 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
api_key: Option<String>,
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
struct ProjectState {
@ -278,18 +277,18 @@ impl SemanticIndex {
}
}
pub fn authenticate(&mut self, cx: &AppContext) {
if self.api_key.is_none() {
self.api_key = self.embedding_provider.retrieve_credentials(cx);
self.embedding_queue
.lock()
.set_api_key(self.api_key.clone());
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
if !self.embedding_provider.has_credentials() {
self.embedding_provider.retrieve_credentials(cx);
} else {
return true;
}
self.embedding_provider.has_credentials()
}
pub fn is_authenticated(&self) -> bool {
self.api_key.is_some()
self.embedding_provider.has_credentials()
}
pub fn enabled(cx: &AppContext) -> bool {
@ -339,7 +338,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@ -404,8 +403,6 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
api_key: None,
embedding_queue
}
}))
}
@ -720,13 +717,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
let api_key = self.api_key.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let query = embedding_provider
.embed_batch(vec![query], api_key)
.embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@ -944,7 +941,6 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
let api_key = self.api_key.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@ -959,15 +955,10 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
if Self::embed_spans(
&mut spans,
embedding_provider.as_ref(),
&db,
api_key.clone(),
)
.await
.log_err()
.is_some()
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
.await
.log_err()
.is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
@ -1007,9 +998,8 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if self.api_key.is_none() {
self.authenticate(cx);
if self.api_key.is_none() {
if !self.is_authenticated() {
if !self.authenticate(cx) {
return Task::ready(Err(anyhow!("user is not authenticated")));
}
}
@ -1192,7 +1182,6 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
api_key: Option<String>,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@ -1215,7 +1204,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), api_key.clone())
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@ -1227,7 +1216,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
.embed_batch(mem::take(&mut batch), api_key)
.embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);

View File

@ -4,10 +4,9 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
use ai::test::FakeEmbeddingProvider;
use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
use std::{
path::Path,
sync::{
atomic::{self, AtomicUsize},
Arc,
},
time::{Instant, SystemTime},
};
use std::{path::Path, sync::Arc, time::SystemTime};
use unindent::Unindent;
use util::RandomCharIter;
@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}
@ -280,7 +272,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -466,7 +458,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
let embedding_provider = Arc::new(DummyEmbeddings {});
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
);
}
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
}
impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
Some("Fake Credentials".to_string())
}
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
}
fn max_tokens_per_batch(&self) -> usize {
200
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(
&self,
spans: Vec<String>,
_api_key: Option<String>,
) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
fn js_lang() -> Arc<Language> {
Arc::new(
Language::new(

View File

@ -1,9 +1,11 @@
mod colors;
mod focus;
mod kitchen_sink;
mod scroll;
mod text;
mod z_index;
pub use colors::*;
pub use focus::*;
pub use kitchen_sink::*;
pub use scroll::*;

View File

@ -0,0 +1,38 @@
use crate::story::Story;
use gpui2::{px, Div, Render};
use ui::prelude::*;
pub struct ColorsStory;
impl Render for ColorsStory {
type Element = Div<Self>;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
let color_scales = theme2::default_color_scales();
Story::container(cx)
.child(Story::title(cx, "Colors"))
.child(
div()
.id("colors")
.flex()
.flex_col()
.gap_1()
.overflow_y_scroll()
.text_color(gpui2::white())
.children(color_scales.into_iter().map(|(name, scale)| {
div()
.flex()
.child(
div()
.w(px(75.))
.line_height(px(24.))
.child(name.to_string()),
)
.child(div().flex().gap_1().children(
(1..=12).map(|step| div().flex().size_6().bg(scale.step(cx, step))),
))
})),
)
}
}

View File

@ -1,9 +1,9 @@
use crate::themes::rose_pine;
use gpui2::{
div, Focusable, KeyBinding, ParentElement, StatelessInteractive, Styled, View, VisualContext,
WindowContext,
div, Div, FocusEnabled, Focusable, KeyBinding, ParentElement, Render, StatefulInteraction,
StatelessInteractive, Styled, View, VisualContext, WindowContext,
};
use serde::Deserialize;
use theme2::theme;
#[derive(Clone, Default, PartialEq, Deserialize)]
struct ActionA;
@ -14,12 +14,10 @@ struct ActionB;
#[derive(Clone, Default, PartialEq, Deserialize)]
struct ActionC;
pub struct FocusStory {
text: View<()>,
}
pub struct FocusStory {}
impl FocusStory {
pub fn view(cx: &mut WindowContext) -> View<()> {
pub fn view(cx: &mut WindowContext) -> View<Self> {
cx.bind_keys([
KeyBinding::new("cmd-a", ActionA, Some("parent")),
KeyBinding::new("cmd-a", ActionB, Some("child-1")),
@ -27,91 +25,92 @@ impl FocusStory {
]);
cx.register_action_type::<ActionA>();
cx.register_action_type::<ActionB>();
let theme = rose_pine();
let color_1 = theme.lowest.negative.default.foreground;
let color_2 = theme.lowest.positive.default.foreground;
let color_3 = theme.lowest.warning.default.foreground;
let color_4 = theme.lowest.accent.default.foreground;
let color_5 = theme.lowest.variant.default.foreground;
let color_6 = theme.highest.negative.default.foreground;
cx.build_view(move |cx| Self {})
}
}
impl Render for FocusStory {
type Element = Div<Self, StatefulInteraction<Self>, FocusEnabled<Self>>;
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
let theme = theme(cx);
let color_1 = theme.git_created;
let color_2 = theme.git_modified;
let color_3 = theme.git_deleted;
let color_4 = theme.git_conflict;
let color_5 = theme.git_ignored;
let color_6 = theme.git_renamed;
let child_1 = cx.focus_handle();
let child_2 = cx.focus_handle();
cx.build_view(
|_| (),
move |_, cx| {
div()
.id("parent")
.focusable()
.context("parent")
.on_action(|_, action: &ActionA, phase, cx| {
println!("Action A dispatched on parent during {:?}", phase);
})
.on_action(|_, action: &ActionB, phase, cx| {
println!("Action B dispatched on parent during {:?}", phase);
})
.on_focus(|_, _, _| println!("Parent focused"))
.on_blur(|_, _, _| println!("Parent blurred"))
.on_focus_in(|_, _, _| println!("Parent focus_in"))
.on_focus_out(|_, _, _| println!("Parent focus_out"))
.on_key_down(|_, event, phase, _| {
println!("Key down on parent {:?} {:?}", phase, event)
})
.on_key_up(|_, event, phase, _| println!("Key up on parent {:?} {:?}", phase, event))
.size_full()
.bg(color_1)
.focus(|style| style.bg(color_2))
.focus_in(|style| style.bg(color_3))
.child(
div()
.id("parent")
.focusable()
.context("parent")
.on_action(|_, action: &ActionA, phase, cx| {
println!("Action A dispatched on parent during {:?}", phase);
})
.track_focus(&child_1)
.context("child-1")
.on_action(|_, action: &ActionB, phase, cx| {
println!("Action B dispatched on parent during {:?}", phase);
println!("Action B dispatched on child 1 during {:?}", phase);
})
.on_focus(|_, _, _| println!("Parent focused"))
.on_blur(|_, _, _| println!("Parent blurred"))
.on_focus_in(|_, _, _| println!("Parent focus_in"))
.on_focus_out(|_, _, _| println!("Parent focus_out"))
.w_full()
.h_6()
.bg(color_4)
.focus(|style| style.bg(color_5))
.in_focus(|style| style.bg(color_6))
.on_focus(|_, _, _| println!("Child 1 focused"))
.on_blur(|_, _, _| println!("Child 1 blurred"))
.on_focus_in(|_, _, _| println!("Child 1 focus_in"))
.on_focus_out(|_, _, _| println!("Child 1 focus_out"))
.on_key_down(|_, event, phase, _| {
println!("Key down on parent {:?} {:?}", phase, event)
println!("Key down on child 1 {:?} {:?}", phase, event)
})
.on_key_up(|_, event, phase, _| {
println!("Key up on parent {:?} {:?}", phase, event)
println!("Key up on child 1 {:?} {:?}", phase, event)
})
.size_full()
.bg(color_1)
.focus(|style| style.bg(color_2))
.focus_in(|style| style.bg(color_3))
.child(
div()
.track_focus(&child_1)
.context("child-1")
.on_action(|_, action: &ActionB, phase, cx| {
println!("Action B dispatched on child 1 during {:?}", phase);
})
.w_full()
.h_6()
.bg(color_4)
.focus(|style| style.bg(color_5))
.in_focus(|style| style.bg(color_6))
.on_focus(|_, _, _| println!("Child 1 focused"))
.on_blur(|_, _, _| println!("Child 1 blurred"))
.on_focus_in(|_, _, _| println!("Child 1 focus_in"))
.on_focus_out(|_, _, _| println!("Child 1 focus_out"))
.on_key_down(|_, event, phase, _| {
println!("Key down on child 1 {:?} {:?}", phase, event)
})
.on_key_up(|_, event, phase, _| {
println!("Key up on child 1 {:?} {:?}", phase, event)
})
.child("Child 1"),
)
.child(
div()
.track_focus(&child_2)
.context("child-2")
.on_action(|_, action: &ActionC, phase, cx| {
println!("Action C dispatched on child 2 during {:?}", phase);
})
.w_full()
.h_6()
.bg(color_4)
.on_focus(|_, _, _| println!("Child 2 focused"))
.on_blur(|_, _, _| println!("Child 2 blurred"))
.on_focus_in(|_, _, _| println!("Child 2 focus_in"))
.on_focus_out(|_, _, _| println!("Child 2 focus_out"))
.on_key_down(|_, event, phase, _| {
println!("Key down on child 2 {:?} {:?}", phase, event)
})
.on_key_up(|_, event, phase, _| {
println!("Key up on child 2 {:?} {:?}", phase, event)
})
.child("Child 2"),
)
},
)
.child("Child 1"),
)
.child(
div()
.track_focus(&child_2)
.context("child-2")
.on_action(|_, action: &ActionC, phase, cx| {
println!("Action C dispatched on child 2 during {:?}", phase);
})
.w_full()
.h_6()
.bg(color_4)
.on_focus(|_, _, _| println!("Child 2 focused"))
.on_blur(|_, _, _| println!("Child 2 blurred"))
.on_focus_in(|_, _, _| println!("Child 2 focus_in"))
.on_focus_out(|_, _, _| println!("Child 2 focus_out"))
.on_key_down(|_, event, phase, _| {
println!("Key down on child 2 {:?} {:?}", phase, event)
})
.on_key_up(|_, event, phase, _| {
println!("Key up on child 2 {:?} {:?}", phase, event)
})
.child("Child 2"),
)
}
}

View File

@ -1,26 +1,23 @@
use gpui2::{AppContext, Context, View};
use crate::{
story::Story,
story_selector::{ComponentStory, ElementStory},
};
use gpui2::{Div, Render, StatefulInteraction, View, VisualContext};
use strum::IntoEnumIterator;
use ui::prelude::*;
use crate::story::Story;
use crate::story_selector::{ComponentStory, ElementStory};
pub struct KitchenSinkStory {}
pub struct KitchenSinkStory;
impl KitchenSinkStory {
pub fn new() -> Self {
Self {}
pub fn view(cx: &mut WindowContext) -> View<Self> {
cx.build_view(|cx| Self)
}
}
pub fn view(cx: &mut AppContext) -> View<Self> {
{
let state = cx.entity(|cx| Self::new());
let render = Self::render;
View::for_handle(state, render)
}
}
impl Render for KitchenSinkStory {
type Element = Div<Self, StatefulInteraction<Self>>;
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> {
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
let element_stories = ElementStory::iter()
.map(|selector| selector.story(cx))
.collect::<Vec<_>>();

View File

@ -1,58 +1,54 @@
use crate::themes::rose_pine;
use gpui2::{
div, px, Component, ParentElement, SharedString, Styled, View, VisualContext, WindowContext,
div, px, Component, Div, ParentElement, Render, SharedString, StatefulInteraction, Styled,
View, VisualContext, WindowContext,
};
use theme2::theme;
pub struct ScrollStory {
text: View<()>,
}
pub struct ScrollStory;
impl ScrollStory {
pub fn view(cx: &mut WindowContext) -> View<()> {
let theme = rose_pine();
{
cx.build_view(|cx| (), move |_, cx| checkerboard(1))
}
pub fn view(cx: &mut WindowContext) -> View<ScrollStory> {
cx.build_view(|cx| ScrollStory)
}
}
fn checkerboard<S>(depth: usize) -> impl Component<S>
where
S: 'static + Send + Sync,
{
let theme = rose_pine();
let color_1 = theme.lowest.positive.default.background;
let color_2 = theme.lowest.warning.default.background;
impl Render for ScrollStory {
type Element = Div<Self, StatefulInteraction<Self>>;
div()
.id("parent")
.bg(theme.lowest.base.default.background)
.size_full()
.overflow_scroll()
.children((0..10).map(|row| {
div()
.w(px(1000.))
.h(px(100.))
.flex()
.flex_row()
.children((0..10).map(|column| {
let id = SharedString::from(format!("{}, {}", row, column));
let bg = if row % 2 == column % 2 {
color_1
} else {
color_2
};
div().id(id).bg(bg).size(px(100. / depth as f32)).when(
row >= 5 && column >= 5,
|d| {
d.overflow_scroll()
.child(div().size(px(50.)).bg(color_1))
.child(div().size(px(50.)).bg(color_2))
.child(div().size(px(50.)).bg(color_1))
.child(div().size(px(50.)).bg(color_2))
},
)
}))
}))
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
let theme = theme(cx);
let color_1 = theme.git_created;
let color_2 = theme.git_modified;
div()
.id("parent")
.bg(theme.background)
.size_full()
.overflow_scroll()
.children((0..10).map(|row| {
div()
.w(px(1000.))
.h(px(100.))
.flex()
.flex_row()
.children((0..10).map(|column| {
let id = SharedString::from(format!("{}, {}", row, column));
let bg = if row % 2 == column % 2 {
color_1
} else {
color_2
};
div().id(id).bg(bg).size(px(100. as f32)).when(
row >= 5 && column >= 5,
|d| {
d.overflow_scroll()
.child(div().size(px(50.)).bg(color_1))
.child(div().size(px(50.)).bg(color_2))
.child(div().size(px(50.)).bg(color_1))
.child(div().size(px(50.)).bg(color_2))
},
)
}))
}))
}
}

View File

@ -1,20 +1,21 @@
use gpui2::{div, white, ParentElement, Styled, View, VisualContext, WindowContext};
use gpui2::{div, white, Div, ParentElement, Render, Styled, View, VisualContext, WindowContext};
pub struct TextStory {
text: View<()>,
}
pub struct TextStory;
impl TextStory {
pub fn view(cx: &mut WindowContext) -> View<()> {
cx.build_view(|cx| (), |_, cx| {
div()
.size_full()
.bg(white())
.child(concat!(
"The quick brown fox jumps over the lazy dog. ",
"Meanwhile, the lazy dog decided it was time for a change. ",
"He started daily workout routines, ate healthier and became the fastest dog in town.",
))
})
pub fn view(cx: &mut WindowContext) -> View<Self> {
cx.build_view(|cx| Self)
}
}
impl Render for TextStory {
type Element = Div<Self>;
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
div().size_full().bg(white()).child(concat!(
"The quick brown fox jumps over the lazy dog. ",
"Meanwhile, the lazy dog decided it was time for a change. ",
"He started daily workout routines, ate healthier and became the fastest dog in town.",
))
}
}

View File

@ -1,15 +1,16 @@
use gpui2::{px, rgb, Div, Hsla};
use gpui2::{px, rgb, Div, Hsla, Render};
use ui::prelude::*;
use crate::story::Story;
/// A reimplementation of the MDN `z-index` example, found here:
/// [https://developer.mozilla.org/en-US/docs/Web/CSS/z-index](https://developer.mozilla.org/en-US/docs/Web/CSS/z-index).
#[derive(Component)]
pub struct ZIndexStory;
impl ZIndexStory {
fn render<V: 'static>(self, _view: &mut V, cx: &mut ViewContext<V>) -> impl Component<V> {
impl Render for ZIndexStory {
type Element = Div<Self>;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
Story::container(cx)
.child(Story::title(cx, "z-index"))
.child(

View File

@ -7,13 +7,14 @@ use clap::builder::PossibleValue;
use clap::ValueEnum;
use gpui2::{AnyView, VisualContext};
use strum::{EnumIter, EnumString, IntoEnumIterator};
use ui::prelude::*;
use ui::{prelude::*, AvatarStory, ButtonStory, DetailsStory, IconStory, InputStory, LabelStory};
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display, EnumString, EnumIter)]
#[strum(serialize_all = "snake_case")]
pub enum ElementStory {
Avatar,
Button,
Colors,
Details,
Focus,
Icon,
@ -27,18 +28,17 @@ pub enum ElementStory {
impl ElementStory {
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
match self {
Self::Avatar => { cx.build_view(|cx| (), |_, _| ui::AvatarStory.render()) }.into_any(),
Self::Button => { cx.build_view(|cx| (), |_, _| ui::ButtonStory.render()) }.into_any(),
Self::Details => {
{ cx.build_view(|cx| (), |_, _| ui::DetailsStory.render()) }.into_any()
}
Self::Colors => cx.build_view(|_| ColorsStory).into_any(),
Self::Avatar => cx.build_view(|_| AvatarStory).into_any(),
Self::Button => cx.build_view(|_| ButtonStory).into_any(),
Self::Details => cx.build_view(|_| DetailsStory).into_any(),
Self::Focus => FocusStory::view(cx).into_any(),
Self::Icon => { cx.build_view(|cx| (), |_, _| ui::IconStory.render()) }.into_any(),
Self::Input => { cx.build_view(|cx| (), |_, _| ui::InputStory.render()) }.into_any(),
Self::Label => { cx.build_view(|cx| (), |_, _| ui::LabelStory.render()) }.into_any(),
Self::Icon => cx.build_view(|_| IconStory).into_any(),
Self::Input => cx.build_view(|_| InputStory).into_any(),
Self::Label => cx.build_view(|_| LabelStory).into_any(),
Self::Scroll => ScrollStory::view(cx).into_any(),
Self::Text => TextStory::view(cx).into_any(),
Self::ZIndex => { cx.build_view(|cx| (), |_, _| ZIndexStory.render()) }.into_any(),
Self::ZIndex => cx.build_view(|_| ZIndexStory).into_any(),
}
}
}
@ -77,69 +77,31 @@ pub enum ComponentStory {
impl ComponentStory {
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
match self {
Self::AssistantPanel => {
{ cx.build_view(|cx| (), |_, _| ui::AssistantPanelStory.render()) }.into_any()
}
Self::Buffer => { cx.build_view(|cx| (), |_, _| ui::BufferStory.render()) }.into_any(),
Self::Breadcrumb => {
{ cx.build_view(|cx| (), |_, _| ui::BreadcrumbStory.render()) }.into_any()
}
Self::ChatPanel => {
{ cx.build_view(|cx| (), |_, _| ui::ChatPanelStory.render()) }.into_any()
}
Self::CollabPanel => {
{ cx.build_view(|cx| (), |_, _| ui::CollabPanelStory.render()) }.into_any()
}
Self::CommandPalette => {
{ cx.build_view(|cx| (), |_, _| ui::CommandPaletteStory.render()) }.into_any()
}
Self::ContextMenu => {
{ cx.build_view(|cx| (), |_, _| ui::ContextMenuStory.render()) }.into_any()
}
Self::Facepile => {
{ cx.build_view(|cx| (), |_, _| ui::FacepileStory.render()) }.into_any()
}
Self::Keybinding => {
{ cx.build_view(|cx| (), |_, _| ui::KeybindingStory.render()) }.into_any()
}
Self::LanguageSelector => {
{ cx.build_view(|cx| (), |_, _| ui::LanguageSelectorStory.render()) }.into_any()
}
Self::MultiBuffer => {
{ cx.build_view(|cx| (), |_, _| ui::MultiBufferStory.render()) }.into_any()
}
Self::NotificationsPanel => {
{ cx.build_view(|cx| (), |_, _| ui::NotificationsPanelStory.render()) }.into_any()
}
Self::Palette => {
{ cx.build_view(|cx| (), |_, _| ui::PaletteStory.render()) }.into_any()
}
Self::Panel => { cx.build_view(|cx| (), |_, _| ui::PanelStory.render()) }.into_any(),
Self::ProjectPanel => {
{ cx.build_view(|cx| (), |_, _| ui::ProjectPanelStory.render()) }.into_any()
}
Self::RecentProjects => {
{ cx.build_view(|cx| (), |_, _| ui::RecentProjectsStory.render()) }.into_any()
}
Self::Tab => { cx.build_view(|cx| (), |_, _| ui::TabStory.render()) }.into_any(),
Self::TabBar => { cx.build_view(|cx| (), |_, _| ui::TabBarStory.render()) }.into_any(),
Self::Terminal => {
{ cx.build_view(|cx| (), |_, _| ui::TerminalStory.render()) }.into_any()
}
Self::ThemeSelector => {
{ cx.build_view(|cx| (), |_, _| ui::ThemeSelectorStory.render()) }.into_any()
}
Self::AssistantPanel => cx.build_view(|_| ui::AssistantPanelStory).into_any(),
Self::Buffer => cx.build_view(|_| ui::BufferStory).into_any(),
Self::Breadcrumb => cx.build_view(|_| ui::BreadcrumbStory).into_any(),
Self::ChatPanel => cx.build_view(|_| ui::ChatPanelStory).into_any(),
Self::CollabPanel => cx.build_view(|_| ui::CollabPanelStory).into_any(),
Self::CommandPalette => cx.build_view(|_| ui::CommandPaletteStory).into_any(),
Self::ContextMenu => cx.build_view(|_| ui::ContextMenuStory).into_any(),
Self::Facepile => cx.build_view(|_| ui::FacepileStory).into_any(),
Self::Keybinding => cx.build_view(|_| ui::KeybindingStory).into_any(),
Self::LanguageSelector => cx.build_view(|_| ui::LanguageSelectorStory).into_any(),
Self::MultiBuffer => cx.build_view(|_| ui::MultiBufferStory).into_any(),
Self::NotificationsPanel => cx.build_view(|cx| ui::NotificationsPanelStory).into_any(),
Self::Palette => cx.build_view(|cx| ui::PaletteStory).into_any(),
Self::Panel => cx.build_view(|cx| ui::PanelStory).into_any(),
Self::ProjectPanel => cx.build_view(|_| ui::ProjectPanelStory).into_any(),
Self::RecentProjects => cx.build_view(|_| ui::RecentProjectsStory).into_any(),
Self::Tab => cx.build_view(|_| ui::TabStory).into_any(),
Self::TabBar => cx.build_view(|_| ui::TabBarStory).into_any(),
Self::Terminal => cx.build_view(|_| ui::TerminalStory).into_any(),
Self::ThemeSelector => cx.build_view(|_| ui::ThemeSelectorStory).into_any(),
Self::Toast => cx.build_view(|_| ui::ToastStory).into_any(),
Self::Toolbar => cx.build_view(|_| ui::ToolbarStory).into_any(),
Self::TrafficLights => cx.build_view(|_| ui::TrafficLightsStory).into_any(),
Self::Copilot => cx.build_view(|_| ui::CopilotModalStory).into_any(),
Self::TitleBar => ui::TitleBarStory::view(cx).into_any(),
Self::Toast => { cx.build_view(|cx| (), |_, _| ui::ToastStory.render()) }.into_any(),
Self::Toolbar => {
{ cx.build_view(|cx| (), |_, _| ui::ToolbarStory.render()) }.into_any()
}
Self::TrafficLights => {
{ cx.build_view(|cx| (), |_, _| ui::TrafficLightsStory.render()) }.into_any()
}
Self::Copilot => {
{ cx.build_view(|cx| (), |_, _| ui::CopilotModalStory.render()) }.into_any()
}
Self::Workspace => ui::WorkspaceStory::view(cx).into_any(),
}
}

View File

@ -4,21 +4,20 @@ mod assets;
mod stories;
mod story;
mod story_selector;
mod themes;
use std::sync::Arc;
use clap::Parser;
use gpui2::{
div, px, size, AnyView, AppContext, Bounds, ViewContext, VisualContext, WindowBounds,
WindowOptions,
div, px, size, AnyView, AppContext, Bounds, Div, Render, ViewContext, VisualContext,
WindowBounds, WindowOptions,
};
use log::LevelFilter;
use settings2::{default_settings, Settings, SettingsStore};
use simplelog::SimpleLogger;
use story_selector::ComponentStory;
use theme2::{ThemeRegistry, ThemeSettings};
use ui::{prelude::*, themed};
use ui::prelude::*;
use crate::assets::Assets;
use crate::story_selector::StorySelector;
@ -50,7 +49,6 @@ fn main() {
let story_selector = args.story.clone();
let theme_name = args.theme.unwrap_or("One Dark".to_string());
let theme = themes::load_theme(theme_name.clone()).unwrap();
let asset_source = Arc::new(Assets);
gpui2::App::production(asset_source).run(move |cx| {
@ -84,12 +82,7 @@ fn main() {
}),
..Default::default()
},
move |cx| {
cx.build_view(
|cx| StoryWrapper::new(selector.story(cx), theme),
StoryWrapper::render,
)
},
move |cx| cx.build_view(|cx| StoryWrapper::new(selector.story(cx))),
);
cx.activate(true);
@ -99,22 +92,23 @@ fn main() {
#[derive(Clone)]
pub struct StoryWrapper {
story: AnyView,
theme: Theme,
}
impl StoryWrapper {
pub(crate) fn new(story: AnyView, theme: Theme) -> Self {
Self { story, theme }
pub(crate) fn new(story: AnyView) -> Self {
Self { story }
}
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> {
themed(self.theme.clone(), cx, |cx| {
div()
.flex()
.flex_col()
.size_full()
.child(self.story.clone())
})
impl Render for StoryWrapper {
type Element = Div<Self>;
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
div()
.flex()
.flex_col()
.size_full()
.child(self.story.clone())
}
}

View File

@ -1,30 +0,0 @@
mod rose_pine;
pub use rose_pine::*;
use anyhow::{Context, Result};
use gpui2::serde_json;
use serde::Deserialize;
use ui::Theme;
use crate::assets::Assets;
#[derive(Deserialize)]
struct LegacyTheme {
pub base_theme: serde_json::Value,
}
/// Loads the [`Theme`] with the given name.
pub fn load_theme(name: String) -> Result<Theme> {
let theme_contents = Assets::get(&format!("themes/{name}.json"))
.with_context(|| format!("theme file not found: '{name}'"))?;
let legacy_theme: LegacyTheme =
serde_json::from_str(std::str::from_utf8(&theme_contents.data)?)
.context("failed to parse legacy theme")?;
let theme: Theme = serde_json::from_value(legacy_theme.base_theme.clone())
.context("failed to parse `base_theme`")?;
Ok(theme)
}

File diff suppressed because it is too large Load Diff

2118
crates/theme2/src/default.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,4 @@
use crate::{
themes::{one_dark, rose_pine, rose_pine_dawn, rose_pine_moon, sandcastle},
Theme, ThemeMetadata,
};
use crate::{themes, Theme, ThemeMetadata};
use anyhow::{anyhow, Result};
use gpui2::SharedString;
use std::{collections::HashMap, sync::Arc};
@ -41,11 +38,45 @@ impl Default for ThemeRegistry {
};
this.insert_themes([
one_dark(),
rose_pine(),
rose_pine_dawn(),
rose_pine_moon(),
sandcastle(),
themes::andromeda(),
themes::atelier_cave_dark(),
themes::atelier_cave_light(),
themes::atelier_dune_dark(),
themes::atelier_dune_light(),
themes::atelier_estuary_dark(),
themes::atelier_estuary_light(),
themes::atelier_forest_dark(),
themes::atelier_forest_light(),
themes::atelier_heath_dark(),
themes::atelier_heath_light(),
themes::atelier_lakeside_dark(),
themes::atelier_lakeside_light(),
themes::atelier_plateau_dark(),
themes::atelier_plateau_light(),
themes::atelier_savanna_dark(),
themes::atelier_savanna_light(),
themes::atelier_seaside_dark(),
themes::atelier_seaside_light(),
themes::atelier_sulphurpool_dark(),
themes::atelier_sulphurpool_light(),
themes::ayu_dark(),
themes::ayu_light(),
themes::ayu_mirage(),
themes::gruvbox_dark(),
themes::gruvbox_dark_hard(),
themes::gruvbox_dark_soft(),
themes::gruvbox_light(),
themes::gruvbox_light_hard(),
themes::gruvbox_light_soft(),
themes::one_dark(),
themes::one_light(),
themes::rose_pine(),
themes::rose_pine_dawn(),
themes::rose_pine_moon(),
themes::sandcastle(),
themes::solarized_dark(),
themes::solarized_light(),
themes::summercamp(),
]);
this

164
crates/theme2/src/scale.rs Normal file
View File

@ -0,0 +1,164 @@
use gpui2::{AppContext, Hsla};
use indexmap::IndexMap;
use crate::{theme, Appearance};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ColorScaleName {
Gray,
Mauve,
Slate,
Sage,
Olive,
Sand,
Gold,
Bronze,
Brown,
Yellow,
Amber,
Orange,
Tomato,
Red,
Ruby,
Crimson,
Pink,
Plum,
Purple,
Violet,
Iris,
Indigo,
Blue,
Cyan,
Teal,
Jade,
Green,
Grass,
Lime,
Mint,
Sky,
Black,
White,
}
impl std::fmt::Display for ColorScaleName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::Gray => "Gray",
Self::Mauve => "Mauve",
Self::Slate => "Slate",
Self::Sage => "Sage",
Self::Olive => "Olive",
Self::Sand => "Sand",
Self::Gold => "Gold",
Self::Bronze => "Bronze",
Self::Brown => "Brown",
Self::Yellow => "Yellow",
Self::Amber => "Amber",
Self::Orange => "Orange",
Self::Tomato => "Tomato",
Self::Red => "Red",
Self::Ruby => "Ruby",
Self::Crimson => "Crimson",
Self::Pink => "Pink",
Self::Plum => "Plum",
Self::Purple => "Purple",
Self::Violet => "Violet",
Self::Iris => "Iris",
Self::Indigo => "Indigo",
Self::Blue => "Blue",
Self::Cyan => "Cyan",
Self::Teal => "Teal",
Self::Jade => "Jade",
Self::Green => "Green",
Self::Grass => "Grass",
Self::Lime => "Lime",
Self::Mint => "Mint",
Self::Sky => "Sky",
Self::Black => "Black",
Self::White => "White",
}
)
}
}
pub type ColorScale = [Hsla; 12];
pub type ColorScales = IndexMap<ColorScaleName, ColorScaleSet>;
/// A one-based step in a [`ColorScale`].
pub type ColorScaleStep = usize;
pub struct ColorScaleSet {
name: ColorScaleName,
light: ColorScale,
dark: ColorScale,
light_alpha: ColorScale,
dark_alpha: ColorScale,
}
impl ColorScaleSet {
pub fn new(
name: ColorScaleName,
light: ColorScale,
light_alpha: ColorScale,
dark: ColorScale,
dark_alpha: ColorScale,
) -> Self {
Self {
name,
light,
light_alpha,
dark,
dark_alpha,
}
}
pub fn name(&self) -> String {
self.name.to_string()
}
pub fn light(&self, step: ColorScaleStep) -> Hsla {
self.light[step - 1]
}
pub fn light_alpha(&self, step: ColorScaleStep) -> Hsla {
self.light_alpha[step - 1]
}
pub fn dark(&self, step: ColorScaleStep) -> Hsla {
self.dark[step - 1]
}
pub fn dark_alpha(&self, step: ColorScaleStep) -> Hsla {
self.dark_alpha[step - 1]
}
fn current_appearance(cx: &AppContext) -> Appearance {
let theme = theme(cx);
if theme.metadata.is_light {
Appearance::Light
} else {
Appearance::Dark
}
}
pub fn step(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
let appearance = Self::current_appearance(cx);
match appearance {
Appearance::Light => self.light(step),
Appearance::Dark => self.dark(step),
}
}
pub fn step_alpha(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
let appearance = Self::current_appearance(cx);
match appearance {
Appearance::Light => self.light_alpha(step),
Appearance::Dark => self.dark_alpha(step),
}
}
}

View File

@ -1,14 +1,24 @@
mod default;
mod registry;
mod scale;
mod settings;
mod themes;
pub use default::*;
pub use registry::*;
pub use scale::*;
pub use settings::*;
use gpui2::{AppContext, HighlightStyle, Hsla, SharedString};
use settings2::Settings;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum Appearance {
Light,
Dark,
}
pub fn init(cx: &mut AppContext) {
cx.set_global(ThemeRegistry::default());
ThemeSettings::register(cx);
@ -18,6 +28,10 @@ pub fn active_theme<'a>(cx: &'a AppContext) -> &'a Arc<Theme> {
&ThemeSettings::get_global(cx).active_theme
}
pub fn theme(cx: &AppContext) -> Arc<Theme> {
active_theme(cx).clone()
}
pub struct Theme {
pub metadata: ThemeMetadata,

View File

@ -0,0 +1,130 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn andromeda() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Andromeda".into(),
is_light: false,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x2b2f38ff).into(),
border_variant: rgba(0x2b2f38ff).into(),
border_focused: rgba(0x183934ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0x262933ff).into(),
surface: rgba(0x21242bff).into(),
background: rgba(0x262933ff).into(),
filled_element: rgba(0x262933ff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0x12231fff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0x12231fff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0xf7f7f8ff).into(),
text_muted: rgba(0xaca8aeff).into(),
text_placeholder: rgba(0xf82871ff).into(),
text_disabled: rgba(0x6b6b73ff).into(),
text_accent: rgba(0x10a793ff).into(),
icon_muted: rgba(0xaca8aeff).into(),
syntax: SyntaxTheme {
highlights: vec![
("emphasis".into(), rgba(0x10a793ff).into()),
("punctuation.bracket".into(), rgba(0xd8d5dbff).into()),
("attribute".into(), rgba(0x10a793ff).into()),
("variable".into(), rgba(0xf7f7f8ff).into()),
("predictive".into(), rgba(0x315f70ff).into()),
("property".into(), rgba(0x10a793ff).into()),
("variant".into(), rgba(0x10a793ff).into()),
("embedded".into(), rgba(0xf7f7f8ff).into()),
("string.special".into(), rgba(0xf29c14ff).into()),
("keyword".into(), rgba(0x10a793ff).into()),
("tag".into(), rgba(0x10a793ff).into()),
("enum".into(), rgba(0xf29c14ff).into()),
("link_text".into(), rgba(0xf29c14ff).into()),
("primary".into(), rgba(0xf7f7f8ff).into()),
("punctuation".into(), rgba(0xd8d5dbff).into()),
("punctuation.special".into(), rgba(0xd8d5dbff).into()),
("function".into(), rgba(0xfee56cff).into()),
("number".into(), rgba(0x96df71ff).into()),
("preproc".into(), rgba(0xf7f7f8ff).into()),
("operator".into(), rgba(0xf29c14ff).into()),
("constructor".into(), rgba(0x10a793ff).into()),
("string.escape".into(), rgba(0xafabb1ff).into()),
("string.special.symbol".into(), rgba(0xf29c14ff).into()),
("string".into(), rgba(0xf29c14ff).into()),
("comment".into(), rgba(0xafabb1ff).into()),
("hint".into(), rgba(0x618399ff).into()),
("type".into(), rgba(0x08e7c5ff).into()),
("label".into(), rgba(0x10a793ff).into()),
("comment.doc".into(), rgba(0xafabb1ff).into()),
("text.literal".into(), rgba(0xf29c14ff).into()),
("constant".into(), rgba(0x96df71ff).into()),
("string.regex".into(), rgba(0xf29c14ff).into()),
("emphasis.strong".into(), rgba(0x10a793ff).into()),
("title".into(), rgba(0xf7f7f8ff).into()),
("punctuation.delimiter".into(), rgba(0xd8d5dbff).into()),
("link_uri".into(), rgba(0x96df71ff).into()),
("boolean".into(), rgba(0x96df71ff).into()),
("punctuation.list_marker".into(), rgba(0xd8d5dbff).into()),
],
},
status_bar: rgba(0x262933ff).into(),
title_bar: rgba(0x262933ff).into(),
toolbar: rgba(0x1e2025ff).into(),
tab_bar: rgba(0x21242bff).into(),
editor: rgba(0x1e2025ff).into(),
editor_subheader: rgba(0x21242bff).into(),
editor_active_line: rgba(0x21242bff).into(),
terminal: rgba(0x1e2025ff).into(),
image_fallback_background: rgba(0x262933ff).into(),
git_created: rgba(0x96df71ff).into(),
git_modified: rgba(0x10a793ff).into(),
git_deleted: rgba(0xf82871ff).into(),
git_conflict: rgba(0xfee56cff).into(),
git_ignored: rgba(0x6b6b73ff).into(),
git_renamed: rgba(0xfee56cff).into(),
players: [
PlayerTheme {
cursor: rgba(0x10a793ff).into(),
selection: rgba(0x10a7933d).into(),
},
PlayerTheme {
cursor: rgba(0x96df71ff).into(),
selection: rgba(0x96df713d).into(),
},
PlayerTheme {
cursor: rgba(0xc74cecff).into(),
selection: rgba(0xc74cec3d).into(),
},
PlayerTheme {
cursor: rgba(0xf29c14ff).into(),
selection: rgba(0xf29c143d).into(),
},
PlayerTheme {
cursor: rgba(0x893ea6ff).into(),
selection: rgba(0x893ea63d).into(),
},
PlayerTheme {
cursor: rgba(0x08e7c5ff).into(),
selection: rgba(0x08e7c53d).into(),
},
PlayerTheme {
cursor: rgba(0xf82871ff).into(),
selection: rgba(0xf828713d).into(),
},
PlayerTheme {
cursor: rgba(0xfee56cff).into(),
selection: rgba(0xfee56c3d).into(),
},
],
}
}

View File

@ -0,0 +1,136 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn atelier_cave_dark() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Atelier Cave Dark".into(),
is_light: false,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x56505eff).into(),
border_variant: rgba(0x56505eff).into(),
border_focused: rgba(0x222953ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0x3a353fff).into(),
surface: rgba(0x221f26ff).into(),
background: rgba(0x3a353fff).into(),
filled_element: rgba(0x3a353fff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0x161a35ff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0x161a35ff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0xefecf4ff).into(),
text_muted: rgba(0x898591ff).into(),
text_placeholder: rgba(0xbe4677ff).into(),
text_disabled: rgba(0x756f7eff).into(),
text_accent: rgba(0x566ddaff).into(),
icon_muted: rgba(0x898591ff).into(),
syntax: SyntaxTheme {
highlights: vec![
("comment.doc".into(), rgba(0x8b8792ff).into()),
("tag".into(), rgba(0x566ddaff).into()),
("link_text".into(), rgba(0xaa563bff).into()),
("constructor".into(), rgba(0x566ddaff).into()),
("punctuation".into(), rgba(0xe2dfe7ff).into()),
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
("string.special.symbol".into(), rgba(0x299292ff).into()),
("string.escape".into(), rgba(0x8b8792ff).into()),
("emphasis".into(), rgba(0x566ddaff).into()),
("type".into(), rgba(0xa06d3aff).into()),
("punctuation.delimiter".into(), rgba(0x8b8792ff).into()),
("variant".into(), rgba(0xa06d3aff).into()),
("variable.special".into(), rgba(0x9559e7ff).into()),
("text.literal".into(), rgba(0xaa563bff).into()),
("punctuation.list_marker".into(), rgba(0xe2dfe7ff).into()),
("comment".into(), rgba(0x655f6dff).into()),
("function.method".into(), rgba(0x576cdbff).into()),
("property".into(), rgba(0xbe4677ff).into()),
("operator".into(), rgba(0x8b8792ff).into()),
("emphasis.strong".into(), rgba(0x566ddaff).into()),
("label".into(), rgba(0x566ddaff).into()),
("enum".into(), rgba(0xaa563bff).into()),
("number".into(), rgba(0xaa563bff).into()),
("primary".into(), rgba(0xe2dfe7ff).into()),
("keyword".into(), rgba(0x9559e7ff).into()),
(
"function.special.definition".into(),
rgba(0xa06d3aff).into(),
),
("punctuation.bracket".into(), rgba(0x8b8792ff).into()),
("constant".into(), rgba(0x2b9292ff).into()),
("string.special".into(), rgba(0xbf3fbfff).into()),
("title".into(), rgba(0xefecf4ff).into()),
("preproc".into(), rgba(0xefecf4ff).into()),
("link_uri".into(), rgba(0x2b9292ff).into()),
("string".into(), rgba(0x299292ff).into()),
("embedded".into(), rgba(0xefecf4ff).into()),
("hint".into(), rgba(0x706897ff).into()),
("boolean".into(), rgba(0x2b9292ff).into()),
("variable".into(), rgba(0xe2dfe7ff).into()),
("predictive".into(), rgba(0x615787ff).into()),
("string.regex".into(), rgba(0x388bc6ff).into()),
("function".into(), rgba(0x576cdbff).into()),
("attribute".into(), rgba(0x566ddaff).into()),
],
},
status_bar: rgba(0x3a353fff).into(),
title_bar: rgba(0x3a353fff).into(),
toolbar: rgba(0x19171cff).into(),
tab_bar: rgba(0x221f26ff).into(),
editor: rgba(0x19171cff).into(),
editor_subheader: rgba(0x221f26ff).into(),
editor_active_line: rgba(0x221f26ff).into(),
terminal: rgba(0x19171cff).into(),
image_fallback_background: rgba(0x3a353fff).into(),
git_created: rgba(0x2b9292ff).into(),
git_modified: rgba(0x566ddaff).into(),
git_deleted: rgba(0xbe4677ff).into(),
git_conflict: rgba(0xa06d3aff).into(),
git_ignored: rgba(0x756f7eff).into(),
git_renamed: rgba(0xa06d3aff).into(),
players: [
PlayerTheme {
cursor: rgba(0x566ddaff).into(),
selection: rgba(0x566dda3d).into(),
},
PlayerTheme {
cursor: rgba(0x2b9292ff).into(),
selection: rgba(0x2b92923d).into(),
},
PlayerTheme {
cursor: rgba(0xbf41bfff).into(),
selection: rgba(0xbf41bf3d).into(),
},
PlayerTheme {
cursor: rgba(0xaa563bff).into(),
selection: rgba(0xaa563b3d).into(),
},
PlayerTheme {
cursor: rgba(0x955ae6ff).into(),
selection: rgba(0x955ae63d).into(),
},
PlayerTheme {
cursor: rgba(0x3a8bc6ff).into(),
selection: rgba(0x3a8bc63d).into(),
},
PlayerTheme {
cursor: rgba(0xbe4677ff).into(),
selection: rgba(0xbe46773d).into(),
},
PlayerTheme {
cursor: rgba(0xa06d3aff).into(),
selection: rgba(0xa06d3a3d).into(),
},
],
}
}

View File

@ -0,0 +1,136 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn atelier_cave_light() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Atelier Cave Light".into(),
is_light: true,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x8f8b96ff).into(),
border_variant: rgba(0x8f8b96ff).into(),
border_focused: rgba(0xc8c7f2ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0xbfbcc5ff).into(),
surface: rgba(0xe6e3ebff).into(),
background: rgba(0xbfbcc5ff).into(),
filled_element: rgba(0xbfbcc5ff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0xe1e0f9ff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0xe1e0f9ff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0x19171cff).into(),
text_muted: rgba(0x5a5462ff).into(),
text_placeholder: rgba(0xbd4677ff).into(),
text_disabled: rgba(0x6e6876ff).into(),
text_accent: rgba(0x586cdaff).into(),
icon_muted: rgba(0x5a5462ff).into(),
syntax: SyntaxTheme {
highlights: vec![
("link_text".into(), rgba(0xaa573cff).into()),
("string".into(), rgba(0x299292ff).into()),
("emphasis".into(), rgba(0x586cdaff).into()),
("label".into(), rgba(0x586cdaff).into()),
("property".into(), rgba(0xbe4677ff).into()),
("emphasis.strong".into(), rgba(0x586cdaff).into()),
("constant".into(), rgba(0x2b9292ff).into()),
(
"function.special.definition".into(),
rgba(0xa06d3aff).into(),
),
("embedded".into(), rgba(0x19171cff).into()),
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
("function".into(), rgba(0x576cdbff).into()),
("tag".into(), rgba(0x586cdaff).into()),
("number".into(), rgba(0xaa563bff).into()),
("primary".into(), rgba(0x26232aff).into()),
("text.literal".into(), rgba(0xaa573cff).into()),
("variant".into(), rgba(0xa06d3aff).into()),
("type".into(), rgba(0xa06d3aff).into()),
("punctuation".into(), rgba(0x26232aff).into()),
("string.escape".into(), rgba(0x585260ff).into()),
("keyword".into(), rgba(0x9559e7ff).into()),
("title".into(), rgba(0x19171cff).into()),
("constructor".into(), rgba(0x586cdaff).into()),
("punctuation.list_marker".into(), rgba(0x26232aff).into()),
("string.special".into(), rgba(0xbf3fbfff).into()),
("operator".into(), rgba(0x585260ff).into()),
("function.method".into(), rgba(0x576cdbff).into()),
("link_uri".into(), rgba(0x2b9292ff).into()),
("variable.special".into(), rgba(0x9559e7ff).into()),
("hint".into(), rgba(0x776d9dff).into()),
("punctuation.bracket".into(), rgba(0x585260ff).into()),
("string.special.symbol".into(), rgba(0x299292ff).into()),
("predictive".into(), rgba(0x887fafff).into()),
("attribute".into(), rgba(0x586cdaff).into()),
("enum".into(), rgba(0xaa573cff).into()),
("preproc".into(), rgba(0x19171cff).into()),
("boolean".into(), rgba(0x2b9292ff).into()),
("variable".into(), rgba(0x26232aff).into()),
("comment.doc".into(), rgba(0x585260ff).into()),
("string.regex".into(), rgba(0x388bc6ff).into()),
("punctuation.delimiter".into(), rgba(0x585260ff).into()),
("comment".into(), rgba(0x7d7787ff).into()),
],
},
status_bar: rgba(0xbfbcc5ff).into(),
title_bar: rgba(0xbfbcc5ff).into(),
toolbar: rgba(0xefecf4ff).into(),
tab_bar: rgba(0xe6e3ebff).into(),
editor: rgba(0xefecf4ff).into(),
editor_subheader: rgba(0xe6e3ebff).into(),
editor_active_line: rgba(0xe6e3ebff).into(),
terminal: rgba(0xefecf4ff).into(),
image_fallback_background: rgba(0xbfbcc5ff).into(),
git_created: rgba(0x2b9292ff).into(),
git_modified: rgba(0x586cdaff).into(),
git_deleted: rgba(0xbd4677ff).into(),
git_conflict: rgba(0xa06e3bff).into(),
git_ignored: rgba(0x6e6876ff).into(),
git_renamed: rgba(0xa06e3bff).into(),
players: [
PlayerTheme {
cursor: rgba(0x586cdaff).into(),
selection: rgba(0x586cda3d).into(),
},
PlayerTheme {
cursor: rgba(0x2b9292ff).into(),
selection: rgba(0x2b92923d).into(),
},
PlayerTheme {
cursor: rgba(0xbf41bfff).into(),
selection: rgba(0xbf41bf3d).into(),
},
PlayerTheme {
cursor: rgba(0xaa573cff).into(),
selection: rgba(0xaa573c3d).into(),
},
PlayerTheme {
cursor: rgba(0x955ae6ff).into(),
selection: rgba(0x955ae63d).into(),
},
PlayerTheme {
cursor: rgba(0x3a8bc6ff).into(),
selection: rgba(0x3a8bc63d).into(),
},
PlayerTheme {
cursor: rgba(0xbd4677ff).into(),
selection: rgba(0xbd46773d).into(),
},
PlayerTheme {
cursor: rgba(0xa06e3bff).into(),
selection: rgba(0xa06e3b3d).into(),
},
],
}
}

View File

@ -0,0 +1,136 @@
use gpui2::rgba;
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
pub fn atelier_dune_dark() -> Theme {
Theme {
metadata: ThemeMetadata {
name: "Atelier Dune Dark".into(),
is_light: false,
},
transparent: rgba(0x00000000).into(),
mac_os_traffic_light_red: rgba(0xec695eff).into(),
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
border: rgba(0x6c695cff).into(),
border_variant: rgba(0x6c695cff).into(),
border_focused: rgba(0x262f56ff).into(),
border_transparent: rgba(0x00000000).into(),
elevated_surface: rgba(0x45433bff).into(),
surface: rgba(0x262622ff).into(),
background: rgba(0x45433bff).into(),
filled_element: rgba(0x45433bff).into(),
filled_element_hover: rgba(0xffffff1e).into(),
filled_element_active: rgba(0xffffff28).into(),
filled_element_selected: rgba(0x171e38ff).into(),
filled_element_disabled: rgba(0x00000000).into(),
ghost_element: rgba(0x00000000).into(),
ghost_element_hover: rgba(0xffffff14).into(),
ghost_element_active: rgba(0xffffff1e).into(),
ghost_element_selected: rgba(0x171e38ff).into(),
ghost_element_disabled: rgba(0x00000000).into(),
text: rgba(0xfefbecff).into(),
text_muted: rgba(0xa4a08bff).into(),
text_placeholder: rgba(0xd73837ff).into(),
text_disabled: rgba(0x8f8b77ff).into(),
text_accent: rgba(0x6684e0ff).into(),
icon_muted: rgba(0xa4a08bff).into(),
syntax: SyntaxTheme {
highlights: vec![
("constructor".into(), rgba(0x6684e0ff).into()),
("punctuation".into(), rgba(0xe8e4cfff).into()),
("punctuation.delimiter".into(), rgba(0xa6a28cff).into()),
("string.special".into(), rgba(0xd43451ff).into()),
("string.escape".into(), rgba(0xa6a28cff).into()),
("comment".into(), rgba(0x7d7a68ff).into()),
("enum".into(), rgba(0xb65611ff).into()),
("variable.special".into(), rgba(0xb854d4ff).into()),
("primary".into(), rgba(0xe8e4cfff).into()),
("comment.doc".into(), rgba(0xa6a28cff).into()),
("label".into(), rgba(0x6684e0ff).into()),
("operator".into(), rgba(0xa6a28cff).into()),
("string".into(), rgba(0x5fac38ff).into()),
("variant".into(), rgba(0xae9512ff).into()),
("variable".into(), rgba(0xe8e4cfff).into()),
("function.method".into(), rgba(0x6583e1ff).into()),
(
"function.special.definition".into(),
rgba(0xae9512ff).into(),
),
("string.regex".into(), rgba(0x1ead82ff).into()),
("emphasis.strong".into(), rgba(0x6684e0ff).into()),
("punctuation.special".into(), rgba(0xd43451ff).into()),
("punctuation.bracket".into(), rgba(0xa6a28cff).into()),
("link_text".into(), rgba(0xb65611ff).into()),
("link_uri".into(), rgba(0x5fac39ff).into()),
("boolean".into(), rgba(0x5fac39ff).into()),
("hint".into(), rgba(0xb17272ff).into()),
("tag".into(), rgba(0x6684e0ff).into()),
("function".into(), rgba(0x6583e1ff).into()),
("title".into(), rgba(0xfefbecff).into()),
("property".into(), rgba(0xd73737ff).into()),
("type".into(), rgba(0xae9512ff).into()),
("constant".into(), rgba(0x5fac39ff).into()),
("attribute".into(), rgba(0x6684e0ff).into()),
("predictive".into(), rgba(0x9c6262ff).into()),
("string.special.symbol".into(), rgba(0x5fac38ff).into()),
("punctuation.list_marker".into(), rgba(0xe8e4cfff).into()),
("emphasis".into(), rgba(0x6684e0ff).into()),
("keyword".into(), rgba(0xb854d4ff).into()),
("text.literal".into(), rgba(0xb65611ff).into()),
("number".into(), rgba(0xb65610ff).into()),
("preproc".into(), rgba(0xfefbecff).into()),
("embedded".into(), rgba(0xfefbecff).into()),
],
},
status_bar: rgba(0x45433bff).into(),
title_bar: rgba(0x45433bff).into(),
toolbar: rgba(0x20201dff).into(),
tab_bar: rgba(0x262622ff).into(),
editor: rgba(0x20201dff).into(),
editor_subheader: rgba(0x262622ff).into(),
editor_active_line: rgba(0x262622ff).into(),
terminal: rgba(0x20201dff).into(),
image_fallback_background: rgba(0x45433bff).into(),
git_created: rgba(0x5fac39ff).into(),
git_modified: rgba(0x6684e0ff).into(),
git_deleted: rgba(0xd73837ff).into(),
git_conflict: rgba(0xae9414ff).into(),
git_ignored: rgba(0x8f8b77ff).into(),
git_renamed: rgba(0xae9414ff).into(),
players: [
PlayerTheme {
cursor: rgba(0x6684e0ff).into(),
selection: rgba(0x6684e03d).into(),
},
PlayerTheme {
cursor: rgba(0x5fac39ff).into(),
selection: rgba(0x5fac393d).into(),
},
PlayerTheme {
cursor: rgba(0xd43651ff).into(),
selection: rgba(0xd436513d).into(),
},
PlayerTheme {
cursor: rgba(0xb65611ff).into(),
selection: rgba(0xb656113d).into(),
},
PlayerTheme {
cursor: rgba(0xb854d3ff).into(),
selection: rgba(0xb854d33d).into(),
},
PlayerTheme {
cursor: rgba(0x20ad83ff).into(),
selection: rgba(0x20ad833d).into(),
},
PlayerTheme {
cursor: rgba(0xd73837ff).into(),
selection: rgba(0xd738373d).into(),
},
PlayerTheme {
cursor: rgba(0xae9414ff).into(),
selection: rgba(0xae94143d).into(),
},
],
}
}

Some files were not shown because too many files have changed in this diff Show More