mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Merge branch 'zed2' into zed2-workspace
This commit is contained in:
commit
3dadfb8ba8
42
Cargo.lock
generated
42
Cargo.lock
generated
@ -108,6 +108,33 @@ dependencies = [
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ai2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"isahc",
|
||||
"language2",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"ordered-float 2.10.0",
|
||||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"postage",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tiktoken-rs",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "alacritty_config"
|
||||
version = "0.1.2-dev"
|
||||
@ -1138,7 +1165,7 @@ dependencies = [
|
||||
"audio2",
|
||||
"client2",
|
||||
"collections",
|
||||
"fs",
|
||||
"fs2",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"language2",
|
||||
@ -4795,6 +4822,13 @@ dependencies = [
|
||||
"gpui",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "menu2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"gpui2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metal"
|
||||
version = "0.21.0"
|
||||
@ -6000,7 +6034,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"client2",
|
||||
"collections",
|
||||
"fs",
|
||||
"fs2",
|
||||
"futures 0.3.28",
|
||||
"gpui2",
|
||||
"language2",
|
||||
@ -6167,7 +6201,7 @@ dependencies = [
|
||||
"ctor",
|
||||
"db2",
|
||||
"env_logger 0.9.3",
|
||||
"fs",
|
||||
"fs2",
|
||||
"fsevent",
|
||||
"futures 0.3.28",
|
||||
"fuzzy2",
|
||||
@ -8740,6 +8774,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap 4.4.4",
|
||||
"convert_case 0.6.0",
|
||||
"gpui2",
|
||||
"log",
|
||||
"rust-embed",
|
||||
@ -10932,6 +10967,7 @@ dependencies = [
|
||||
name = "zed2"
|
||||
version = "0.109.0"
|
||||
dependencies = [
|
||||
"ai2",
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"async-recursion 0.3.2",
|
||||
|
@ -59,6 +59,7 @@ members = [
|
||||
"crates/lsp2",
|
||||
"crates/media",
|
||||
"crates/menu",
|
||||
"crates/menu2",
|
||||
"crates/multi_buffer",
|
||||
"crates/node_runtime",
|
||||
"crates/notifications",
|
||||
|
38
crates/Cargo.toml
Normal file
38
crates/Cargo.toml
Normal file
@ -0,0 +1,38 @@
|
||||
[package]
|
||||
name = "ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/ai.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
gpui = { path = "../gpui" }
|
||||
util = { path = "../util" }
|
||||
language = { path = "../language" }
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
lazy_static.workspace = true
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
isahc.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
log.workspace = true
|
||||
parse_duration = "2.1.1"
|
||||
tiktoken-rs = "0.5.0"
|
||||
matrixmultiply = "0.3.7"
|
||||
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
bincode = "1.3.3"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
@ -8,6 +8,9 @@ publish = false
|
||||
path = "src/ai.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
gpui = { path = "../gpui" }
|
||||
util = { path = "../util" }
|
||||
|
@ -1,4 +1,8 @@
|
||||
pub mod auth;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod models;
|
||||
pub mod templates;
|
||||
pub mod prompts;
|
||||
pub mod providers;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
15
crates/ai/src/auth.rs
Normal file
15
crates/ai/src/auth.rs
Normal file
@ -0,0 +1,15 @@
|
||||
use gpui::AppContext;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderCredential {
|
||||
Credentials { api_key: String },
|
||||
NoCredentials,
|
||||
NotNeeded,
|
||||
}
|
||||
|
||||
pub trait CredentialProvider: Send + Sync {
|
||||
fn has_credentials(&self) -> bool;
|
||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
|
||||
fn delete_credentials(&self, cx: &AppContext);
|
||||
}
|
@ -1,214 +1,23 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui::executor::Background;
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
||||
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
pub trait CompletionRequest: Send + Sync {
|
||||
fn data(&self) -> serde_json::Result<String>;
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAIUsage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
mut request: OpenAIRequest,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
request.stream = true;
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = serde_json::to_string(&request)?;
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(&data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIResponse {
|
||||
error: OpenAIError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CompletionProvider {
|
||||
pub trait CompletionProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: OpenAIRequest,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||
}
|
||||
|
||||
pub struct OpenAICompletionProvider {
|
||||
api_key: String,
|
||||
executor: Arc<Background>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(api_key: String, executor: Arc<Background>) -> Self {
|
||||
Self { api_key, executor }
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: OpenAIRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
impl Clone for Box<dyn CompletionProvider> {
|
||||
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||
self.box_clone()
|
||||
}
|
||||
}
|
||||
|
@ -1,32 +1,13 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::executor::Background;
|
||||
use gpui::{serde_json, AppContext};
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::Mutex;
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::completion::OPENAI_API_URL;
|
||||
|
||||
lazy_static! {
|
||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
use crate::auth::CredentialProvider;
|
||||
use crate::models::LanguageModel;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(pub Vec<f32>);
|
||||
@ -87,301 +68,14 @@ impl Embedding {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: Arc<Background>,
|
||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAIEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingResponse {
|
||||
data: Vec<OpenAIEmbedding>,
|
||||
usage: OpenAIEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Sync + Send {
|
||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
spans: Vec<String>,
|
||||
api_key: Option<String>,
|
||||
) -> Result<Vec<Embedding>>;
|
||||
pub trait EmbeddingProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn truncate(&self, span: &str) -> (String, usize);
|
||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||
}
|
||||
|
||||
pub struct DummyEmbeddings {}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for DummyEmbeddings {
|
||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
||||
Some("Dummy API KEY".to_string())
|
||||
}
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
spans: Vec<String>,
|
||||
_api_key: Option<String>,
|
||||
) -> Result<Vec<Embedding>> {
|
||||
// 1024 is the OpenAI Embeddings size for ada models.
|
||||
// the model we will likely be starting with.
|
||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
||||
return Ok(vec![dummy_vec; spans.len()]);
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
OPENAI_INPUT_LIMIT
|
||||
}
|
||||
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let token_count = tokens.len();
|
||||
let output = if token_count > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
||||
new_input.ok().unwrap_or_else(|| span.to_string())
|
||||
} else {
|
||||
span.to_string()
|
||||
};
|
||||
|
||||
(output, tokens.len())
|
||||
}
|
||||
}
|
||||
|
||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||
|
||||
impl OpenAIEmbeddings {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
OpenAIEmbeddings {
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
rate_limit_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_rate_limit(&self) {
|
||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
if let Some(reset_time) = reset_time {
|
||||
if Instant::now() >= reset_time {
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"resolving reset time: {:?}",
|
||||
*self.rate_limit_count_tx.lock().borrow()
|
||||
);
|
||||
}
|
||||
|
||||
fn update_reset_time(&self, reset_time: Instant) {
|
||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
let updated_time = if let Some(original_time) = original_time {
|
||||
if reset_time < original_time {
|
||||
Some(reset_time)
|
||||
} else {
|
||||
Some(original_time)
|
||||
}
|
||||
} else {
|
||||
Some(reset_time)
|
||||
};
|
||||
|
||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||
input: spans.clone(),
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
Ok(self.client.send(request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
|
||||
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
Some(api_key)
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
String::from_utf8(api_key).log_err()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
*self.rate_limit_count_rx.borrow()
|
||||
}
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
OPENAI_BPE_TOKENIZER
|
||||
.decode(tokens.clone())
|
||||
.ok()
|
||||
.unwrap_or_else(|| span.to_string())
|
||||
} else {
|
||||
span.to_string()
|
||||
};
|
||||
|
||||
(output, tokens.len())
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
spans: Vec<String>,
|
||||
api_key: Option<String>,
|
||||
) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(anyhow!("no open ai key provided"));
|
||||
};
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
request_number += 1;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::REQUEST_TIMEOUT => {
|
||||
request_timeout += 5;
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::trace!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// If we complete a request successfully that was previously rate_limited
|
||||
// resolve the rate limit
|
||||
if rate_limiting {
|
||||
self.resolve_rate_limit()
|
||||
}
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
rate_limiting = true;
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
let delay_duration = {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
if let Some(time_to_reset) =
|
||||
response.headers().get("x-ratelimit-reset-tokens")
|
||||
{
|
||||
if let Ok(time_str) = time_to_reset.to_str() {
|
||||
parse(time_str).unwrap_or(delay)
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
};
|
||||
|
||||
// If we've previously rate limited, increment the duration but not the count
|
||||
let reset_time = Instant::now().add(delay_duration);
|
||||
self.update_reset_time(reset_time);
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
);
|
||||
|
||||
self.executor.timer(delay_duration).await;
|
||||
}
|
||||
_ => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"open ai bad request: {:?} {:?}",
|
||||
&response.status(),
|
||||
body
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -1,66 +1,16 @@
|
||||
use anyhow::anyhow;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
use util::ResultExt;
|
||||
pub enum TruncationDirection {
|
||||
Start,
|
||||
End,
|
||||
}
|
||||
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
||||
|
||||
pub struct OpenAILanguageModel {
|
||||
name: String,
|
||||
bpe: Option<CoreBPE>,
|
||||
}
|
||||
|
||||
impl OpenAILanguageModel {
|
||||
pub fn load(model_name: &str) -> Self {
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||
OpenAILanguageModel {
|
||||
name: model_name.to_string(),
|
||||
bpe,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAILanguageModel {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
bpe.decode(tokens[..length].to_vec())
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
bpe.decode(tokens[length..].to_vec())
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ use language::BufferSnapshot;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::templates::repository_context::PromptCodeSnippet;
|
||||
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||
|
||||
pub(crate) enum PromptFileType {
|
||||
Text,
|
||||
@ -125,6 +125,9 @@ impl PromptChain {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::test::FakeLanguageModel;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@ -141,7 +144,11 @@ pub(crate) mod tests {
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(&content, max_token_length)?;
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
@ -162,7 +169,11 @@ pub(crate) mod tests {
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(&content, max_token_length)?;
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
@ -171,38 +182,7 @@ pub(crate) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DummyLanguageModel {
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl LanguageModel for DummyLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
"dummy".to_string()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||
}
|
||||
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
||||
anyhow::Ok(
|
||||
content.chars().collect::<Vec<char>>()[..length]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
)
|
||||
}
|
||||
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
|
||||
anyhow::Ok(
|
||||
content.chars().collect::<Vec<char>>()[length..]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
)
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(self.capacity)
|
||||
}
|
||||
}
|
||||
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
@ -238,7 +218,7 @@ pub(crate) mod tests {
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
@ -275,7 +255,7 @@ pub(crate) mod tests {
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let capacity = 20;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
@ -311,7 +291,7 @@ pub(crate) mod tests {
|
||||
// Change Ordering of Prompts Based on Priority
|
||||
let capacity = 120;
|
||||
let reserved_tokens = 10;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
@ -3,8 +3,9 @@ use language::BufferSnapshot;
|
||||
use language::ToOffset;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::templates::base::PromptArguments;
|
||||
use crate::templates::base::PromptTemplate;
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::prompts::base::PromptArguments;
|
||||
use crate::prompts::base::PromptTemplate;
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
@ -70,8 +71,9 @@ fn retrieve_context(
|
||||
};
|
||||
|
||||
let truncated_start_window =
|
||||
model.truncate_start(&start_window, start_goal_tokens)?;
|
||||
let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
|
||||
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||
let truncated_end_window =
|
||||
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||
@ -89,7 +91,7 @@ fn retrieve_context(
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
if model.count_tokens(&prompt)? > max_token_count {
|
||||
truncated = true;
|
||||
prompt = model.truncate(&prompt, max_token_count)?;
|
||||
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args.model.truncate(&prompt, max_tokens)?;
|
||||
prompt = args
|
||||
.model
|
||||
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
@ -1,4 +1,4 @@
|
||||
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use anyhow::anyhow;
|
||||
use std::fmt::Write;
|
||||
|
||||
@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args.model.truncate(&prompt, max_tokens)?;
|
||||
prompt = args.model.truncate(
|
||||
&prompt,
|
||||
max_tokens,
|
||||
crate::models::TruncationDirection::End,
|
||||
)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
@ -1,4 +1,4 @@
|
||||
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
|
||||
pub struct EngineerPreamble {}
|
@ -1,4 +1,4 @@
|
||||
use crate::templates::base::{PromptArguments, PromptTemplate};
|
||||
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
use std::{ops::Range, path::PathBuf};
|
||||
|
1
crates/ai/src/providers/mod.rs
Normal file
1
crates/ai/src/providers/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod open_ai;
|
298
crates/ai/src/providers/open_ai/completion.rs
Normal file
298
crates/ai/src/providers/open_ai/completion.rs
Normal file
@ -0,0 +1,298 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui::{executor::Background, AppContext};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
env,
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
models::LanguageModel,
|
||||
};
|
||||
|
||||
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl CompletionRequest for OpenAIRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAIUsage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
credential: ProviderCredential,
|
||||
executor: Arc<Background>,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
let api_key = match credential {
|
||||
ProviderCredential::Credentials { api_key } => api_key,
|
||||
_ => {
|
||||
return Err(anyhow!("no credentials provider for completion"));
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = request.data()?;
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(&data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIResponse {
|
||||
error: OpenAIError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAICompletionProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
executor: Arc<Background>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
|
||||
let model = OpenAILanguageModel::load(model_name);
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
Self {
|
||||
model,
|
||||
credential,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for OpenAICompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||
let mut credential = self.credential.write();
|
||||
match *credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return credential.clone();
|
||||
}
|
||||
_ => {
|
||||
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
*credential = ProviderCredential::Credentials { api_key };
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
*credential = ProviderCredential::Credentials { api_key };
|
||||
}
|
||||
} else {
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
credential.clone()
|
||||
}
|
||||
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||
match credential.clone() {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.platform()
|
||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
*self.credential.write() = credential;
|
||||
}
|
||||
fn delete_credentials(&self, cx: &AppContext) {
|
||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||
// which is currently model based, due to the langauge model.
|
||||
// At some point in the future we should rectify this.
|
||||
let credential = self.credential.read().clone();
|
||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
306
crates/ai/src/providers/open_ai/embedding.rs
Normal file
306
crates/ai/src/providers/open_ai/embedding.rs
Normal file
@ -0,0 +1,306 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::executor::Background;
|
||||
use gpui::{serde_json, AppContext};
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||
use crate::models::LanguageModel;
|
||||
use crate::providers::open_ai::OpenAILanguageModel;
|
||||
|
||||
use crate::providers::open_ai::OPENAI_API_URL;
|
||||
|
||||
lazy_static! {
|
||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddingProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: Arc<Background>,
|
||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAIEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingResponse {
|
||||
data: Vec<OpenAIEmbedding>,
|
||||
usage: OpenAIEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
impl OpenAIEmbeddingProvider {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
|
||||
OpenAIEmbeddingProvider {
|
||||
model,
|
||||
credential,
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
rate_limit_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_api_key(&self) -> Result<String> {
|
||||
match self.credential.read().clone() {
|
||||
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||
_ => Err(anyhow!("api credentials not provided")),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_rate_limit(&self) {
|
||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
if let Some(reset_time) = reset_time {
|
||||
if Instant::now() >= reset_time {
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"resolving reset time: {:?}",
|
||||
*self.rate_limit_count_tx.lock().borrow()
|
||||
);
|
||||
}
|
||||
|
||||
fn update_reset_time(&self, reset_time: Instant) {
|
||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
let updated_time = if let Some(original_time) = original_time {
|
||||
if reset_time < original_time {
|
||||
Some(reset_time)
|
||||
} else {
|
||||
Some(original_time)
|
||||
}
|
||||
} else {
|
||||
Some(reset_time)
|
||||
};
|
||||
|
||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||
input: spans.clone(),
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
Ok(self.client.send(request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
|
||||
let mut credential = self.credential.write();
|
||||
match *credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return credential.clone();
|
||||
}
|
||||
_ => {
|
||||
if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
*credential = ProviderCredential::Credentials { api_key };
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
*credential = ProviderCredential::Credentials { api_key };
|
||||
}
|
||||
} else {
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
credential.clone()
|
||||
}
|
||||
|
||||
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
|
||||
match credential.clone() {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.platform()
|
||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
*self.credential.write() = credential;
|
||||
}
|
||||
fn delete_credentials(&self, cx: &AppContext) {
|
||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
*self.rate_limit_count_rx.borrow()
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
request_number += 1;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::REQUEST_TIMEOUT => {
|
||||
request_timeout += 5;
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::trace!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// If we complete a request successfully that was previously rate_limited
|
||||
// resolve the rate limit
|
||||
if rate_limiting {
|
||||
self.resolve_rate_limit()
|
||||
}
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
rate_limiting = true;
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
let delay_duration = {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
if let Some(time_to_reset) =
|
||||
response.headers().get("x-ratelimit-reset-tokens")
|
||||
{
|
||||
if let Ok(time_str) = time_to_reset.to_str() {
|
||||
parse(time_str).unwrap_or(delay)
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
};
|
||||
|
||||
// If we've previously rate limited, increment the duration but not the count
|
||||
let reset_time = Instant::now().add(delay_duration);
|
||||
self.update_reset_time(reset_time);
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
);
|
||||
|
||||
self.executor.timer(delay_duration).await;
|
||||
}
|
||||
_ => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"open ai bad request: {:?} {:?}",
|
||||
&response.status(),
|
||||
body
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
9
crates/ai/src/providers/open_ai/mod.rs
Normal file
9
crates/ai/src/providers/open_ai/mod.rs
Normal file
@ -0,0 +1,9 @@
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod model;
|
||||
|
||||
pub use completion::*;
|
||||
pub use embedding::*;
|
||||
pub use model::OpenAILanguageModel;
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
57
crates/ai/src/providers/open_ai/model.rs
Normal file
57
crates/ai/src/providers/open_ai/model.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use anyhow::anyhow;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::{LanguageModel, TruncationDirection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAILanguageModel {
|
||||
name: String,
|
||||
bpe: Option<CoreBPE>,
|
||||
}
|
||||
|
||||
impl OpenAILanguageModel {
|
||||
pub fn load(model_name: &str) -> Self {
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||
OpenAILanguageModel {
|
||||
name: model_name.to_string(),
|
||||
bpe,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAILanguageModel {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
match direction {
|
||||
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||
}
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||
}
|
||||
}
|
11
crates/ai/src/providers/open_ai/new.rs
Normal file
11
crates/ai/src/providers/open_ai/new.rs
Normal file
@ -0,0 +1,11 @@
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
191
crates/ai/src/test.rs
Normal file
191
crates/ai/src/test.rs
Normal file
@ -0,0 +1,191 @@
|
||||
use std::{
|
||||
sync::atomic::{self, AtomicUsize, Ordering},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::AppContext;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::{LanguageModel, TruncationDirection},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FakeLanguageModel {
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl LanguageModel for FakeLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
"dummy".to_string()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||
|
||||
if length > self.count_tokens(content)? {
|
||||
println!("NOT TRUNCATING");
|
||||
return anyhow::Ok(content.to_string());
|
||||
}
|
||||
|
||||
anyhow::Ok(match direction {
|
||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
})
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(self.capacity)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeEmbeddingProvider {
|
||||
pub embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl Clone for FakeEmbeddingProvider {
|
||||
fn clone(&self) -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FakeEmbeddingProvider {
|
||||
fn default() -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for FakeEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||
ProviderCredential::NotNeeded
|
||||
}
|
||||
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||
}
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl Clone for FakeCompletionProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for FakeCompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
|
||||
ProviderCredential::NotNeeded
|
||||
}
|
||||
fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
|
||||
fn delete_credentials(&self, _cx: &AppContext) {}
|
||||
}
|
||||
|
||||
impl CompletionProvider for FakeCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
38
crates/ai2/Cargo.toml
Normal file
38
crates/ai2/Cargo.toml
Normal file
@ -0,0 +1,38 @@
|
||||
[package]
|
||||
name = "ai2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/ai2.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
gpui2 = { path = "../gpui2" }
|
||||
util = { path = "../util" }
|
||||
language2 = { path = "../language2" }
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
lazy_static.workspace = true
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
isahc.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
log.workspace = true
|
||||
parse_duration = "2.1.1"
|
||||
tiktoken-rs = "0.5.0"
|
||||
matrixmultiply = "0.3.7"
|
||||
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
bincode = "1.3.3"
|
||||
|
||||
[dev-dependencies]
|
||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
8
crates/ai2/src/ai2.rs
Normal file
8
crates/ai2/src/ai2.rs
Normal file
@ -0,0 +1,8 @@
|
||||
pub mod auth;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod models;
|
||||
pub mod prompts;
|
||||
pub mod providers;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
17
crates/ai2/src/auth.rs
Normal file
17
crates/ai2/src/auth.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use async_trait::async_trait;
|
||||
use gpui2::AppContext;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderCredential {
|
||||
Credentials { api_key: String },
|
||||
NoCredentials,
|
||||
NotNeeded,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait CredentialProvider: Send + Sync {
|
||||
fn has_credentials(&self) -> bool;
|
||||
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
|
||||
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
|
||||
async fn delete_credentials(&self, cx: &mut AppContext);
|
||||
}
|
23
crates/ai2/src/completion.rs
Normal file
23
crates/ai2/src/completion.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use anyhow::Result;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
|
||||
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||
|
||||
pub trait CompletionRequest: Send + Sync {
|
||||
fn data(&self) -> serde_json::Result<String>;
|
||||
}
|
||||
|
||||
pub trait CompletionProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn CompletionProvider> {
|
||||
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||
self.box_clone()
|
||||
}
|
||||
}
|
123
crates/ai2/src/embedding.rs
Normal file
123
crates/ai2/src/embedding.rs
Normal file
@ -0,0 +1,123 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use ordered_float::OrderedFloat;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
|
||||
use crate::auth::CredentialProvider;
|
||||
use crate::models::LanguageModel;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(pub Vec<f32>);
|
||||
|
||||
// This is needed for semantic index functionality
|
||||
// Unfortunately it has to live wherever the "Embedding" struct is created.
|
||||
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
|
||||
// which is less than ideal
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
Ok(Embedding(embedding.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
impl From<Vec<f32>> for Embedding {
|
||||
fn from(value: Vec<f32>) -> Self {
|
||||
Embedding(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
self.0.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
other.0.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
OrderedFloat(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
#[gpui2::test]
|
||||
fn test_similarity(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
Embedding::from(vec![1., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
|
||||
0.
|
||||
);
|
||||
assert_eq!(
|
||||
Embedding::from(vec![2., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
|
||||
6.
|
||||
);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
let a = Embedding::from(a);
|
||||
let b = Embedding::from(b);
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(a.similarity(&b), 1),
|
||||
round_to_decimals(reference_dot(&a.0, &b.0), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||
}
|
||||
}
|
||||
}
|
16
crates/ai2/src/models.rs
Normal file
16
crates/ai2/src/models.rs
Normal file
@ -0,0 +1,16 @@
|
||||
pub enum TruncationDirection {
|
||||
Start,
|
||||
End,
|
||||
}
|
||||
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
330
crates/ai2/src/prompts/base.rs
Normal file
330
crates/ai2/src/prompts/base.rs
Normal file
@ -0,0 +1,330 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use language2::BufferSnapshot;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||
|
||||
pub(crate) enum PromptFileType {
|
||||
Text,
|
||||
Code,
|
||||
}
|
||||
|
||||
// TODO: Set this up to manage for defaults well
|
||||
pub struct PromptArguments {
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
pub user_prompt: Option<String>,
|
||||
pub language_name: Option<String>,
|
||||
pub project_name: Option<String>,
|
||||
pub snippets: Vec<PromptCodeSnippet>,
|
||||
pub reserved_tokens: usize,
|
||||
pub buffer: Option<BufferSnapshot>,
|
||||
pub selected_range: Option<Range<usize>>,
|
||||
}
|
||||
|
||||
impl PromptArguments {
|
||||
pub(crate) fn get_file_type(&self) -> PromptFileType {
|
||||
if self
|
||||
.language_name
|
||||
.as_ref()
|
||||
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
PromptFileType::Code
|
||||
} else {
|
||||
PromptFileType::Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)>;
|
||||
}
|
||||
|
||||
#[repr(i8)]
|
||||
#[derive(PartialEq, Eq, Ord)]
|
||||
pub enum PromptPriority {
|
||||
Mandatory, // Ignores truncation
|
||||
Ordered { order: usize }, // Truncates based on priority
|
||||
}
|
||||
|
||||
impl PartialOrd for PromptPriority {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
match (self, other) {
|
||||
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
|
||||
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
|
||||
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
|
||||
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PromptChain {
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
}
|
||||
|
||||
impl PromptChain {
|
||||
pub fn new(
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
) -> Self {
|
||||
PromptChain { args, templates }
|
||||
}
|
||||
|
||||
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
|
||||
// Argsort based on Prompt Priority
|
||||
let seperator = "\n";
|
||||
let seperator_tokens = self.args.model.count_tokens(seperator)?;
|
||||
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
|
||||
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
|
||||
|
||||
// If Truncate
|
||||
let mut tokens_outstanding = if truncate {
|
||||
Some(self.args.model.capacity()? - self.args.reserved_tokens)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut prompts = vec!["".to_string(); sorted_indices.len()];
|
||||
for idx in sorted_indices {
|
||||
let (_, template) = &self.templates[idx];
|
||||
|
||||
if let Some((template_prompt, prompt_token_count)) =
|
||||
template.generate(&self.args, tokens_outstanding).log_err()
|
||||
{
|
||||
if template_prompt != "" {
|
||||
prompts[idx] = template_prompt;
|
||||
|
||||
if let Some(remaining_tokens) = tokens_outstanding {
|
||||
let new_tokens = prompt_token_count + seperator_tokens;
|
||||
tokens_outstanding = if remaining_tokens > new_tokens {
|
||||
Some(remaining_tokens - new_tokens)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompts.retain(|x| x != "");
|
||||
|
||||
let full_prompt = prompts.join(seperator);
|
||||
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
|
||||
anyhow::Ok((prompts.join(seperator), total_token_count))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::test::FakeLanguageModel;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
pub fn test_prompt_chain() {
|
||||
struct TestPromptTemplate {}
|
||||
impl PromptTemplate for TestPromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
struct TestLowPriorityTemplate {}
|
||||
impl PromptTemplate for TestLowPriorityTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a low priority test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let capacity = 20;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 2 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(prompt, "This is a test promp".to_string());
|
||||
assert_eq!(token_count, capacity);
|
||||
|
||||
// Change Ordering of Prompts Based on Priority
|
||||
let capacity = 120;
|
||||
let reserved_tokens = 10;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Mandatory,
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(token_count, capacity - reserved_tokens);
|
||||
}
|
||||
}
|
164
crates/ai2/src/prompts/file_context.rs
Normal file
164
crates/ai2/src/prompts/file_context.rs
Normal file
@ -0,0 +1,164 @@
|
||||
use anyhow::anyhow;
|
||||
use language2::BufferSnapshot;
|
||||
use language2::ToOffset;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::prompts::base::PromptArguments;
|
||||
use crate::prompts::base::PromptTemplate;
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn retrieve_context(
|
||||
buffer: &BufferSnapshot,
|
||||
selected_range: &Option<Range<usize>>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
max_token_count: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize, bool)> {
|
||||
let mut prompt = String::new();
|
||||
let mut truncated = false;
|
||||
if let Some(selected_range) = selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
let start_window = buffer.text_for_range(0..start).collect::<String>();
|
||||
|
||||
let mut selected_window = String::new();
|
||||
if start == end {
|
||||
write!(selected_window, "<|START|>").unwrap();
|
||||
} else {
|
||||
write!(selected_window, "<|START|").unwrap();
|
||||
}
|
||||
|
||||
write!(
|
||||
selected_window,
|
||||
"{}",
|
||||
buffer.text_for_range(start..end).collect::<String>()
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if start != end {
|
||||
write!(selected_window, "|END|>").unwrap();
|
||||
}
|
||||
|
||||
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
|
||||
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
let selected_tokens = model.count_tokens(&selected_window)?;
|
||||
if selected_tokens > max_token_count {
|
||||
return Err(anyhow!(
|
||||
"selected range is greater than model context window, truncation not possible"
|
||||
));
|
||||
};
|
||||
|
||||
let mut remaining_tokens = max_token_count - selected_tokens;
|
||||
let start_window_tokens = model.count_tokens(&start_window)?;
|
||||
let end_window_tokens = model.count_tokens(&end_window)?;
|
||||
let outside_tokens = start_window_tokens + end_window_tokens;
|
||||
if outside_tokens > remaining_tokens {
|
||||
let (start_goal_tokens, end_goal_tokens) =
|
||||
if start_window_tokens < end_window_tokens {
|
||||
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
|
||||
remaining_tokens -= start_goal_tokens;
|
||||
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
} else {
|
||||
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
|
||||
remaining_tokens -= end_goal_tokens;
|
||||
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
};
|
||||
|
||||
let truncated_start_window =
|
||||
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||
let truncated_end_window =
|
||||
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||
)
|
||||
.unwrap();
|
||||
truncated = true;
|
||||
} else {
|
||||
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
|
||||
}
|
||||
} else {
|
||||
// If we dont have a selected range, include entire file.
|
||||
writeln!(prompt, "{}", &buffer.text()).unwrap();
|
||||
|
||||
// Dumb truncation strategy
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
if model.count_tokens(&prompt)? > max_token_count {
|
||||
truncated = true;
|
||||
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let token_count = model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count, truncated))
|
||||
}
|
||||
|
||||
pub struct FileContext {}
|
||||
|
||||
impl PromptTemplate for FileContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
if let Some(buffer) = &args.buffer {
|
||||
let mut prompt = String::new();
|
||||
// Add Initial Preamble
|
||||
// TODO: Do we want to add the path in here?
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following content:"
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let language_name = args
|
||||
.language_name
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
.to_lowercase();
|
||||
|
||||
let (context, _, truncated) = retrieve_context(
|
||||
buffer,
|
||||
&args.selected_range,
|
||||
args.model.clone(),
|
||||
max_token_length,
|
||||
)?;
|
||||
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
|
||||
|
||||
if truncated {
|
||||
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
|
||||
}
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
if start == end {
|
||||
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args
|
||||
.model
|
||||
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
} else {
|
||||
Err(anyhow!("no buffer provided to retrieve file context from"))
|
||||
}
|
||||
}
|
||||
}
|
99
crates/ai2/src/prompts/generate.rs
Normal file
99
crates/ai2/src/prompts/generate.rs
Normal file
@ -0,0 +1,99 @@
|
||||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use anyhow::anyhow;
|
||||
use std::fmt::Write;
|
||||
|
||||
pub fn capitalize(s: &str) -> String {
|
||||
let mut c = s.chars();
|
||||
match c.next() {
|
||||
None => String::new(),
|
||||
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerateInlineContent {}
|
||||
|
||||
impl PromptTemplate for GenerateInlineContent {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let Some(user_prompt) = &args.user_prompt else {
|
||||
return Err(anyhow!("user prompt not provided"));
|
||||
};
|
||||
|
||||
let file_type = args.get_file_type();
|
||||
let content_type = match &file_type {
|
||||
PromptFileType::Code => "code",
|
||||
PromptFileType::Text => "text",
|
||||
};
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
if selected_range.start == selected_range.end {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
capitalize(content_type)
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
|
||||
}
|
||||
} else {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(language_name) = &args.language_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
match file_type {
|
||||
PromptFileType::Code => {
|
||||
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args.model.truncate(
|
||||
&prompt,
|
||||
max_tokens,
|
||||
crate::models::TruncationDirection::End,
|
||||
)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
5
crates/ai2/src/prompts/mod.rs
Normal file
5
crates/ai2/src/prompts/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
pub mod base;
|
||||
pub mod file_context;
|
||||
pub mod generate;
|
||||
pub mod preamble;
|
||||
pub mod repository_context;
|
52
crates/ai2/src/prompts/preamble.rs
Normal file
52
crates/ai2/src/prompts/preamble.rs
Normal file
@ -0,0 +1,52 @@
|
||||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
|
||||
pub struct EngineerPreamble {}
|
||||
|
||||
impl PromptTemplate for EngineerPreamble {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut prompts = Vec::new();
|
||||
|
||||
match args.get_file_type() {
|
||||
PromptFileType::Code => {
|
||||
prompts.push(format!(
|
||||
"You are an expert {}engineer.",
|
||||
args.language_name.clone().unwrap_or("".to_string()) + " "
|
||||
));
|
||||
}
|
||||
PromptFileType::Text => {
|
||||
prompts.push("You are an expert engineer.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project_name) = args.project_name.clone() {
|
||||
prompts.push(format!(
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(mut remaining_tokens) = max_token_length {
|
||||
let mut prompt = String::new();
|
||||
let mut total_count = 0;
|
||||
for prompt_piece in prompts {
|
||||
let prompt_token_count =
|
||||
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
|
||||
if remaining_tokens > prompt_token_count {
|
||||
writeln!(prompt, "{prompt_piece}").unwrap();
|
||||
remaining_tokens -= prompt_token_count;
|
||||
total_count += prompt_token_count;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((prompt, total_count))
|
||||
} else {
|
||||
let prompt = prompts.join("\n");
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
||||
}
|
98
crates/ai2/src/prompts/repository_context.rs
Normal file
98
crates/ai2/src/prompts/repository_context.rs
Normal file
@ -0,0 +1,98 @@
|
||||
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
use std::{ops::Range, path::PathBuf};
|
||||
|
||||
use gpui2::{AsyncAppContext, Model};
|
||||
use language2::{Anchor, Buffer};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl PromptCodeSnippet {
|
||||
pub fn new(
|
||||
buffer: Model<Buffer>,
|
||||
range: Range<Anchor>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let content = snapshot.text_for_range(range.clone()).collect::<String>();
|
||||
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.and_then(|language| Some(language.name().to_string().to_lowercase()));
|
||||
|
||||
let file_path = buffer
|
||||
.file()
|
||||
.and_then(|file| Some(file.path().to_path_buf()));
|
||||
|
||||
(content, language_name, file_path)
|
||||
})?;
|
||||
|
||||
anyhow::Ok(PromptCodeSnippet {
|
||||
path: file_path,
|
||||
language_name,
|
||||
content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for PromptCodeSnippet {
|
||||
fn to_string(&self) -> String {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.and_then(|path| Some(path.to_string_lossy().to_string()))
|
||||
.unwrap_or("".to_string());
|
||||
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||
let content = self.content.clone();
|
||||
|
||||
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RepositoryContext {}
|
||||
|
||||
impl PromptTemplate for RepositoryContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
|
||||
let mut prompt = String::new();
|
||||
|
||||
let mut remaining_tokens = max_token_length.clone();
|
||||
let seperator_token_length = args.model.count_tokens("\n")?;
|
||||
for snippet in &args.snippets {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
let content = snippet.to_string();
|
||||
writeln!(snippet_prompt, "{content}").unwrap();
|
||||
|
||||
let token_count = args.model.count_tokens(&snippet_prompt)?;
|
||||
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
if let Some(tokens_left) = remaining_tokens {
|
||||
if tokens_left >= token_count {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
|
||||
{
|
||||
Some(tokens_left - token_count - seperator_token_length)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
} else {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, total_token_count))
|
||||
}
|
||||
}
|
1
crates/ai2/src/providers/mod.rs
Normal file
1
crates/ai2/src/providers/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod open_ai;
|
306
crates/ai2/src/providers/open_ai/completion.rs
Normal file
306
crates/ai2/src/providers/open_ai/completion.rs
Normal file
@ -0,0 +1,306 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui2::{AppContext, Executor};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
env,
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
models::LanguageModel,
|
||||
};
|
||||
|
||||
use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl CompletionRequest for OpenAIRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAIResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAIUsage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
credential: ProviderCredential,
|
||||
executor: Arc<Executor>,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
|
||||
let api_key = match credential {
|
||||
ProviderCredential::Credentials { api_key } => api_key,
|
||||
_ => {
|
||||
return Err(anyhow!("no credentials provider for completion"));
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
|
||||
|
||||
let json_data = request.data()?;
|
||||
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAIResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(&data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIResponse {
|
||||
error: OpenAIError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAIResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAICompletionProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
executor: Arc<Executor>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionProvider {
|
||||
pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
|
||||
let model = OpenAILanguageModel::load(model_name);
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
Self {
|
||||
model,
|
||||
credential,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for OpenAICompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
|
||||
let retrieved_credential = cx
|
||||
.run_on_main(move |cx| match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return existing_credential.clone();
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
}
|
||||
|
||||
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
|
||||
{
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
|
||||
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
cx.run_on_main(move |cx| match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
async fn delete_credentials(&self, cx: &mut AppContext) {
|
||||
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
|
||||
.await;
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAICompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||
// which is currently model based, due to the langauge model.
|
||||
// At some point in the future we should rectify this.
|
||||
let credential = self.credential.read().clone();
|
||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
313
crates/ai2/src/providers/open_ai/embedding.rs
Normal file
313
crates/ai2/src/providers/open_ai/embedding.rs
Normal file
@ -0,0 +1,313 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::AsyncReadExt;
|
||||
use gpui2::Executor;
|
||||
use gpui2::{serde_json, AppContext};
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||
use crate::models::LanguageModel;
|
||||
use crate::providers::open_ai::OpenAILanguageModel;
|
||||
|
||||
use crate::providers::open_ai::OPENAI_API_URL;
|
||||
|
||||
lazy_static! {
|
||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddingProvider {
|
||||
model: OpenAILanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: Arc<Executor>,
|
||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAIEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingResponse {
|
||||
data: Vec<OpenAIEmbedding>,
|
||||
usage: OpenAIEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAIEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAIEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
impl OpenAIEmbeddingProvider {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
let model = OpenAILanguageModel::load("text-embedding-ada-002");
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
|
||||
OpenAIEmbeddingProvider {
|
||||
model,
|
||||
credential,
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
rate_limit_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_api_key(&self) -> Result<String> {
|
||||
match self.credential.read().clone() {
|
||||
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||
_ => Err(anyhow!("api credentials not provided")),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_rate_limit(&self) {
|
||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
if let Some(reset_time) = reset_time {
|
||||
if Instant::now() >= reset_time {
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"resolving reset time: {:?}",
|
||||
*self.rate_limit_count_tx.lock().borrow()
|
||||
);
|
||||
}
|
||||
|
||||
fn update_reset_time(&self, reset_time: Instant) {
|
||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
let updated_time = if let Some(original_time) = original_time {
|
||||
if reset_time < original_time {
|
||||
Some(reset_time)
|
||||
} else {
|
||||
Some(original_time)
|
||||
}
|
||||
} else {
|
||||
Some(reset_time)
|
||||
};
|
||||
|
||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAIEmbeddingRequest {
|
||||
input: spans.clone(),
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
Ok(self.client.send(request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for OpenAIEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
|
||||
let retrieved_credential = cx
|
||||
.run_on_main(move |cx| match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return existing_credential.clone();
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
}
|
||||
|
||||
if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
|
||||
{
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
return ProviderCredential::Credentials { api_key };
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
} else {
|
||||
return ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
|
||||
async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
cx.run_on_main(move |cx| match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
async fn delete_credentials(&self, cx: &mut AppContext) {
|
||||
cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
|
||||
.await;
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
*self.rate_limit_count_rx.borrow()
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
request_number += 1;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::REQUEST_TIMEOUT => {
|
||||
request_timeout += 5;
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::trace!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// If we complete a request successfully that was previously rate_limited
|
||||
// resolve the rate limit
|
||||
if rate_limiting {
|
||||
self.resolve_rate_limit()
|
||||
}
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
rate_limiting = true;
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
let delay_duration = {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
if let Some(time_to_reset) =
|
||||
response.headers().get("x-ratelimit-reset-tokens")
|
||||
{
|
||||
if let Ok(time_str) = time_to_reset.to_str() {
|
||||
parse(time_str).unwrap_or(delay)
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
};
|
||||
|
||||
// If we've previously rate limited, increment the duration but not the count
|
||||
let reset_time = Instant::now().add(delay_duration);
|
||||
self.update_reset_time(reset_time);
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
);
|
||||
|
||||
self.executor.timer(delay_duration).await;
|
||||
}
|
||||
_ => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"open ai bad request: {:?} {:?}",
|
||||
&response.status(),
|
||||
body
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
9
crates/ai2/src/providers/open_ai/mod.rs
Normal file
9
crates/ai2/src/providers/open_ai/mod.rs
Normal file
@ -0,0 +1,9 @@
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod model;
|
||||
|
||||
pub use completion::*;
|
||||
pub use embedding::*;
|
||||
pub use model::OpenAILanguageModel;
|
||||
|
||||
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
|
57
crates/ai2/src/providers/open_ai/model.rs
Normal file
57
crates/ai2/src/providers/open_ai/model.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use anyhow::anyhow;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::{LanguageModel, TruncationDirection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAILanguageModel {
|
||||
name: String,
|
||||
bpe: Option<CoreBPE>,
|
||||
}
|
||||
|
||||
impl OpenAILanguageModel {
|
||||
pub fn load(model_name: &str) -> Self {
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
|
||||
OpenAILanguageModel {
|
||||
name: model_name.to_string(),
|
||||
bpe,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAILanguageModel {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
match direction {
|
||||
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||
}
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||
}
|
||||
}
|
11
crates/ai2/src/providers/open_ai/new.rs
Normal file
11
crates/ai2/src/providers/open_ai/new.rs
Normal file
@ -0,0 +1,11 @@
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
193
crates/ai2/src/test.rs
Normal file
193
crates/ai2/src/test.rs
Normal file
@ -0,0 +1,193 @@
|
||||
use std::{
|
||||
sync::atomic::{self, AtomicUsize, Ordering},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui2::AppContext;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::{LanguageModel, TruncationDirection},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FakeLanguageModel {
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl LanguageModel for FakeLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
"dummy".to_string()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||
|
||||
if length > self.count_tokens(content)? {
|
||||
println!("NOT TRUNCATING");
|
||||
return anyhow::Ok(content.to_string());
|
||||
}
|
||||
|
||||
anyhow::Ok(match direction {
|
||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
})
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(self.capacity)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeEmbeddingProvider {
|
||||
pub embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl Clone for FakeEmbeddingProvider {
|
||||
fn clone(&self) -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FakeEmbeddingProvider {
|
||||
fn default() -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for FakeEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
||||
ProviderCredential::NotNeeded
|
||||
}
|
||||
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
||||
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||
}
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl Clone for FakeCompletionProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialProvider for FakeCompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
|
||||
ProviderCredential::NotNeeded
|
||||
}
|
||||
async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
|
||||
async fn delete_credentials(&self, _cx: &mut AppContext) {}
|
||||
}
|
||||
|
||||
impl CompletionProvider for FakeCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
|
||||
[dev-dependencies]
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
ai = { path = "../ai", features = ["test-support"]}
|
||||
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
@ -4,7 +4,7 @@ mod codegen;
|
||||
mod prompts;
|
||||
mod streaming_diff;
|
||||
|
||||
use ai::completion::Role;
|
||||
use ai::providers::open_ai::Role;
|
||||
use anyhow::Result;
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::OpenAIModel;
|
||||
|
@ -5,12 +5,14 @@ use crate::{
|
||||
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
|
||||
SavedMessage,
|
||||
};
|
||||
|
||||
use ai::{
|
||||
completion::{
|
||||
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
|
||||
},
|
||||
templates::repository_context::PromptCodeSnippet,
|
||||
auth::ProviderCredential,
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
|
||||
};
|
||||
|
||||
use ai::prompts::repository_context::PromptCodeSnippet;
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Local};
|
||||
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
|
||||
@ -43,8 +45,8 @@ use search::BufferSearchBar;
|
||||
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
cell::{Cell, RefCell},
|
||||
cmp, env,
|
||||
cell::Cell,
|
||||
cmp,
|
||||
fmt::Write,
|
||||
iter,
|
||||
ops::Range,
|
||||
@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
|
||||
cx.capture_action(ConversationEditor::copy);
|
||||
cx.add_action(ConversationEditor::split);
|
||||
cx.capture_action(ConversationEditor::cycle_message_role);
|
||||
cx.add_action(AssistantPanel::save_api_key);
|
||||
cx.add_action(AssistantPanel::reset_api_key);
|
||||
cx.add_action(AssistantPanel::save_credentials);
|
||||
cx.add_action(AssistantPanel::reset_credentials);
|
||||
cx.add_action(AssistantPanel::toggle_zoom);
|
||||
cx.add_action(AssistantPanel::deploy);
|
||||
cx.add_action(AssistantPanel::select_next_match);
|
||||
@ -140,9 +142,8 @@ pub struct AssistantPanel {
|
||||
zoomed: bool,
|
||||
has_focus: bool,
|
||||
toolbar: ViewHandle<Toolbar>,
|
||||
api_key: Rc<RefCell<Option<String>>>,
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
api_key_editor: Option<ViewHandle<Editor>>,
|
||||
has_read_credentials: bool,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
subscriptions: Vec<Subscription>,
|
||||
@ -202,6 +203,11 @@ impl AssistantPanel {
|
||||
});
|
||||
|
||||
let semantic_index = SemanticIndex::global(cx);
|
||||
// Defaulting currently to GPT4, allow for this to be set via config.
|
||||
let completion_provider = Box::new(OpenAICompletionProvider::new(
|
||||
"gpt-4",
|
||||
cx.background().clone(),
|
||||
));
|
||||
|
||||
let mut this = Self {
|
||||
workspace: workspace_handle,
|
||||
@ -213,9 +219,8 @@ impl AssistantPanel {
|
||||
zoomed: false,
|
||||
has_focus: false,
|
||||
toolbar,
|
||||
api_key: Rc::new(RefCell::new(None)),
|
||||
completion_provider,
|
||||
api_key_editor: None,
|
||||
has_read_credentials: false,
|
||||
languages: workspace.app_state().languages.clone(),
|
||||
fs: workspace.app_state().fs.clone(),
|
||||
width: None,
|
||||
@ -254,10 +259,7 @@ impl AssistantPanel {
|
||||
cx: &mut ViewContext<Workspace>,
|
||||
) {
|
||||
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if this
|
||||
.update(cx, |assistant, cx| assistant.load_api_key(cx))
|
||||
.is_some()
|
||||
{
|
||||
if this.update(cx, |assistant, _| assistant.has_credentials()) {
|
||||
this
|
||||
} else {
|
||||
workspace.focus_panel::<AssistantPanel>(cx);
|
||||
@ -289,12 +291,6 @@ impl AssistantPanel {
|
||||
cx: &mut ViewContext<Self>,
|
||||
project: &ModelHandle<Project>,
|
||||
) {
|
||||
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
|
||||
api_key
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
let selection = editor.read(cx).selections.newest_anchor().clone();
|
||||
if selection.start.excerpt_id != selection.end.excerpt_id {
|
||||
return;
|
||||
@ -325,10 +321,13 @@ impl AssistantPanel {
|
||||
|
||||
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
|
||||
let provider = Arc::new(OpenAICompletionProvider::new(
|
||||
api_key,
|
||||
"gpt-4",
|
||||
cx.background().clone(),
|
||||
));
|
||||
|
||||
// Retrieve Credentials Authenticates the Provider
|
||||
// provider.retrieve_credentials(cx);
|
||||
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
|
||||
});
|
||||
@ -745,13 +744,14 @@ impl AssistantPanel {
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
let request = OpenAIRequest {
|
||||
let request = Box::new(OpenAIRequest {
|
||||
model: model.full_name().into(),
|
||||
messages,
|
||||
stream: true,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
};
|
||||
});
|
||||
|
||||
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
|
||||
anyhow::Ok(())
|
||||
})
|
||||
@ -811,7 +811,7 @@ impl AssistantPanel {
|
||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
|
||||
let editor = cx.add_view(|cx| {
|
||||
ConversationEditor::new(
|
||||
self.api_key.clone(),
|
||||
self.completion_provider.clone(),
|
||||
self.languages.clone(),
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
@ -870,17 +870,19 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
if let Some(api_key) = self
|
||||
.api_key_editor
|
||||
.as_ref()
|
||||
.map(|editor| editor.read(cx).text(cx))
|
||||
{
|
||||
if !api_key.is_empty() {
|
||||
cx.platform()
|
||||
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
|
||||
.log_err();
|
||||
*self.api_key.borrow_mut() = Some(api_key);
|
||||
let credential = ProviderCredential::Credentials {
|
||||
api_key: api_key.clone(),
|
||||
};
|
||||
|
||||
self.completion_provider.save_credentials(cx, credential);
|
||||
|
||||
self.api_key_editor.take();
|
||||
cx.focus_self();
|
||||
cx.notify();
|
||||
@ -890,9 +892,8 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||
cx.platform().delete_credentials(OPENAI_API_URL).log_err();
|
||||
self.api_key.take();
|
||||
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
|
||||
self.completion_provider.delete_credentials(cx);
|
||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
||||
cx.focus_self();
|
||||
cx.notify();
|
||||
@ -1151,13 +1152,12 @@ impl AssistantPanel {
|
||||
|
||||
let fs = self.fs.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let languages = self.languages.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let saved_conversation = fs.load(&path).await?;
|
||||
let saved_conversation = serde_json::from_str(&saved_conversation)?;
|
||||
let conversation = cx.add_model(|cx| {
|
||||
Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
|
||||
Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
|
||||
});
|
||||
this.update(&mut cx, |this, cx| {
|
||||
// If, by the time we've loaded the conversation, the user has already opened
|
||||
@ -1181,30 +1181,12 @@ impl AssistantPanel {
|
||||
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
|
||||
}
|
||||
|
||||
fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
|
||||
if self.api_key.borrow().is_none() && !self.has_read_credentials {
|
||||
self.has_read_credentials = true;
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
Some(api_key)
|
||||
} else if let Some((_, api_key)) = cx
|
||||
.platform()
|
||||
.read_credentials(OPENAI_API_URL)
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
String::from_utf8(api_key).log_err()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(api_key) = api_key {
|
||||
*self.api_key.borrow_mut() = Some(api_key);
|
||||
} else if self.api_key_editor.is_none() {
|
||||
self.api_key_editor = Some(build_api_key_editor(cx));
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
fn has_credentials(&mut self) -> bool {
|
||||
self.completion_provider.has_credentials()
|
||||
}
|
||||
|
||||
self.api_key.borrow().clone()
|
||||
fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
|
||||
self.completion_provider.retrieve_credentials(cx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
|
||||
|
||||
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
|
||||
if active {
|
||||
self.load_api_key(cx);
|
||||
self.load_credentials(cx);
|
||||
|
||||
if self.editors.is_empty() {
|
||||
self.new_conversation(cx);
|
||||
@ -1454,10 +1436,10 @@ struct Conversation {
|
||||
token_count: Option<usize>,
|
||||
max_token_count: usize,
|
||||
pending_token_count: Task<Option<()>>,
|
||||
api_key: Rc<RefCell<Option<String>>>,
|
||||
pending_save: Task<Result<()>>,
|
||||
path: Option<PathBuf>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
}
|
||||
|
||||
impl Entity for Conversation {
|
||||
@ -1466,9 +1448,9 @@ impl Entity for Conversation {
|
||||
|
||||
impl Conversation {
|
||||
fn new(
|
||||
api_key: Rc<RefCell<Option<String>>>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
) -> Self {
|
||||
let markdown = language_registry.language_for_name("Markdown");
|
||||
let buffer = cx.add_model(|cx| {
|
||||
@ -1507,8 +1489,8 @@ impl Conversation {
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
pending_save: Task::ready(Ok(())),
|
||||
path: None,
|
||||
api_key,
|
||||
buffer,
|
||||
completion_provider,
|
||||
};
|
||||
let message = MessageAnchor {
|
||||
id: MessageId(post_inc(&mut this.next_message_id.0)),
|
||||
@ -1554,7 +1536,6 @@ impl Conversation {
|
||||
fn deserialize(
|
||||
saved_conversation: SavedConversation,
|
||||
path: PathBuf,
|
||||
api_key: Rc<RefCell<Option<String>>>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
@ -1563,6 +1544,10 @@ impl Conversation {
|
||||
None => Some(Uuid::new_v4().to_string()),
|
||||
};
|
||||
let model = saved_conversation.model;
|
||||
let completion_provider: Box<dyn CompletionProvider> = Box::new(
|
||||
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
|
||||
);
|
||||
completion_provider.retrieve_credentials(cx);
|
||||
let markdown = language_registry.language_for_name("Markdown");
|
||||
let mut message_anchors = Vec::new();
|
||||
let mut next_message_id = MessageId(0);
|
||||
@ -1609,8 +1594,8 @@ impl Conversation {
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
pending_save: Task::ready(Ok(())),
|
||||
path: Some(path),
|
||||
api_key,
|
||||
buffer,
|
||||
completion_provider,
|
||||
};
|
||||
this.count_remaining_tokens(cx);
|
||||
this
|
||||
@ -1731,11 +1716,11 @@ impl Conversation {
|
||||
}
|
||||
|
||||
if should_assist {
|
||||
let Some(api_key) = self.api_key.borrow().clone() else {
|
||||
if !self.completion_provider.has_credentials() {
|
||||
return Default::default();
|
||||
};
|
||||
}
|
||||
|
||||
let request = OpenAIRequest {
|
||||
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||
model: self.model.full_name().to_string(),
|
||||
messages: self
|
||||
.messages(cx)
|
||||
@ -1745,9 +1730,9 @@ impl Conversation {
|
||||
stream: true,
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
};
|
||||
});
|
||||
|
||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||
let stream = self.completion_provider.complete(request);
|
||||
let assistant_message = self
|
||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||
.unwrap();
|
||||
@ -1765,33 +1750,28 @@ impl Conversation {
|
||||
let mut messages = stream.await?;
|
||||
|
||||
while let Some(message) = messages.next().await {
|
||||
let mut message = message?;
|
||||
if let Some(choice) = message.choices.pop() {
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("conversation was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
let text: Arc<str> = choice.delta.content?.into();
|
||||
let message_ix =
|
||||
this.message_anchors.iter().position(|message| {
|
||||
message.id == assistant_message_id
|
||||
})?;
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
let offset = this.message_anchors[message_ix + 1..]
|
||||
.iter()
|
||||
.find(|message| message.start.is_valid(buffer))
|
||||
.map_or(buffer.len(), |message| {
|
||||
message
|
||||
.start
|
||||
.to_offset(buffer)
|
||||
.saturating_sub(1)
|
||||
});
|
||||
buffer.edit([(offset..offset, text)], None, cx);
|
||||
});
|
||||
cx.emit(ConversationEvent::StreamedCompletion);
|
||||
let text = message?;
|
||||
|
||||
Some(())
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("conversation was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
let message_ix = this
|
||||
.message_anchors
|
||||
.iter()
|
||||
.position(|message| message.id == assistant_message_id)?;
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
let offset = this.message_anchors[message_ix + 1..]
|
||||
.iter()
|
||||
.find(|message| message.start.is_valid(buffer))
|
||||
.map_or(buffer.len(), |message| {
|
||||
message.start.to_offset(buffer).saturating_sub(1)
|
||||
});
|
||||
buffer.edit([(offset..offset, text)], None, cx);
|
||||
});
|
||||
}
|
||||
cx.emit(ConversationEvent::StreamedCompletion);
|
||||
|
||||
Some(())
|
||||
});
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
|
||||
@ -2013,57 +1993,54 @@ impl Conversation {
|
||||
|
||||
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
|
||||
if self.message_anchors.len() >= 2 && self.summary.is_none() {
|
||||
let api_key = self.api_key.borrow().clone();
|
||||
if let Some(api_key) = api_key {
|
||||
let messages = self
|
||||
.messages(cx)
|
||||
.take(2)
|
||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||
.chain(Some(RequestMessage {
|
||||
role: Role::User,
|
||||
content:
|
||||
"Summarize the conversation into a short title without punctuation"
|
||||
.into(),
|
||||
}));
|
||||
let request = OpenAIRequest {
|
||||
model: self.model.full_name().to_string(),
|
||||
messages: messages.collect(),
|
||||
stream: true,
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
};
|
||||
|
||||
let stream = stream_completion(api_key, cx.background().clone(), request);
|
||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let mut messages = stream.await?;
|
||||
|
||||
while let Some(message) = messages.next().await {
|
||||
let mut message = message?;
|
||||
if let Some(choice) = message.choices.pop() {
|
||||
let text = choice.delta.content.unwrap_or_default();
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.summary
|
||||
.get_or_insert(Default::default())
|
||||
.text
|
||||
.push_str(&text);
|
||||
cx.emit(ConversationEvent::SummaryChanged);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Some(summary) = this.summary.as_mut() {
|
||||
summary.done = true;
|
||||
cx.emit(ConversationEvent::SummaryChanged);
|
||||
}
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
if !self.completion_provider.has_credentials() {
|
||||
return;
|
||||
}
|
||||
|
||||
let messages = self
|
||||
.messages(cx)
|
||||
.take(2)
|
||||
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
|
||||
.chain(Some(RequestMessage {
|
||||
role: Role::User,
|
||||
content: "Summarize the conversation into a short title without punctuation"
|
||||
.into(),
|
||||
}));
|
||||
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
|
||||
model: self.model.full_name().to_string(),
|
||||
messages: messages.collect(),
|
||||
stream: true,
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
});
|
||||
|
||||
let stream = self.completion_provider.complete(request);
|
||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let mut messages = stream.await?;
|
||||
|
||||
while let Some(message) = messages.next().await {
|
||||
let text = message?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.summary
|
||||
.get_or_insert(Default::default())
|
||||
.text
|
||||
.push_str(&text);
|
||||
cx.emit(ConversationEvent::SummaryChanged);
|
||||
});
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let Some(summary) = this.summary.as_mut() {
|
||||
summary.done = true;
|
||||
cx.emit(ConversationEvent::SummaryChanged);
|
||||
}
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -2224,13 +2201,14 @@ struct ConversationEditor {
|
||||
|
||||
impl ConversationEditor {
|
||||
fn new(
|
||||
api_key: Rc<RefCell<Option<String>>>,
|
||||
completion_provider: Box<dyn CompletionProvider>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakViewHandle<Workspace>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
|
||||
let conversation =
|
||||
cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
|
||||
Self::for_conversation(conversation, fs, workspace, cx)
|
||||
}
|
||||
|
||||
@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::MessageId;
|
||||
use ai::test::FakeCompletionProvider;
|
||||
use gpui::AppContext;
|
||||
|
||||
#[gpui::test]
|
||||
@ -3426,7 +3405,9 @@ mod tests {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
||||
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||
@ -3554,7 +3535,9 @@ mod tests {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
|
||||
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||
@ -3650,7 +3633,8 @@ mod tests {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
|
||||
let message_1 = conversation.read(cx).message_anchors[0].clone();
|
||||
@ -3732,8 +3716,9 @@ mod tests {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let completion_provider = Box::new(FakeCompletionProvider::new());
|
||||
let conversation =
|
||||
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
|
||||
cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
let message_0 = conversation.read(cx).message_anchors[0].id;
|
||||
let message_1 = conversation.update(cx, |conversation, cx| {
|
||||
@ -3770,7 +3755,6 @@ mod tests {
|
||||
Conversation::deserialize(
|
||||
conversation.read(cx).serialize(cx),
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
registry.clone(),
|
||||
cx,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||
use ai::completion::{CompletionProvider, OpenAIRequest};
|
||||
use ai::completion::{CompletionProvider, CompletionRequest};
|
||||
use anyhow::Result;
|
||||
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
@ -96,7 +96,7 @@ impl Codegen {
|
||||
self.error.as_ref()
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
|
||||
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
|
||||
let range = self.range();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use futures::{
|
||||
future::BoxFuture,
|
||||
stream::{self, BoxStream},
|
||||
};
|
||||
use ai::test::FakeCompletionProvider;
|
||||
use futures::stream::{self};
|
||||
use gpui::{executor::Deterministic, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use serde::Serialize;
|
||||
use settings::SettingsStore;
|
||||
use smol::future::FutureExt;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct DummyCompletionRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl CompletionRequest for DummyCompletionRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(
|
||||
@ -372,7 +380,7 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
@ -381,7 +389,11 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
" let mut x = 0;\n",
|
||||
@ -434,7 +446,7 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
@ -443,7 +455,11 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"t mut x = 0;\n",
|
||||
@ -496,7 +512,7 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))
|
||||
});
|
||||
let provider = Arc::new(TestCompletionProvider::new());
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
@ -505,7 +521,11 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"let mut x = 0;\n",
|
||||
@ -593,38 +613,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
struct TestCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl TestCompletionProvider {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for TestCompletionProvider {
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: OpenAIRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
|
@ -1,9 +1,10 @@
|
||||
use ai::models::{LanguageModel, OpenAILanguageModel};
|
||||
use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||
use ai::templates::file_context::FileContext;
|
||||
use ai::templates::generate::GenerateInlineContent;
|
||||
use ai::templates::preamble::EngineerPreamble;
|
||||
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||
use ai::models::LanguageModel;
|
||||
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||
use ai::prompts::file_context::FileContext;
|
||||
use ai::prompts::generate::GenerateInlineContent;
|
||||
use ai::prompts::preamble::EngineerPreamble;
|
||||
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||
use ai::providers::open_ai::OpenAILanguageModel;
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use std::cmp::{self, Reverse};
|
||||
use std::ops::Range;
|
||||
|
@ -25,7 +25,7 @@ collections = { path = "../collections" }
|
||||
gpui2 = { path = "../gpui2" }
|
||||
log.workspace = true
|
||||
live_kit_client = { path = "../live_kit_client" }
|
||||
fs = { path = "../fs" }
|
||||
fs2 = { path = "../fs2" }
|
||||
language2 = { path = "../language2" }
|
||||
media = { path = "../media" }
|
||||
project2 = { path = "../project2" }
|
||||
@ -43,7 +43,7 @@ serde_derive.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client2 = { path = "../client2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
fs2 = { path = "../fs2", features = ["test-support"] }
|
||||
language2 = { path = "../language2", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
||||
|
@ -12,8 +12,8 @@ use client2::{
|
||||
use collections::HashSet;
|
||||
use futures::{future::Shared, FutureExt};
|
||||
use gpui2::{
|
||||
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Subscription, Task,
|
||||
WeakHandle,
|
||||
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
|
||||
WeakModel,
|
||||
};
|
||||
use postage::watch;
|
||||
use project2::Project;
|
||||
@ -23,10 +23,10 @@ use std::sync::Arc;
|
||||
pub use participant::ParticipantLocation;
|
||||
pub use room::Room;
|
||||
|
||||
pub fn init(client: Arc<Client>, user_store: Handle<UserStore>, cx: &mut AppContext) {
|
||||
pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
|
||||
CallSettings::register(cx);
|
||||
|
||||
let active_call = cx.entity(|cx| ActiveCall::new(client, user_store, cx));
|
||||
let active_call = cx.build_model(|cx| ActiveCall::new(client, user_store, cx));
|
||||
cx.set_global(active_call);
|
||||
}
|
||||
|
||||
@ -40,16 +40,16 @@ pub struct IncomingCall {
|
||||
|
||||
/// Singleton global maintaining the user's participation in a room across workspaces.
|
||||
pub struct ActiveCall {
|
||||
room: Option<(Handle<Room>, Vec<Subscription>)>,
|
||||
pending_room_creation: Option<Shared<Task<Result<Handle<Room>, Arc<anyhow::Error>>>>>,
|
||||
location: Option<WeakHandle<Project>>,
|
||||
room: Option<(Model<Room>, Vec<Subscription>)>,
|
||||
pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
|
||||
location: Option<WeakModel<Project>>,
|
||||
pending_invites: HashSet<u64>,
|
||||
incoming_call: (
|
||||
watch::Sender<Option<IncomingCall>>,
|
||||
watch::Receiver<Option<IncomingCall>>,
|
||||
),
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
user_store: Model<UserStore>,
|
||||
_subscriptions: Vec<client2::Subscription>,
|
||||
}
|
||||
|
||||
@ -58,11 +58,7 @@ impl EventEmitter for ActiveCall {
|
||||
}
|
||||
|
||||
impl ActiveCall {
|
||||
fn new(
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
fn new(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut ModelContext<Self>) -> Self {
|
||||
Self {
|
||||
room: None,
|
||||
pending_room_creation: None,
|
||||
@ -84,7 +80,7 @@ impl ActiveCall {
|
||||
}
|
||||
|
||||
async fn handle_incoming_call(
|
||||
this: Handle<Self>,
|
||||
this: Model<Self>,
|
||||
envelope: TypedEnvelope<proto::IncomingCall>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
@ -112,7 +108,7 @@ impl ActiveCall {
|
||||
}
|
||||
|
||||
async fn handle_call_canceled(
|
||||
this: Handle<Self>,
|
||||
this: Model<Self>,
|
||||
envelope: TypedEnvelope<proto::CallCanceled>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
@ -129,14 +125,14 @@ impl ActiveCall {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn global(cx: &AppContext) -> Handle<Self> {
|
||||
cx.global::<Handle<Self>>().clone()
|
||||
pub fn global(cx: &AppContext) -> Model<Self> {
|
||||
cx.global::<Model<Self>>().clone()
|
||||
}
|
||||
|
||||
pub fn invite(
|
||||
&mut self,
|
||||
called_user_id: u64,
|
||||
initial_project: Option<Handle<Project>>,
|
||||
initial_project: Option<Model<Project>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if !self.pending_invites.insert(called_user_id) {
|
||||
@ -291,7 +287,7 @@ impl ActiveCall {
|
||||
&mut self,
|
||||
channel_id: u64,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Handle<Room>>> {
|
||||
) -> Task<Result<Model<Room>>> {
|
||||
if let Some(room) = self.room().cloned() {
|
||||
if room.read(cx).channel_id() == Some(channel_id) {
|
||||
return Task::ready(Ok(room));
|
||||
@ -327,7 +323,7 @@ impl ActiveCall {
|
||||
|
||||
pub fn share_project(
|
||||
&mut self,
|
||||
project: Handle<Project>,
|
||||
project: Model<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<u64>> {
|
||||
if let Some((room, _)) = self.room.as_ref() {
|
||||
@ -340,7 +336,7 @@ impl ActiveCall {
|
||||
|
||||
pub fn unshare_project(
|
||||
&mut self,
|
||||
project: Handle<Project>,
|
||||
project: Model<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<()> {
|
||||
if let Some((room, _)) = self.room.as_ref() {
|
||||
@ -351,13 +347,13 @@ impl ActiveCall {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn location(&self) -> Option<&WeakHandle<Project>> {
|
||||
pub fn location(&self) -> Option<&WeakModel<Project>> {
|
||||
self.location.as_ref()
|
||||
}
|
||||
|
||||
pub fn set_location(
|
||||
&mut self,
|
||||
project: Option<&Handle<Project>>,
|
||||
project: Option<&Model<Project>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
|
||||
@ -371,7 +367,7 @@ impl ActiveCall {
|
||||
|
||||
fn set_room(
|
||||
&mut self,
|
||||
room: Option<Handle<Room>>,
|
||||
room: Option<Model<Room>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if room.as_ref() != self.room.as_ref().map(|room| &room.0) {
|
||||
@ -407,7 +403,7 @@ impl ActiveCall {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn room(&self) -> Option<&Handle<Room>> {
|
||||
pub fn room(&self) -> Option<&Model<Room>> {
|
||||
self.room.as_ref().map(|(room, _)| room)
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use client2::ParticipantIndex;
|
||||
use client2::{proto, User};
|
||||
use gpui2::WeakHandle;
|
||||
use gpui2::WeakModel;
|
||||
pub use live_kit_client::Frame;
|
||||
use project2::Project;
|
||||
use std::{fmt, sync::Arc};
|
||||
@ -33,7 +33,7 @@ impl ParticipantLocation {
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LocalParticipant {
|
||||
pub projects: Vec<proto::ParticipantProject>,
|
||||
pub active_project: Option<WeakHandle<Project>>,
|
||||
pub active_project: Option<WeakModel<Project>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -13,10 +13,10 @@ use client2::{
|
||||
Client, ParticipantIndex, TypedEnvelope, User, UserStore,
|
||||
};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use fs::Fs;
|
||||
use fs2::Fs;
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use gpui2::{
|
||||
AppContext, AsyncAppContext, Context, EventEmitter, Handle, ModelContext, Task, WeakHandle,
|
||||
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
|
||||
};
|
||||
use language2::LanguageRegistry;
|
||||
use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate};
|
||||
@ -61,8 +61,8 @@ pub struct Room {
|
||||
channel_id: Option<u64>,
|
||||
// live_kit: Option<LiveKitRoom>,
|
||||
status: RoomStatus,
|
||||
shared_projects: HashSet<WeakHandle<Project>>,
|
||||
joined_projects: HashSet<WeakHandle<Project>>,
|
||||
shared_projects: HashSet<WeakModel<Project>>,
|
||||
joined_projects: HashSet<WeakModel<Project>>,
|
||||
local_participant: LocalParticipant,
|
||||
remote_participants: BTreeMap<u64, RemoteParticipant>,
|
||||
pending_participants: Vec<Arc<User>>,
|
||||
@ -70,7 +70,7 @@ pub struct Room {
|
||||
pending_call_count: usize,
|
||||
leave_when_empty: bool,
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
user_store: Model<UserStore>,
|
||||
follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>,
|
||||
client_subscriptions: Vec<client2::Subscription>,
|
||||
_subscriptions: Vec<gpui2::Subscription>,
|
||||
@ -111,7 +111,7 @@ impl Room {
|
||||
channel_id: Option<u64>,
|
||||
live_kit_connection_info: Option<proto::LiveKitConnectionInfo>,
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
user_store: Model<UserStore>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
todo!()
|
||||
@ -237,15 +237,15 @@ impl Room {
|
||||
|
||||
pub(crate) fn create(
|
||||
called_user_id: u64,
|
||||
initial_project: Option<Handle<Project>>,
|
||||
initial_project: Option<Model<Project>>,
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
user_store: Model<UserStore>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Handle<Self>>> {
|
||||
) -> Task<Result<Model<Self>>> {
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let response = client.request(proto::CreateRoom {}).await?;
|
||||
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
|
||||
let room = cx.entity(|cx| {
|
||||
let room = cx.build_model(|cx| {
|
||||
Self::new(
|
||||
room_proto.id,
|
||||
None,
|
||||
@ -283,9 +283,9 @@ impl Room {
|
||||
pub(crate) fn join_channel(
|
||||
channel_id: u64,
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
user_store: Model<UserStore>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Handle<Self>>> {
|
||||
) -> Task<Result<Model<Self>>> {
|
||||
cx.spawn(move |cx| async move {
|
||||
Self::from_join_response(
|
||||
client.request(proto::JoinChannel { channel_id }).await?,
|
||||
@ -299,9 +299,9 @@ impl Room {
|
||||
pub(crate) fn join(
|
||||
call: &IncomingCall,
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
user_store: Model<UserStore>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Handle<Self>>> {
|
||||
) -> Task<Result<Model<Self>>> {
|
||||
let id = call.room_id;
|
||||
cx.spawn(move |cx| async move {
|
||||
Self::from_join_response(
|
||||
@ -343,11 +343,11 @@ impl Room {
|
||||
fn from_join_response(
|
||||
response: proto::JoinRoomResponse,
|
||||
client: Arc<Client>,
|
||||
user_store: Handle<UserStore>,
|
||||
user_store: Model<UserStore>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Handle<Self>> {
|
||||
) -> Result<Model<Self>> {
|
||||
let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?;
|
||||
let room = cx.entity(|cx| {
|
||||
let room = cx.build_model(|cx| {
|
||||
Self::new(
|
||||
room_proto.id,
|
||||
response.channel_id,
|
||||
@ -424,7 +424,7 @@ impl Room {
|
||||
}
|
||||
|
||||
async fn maintain_connection(
|
||||
this: WeakHandle<Self>,
|
||||
this: WeakModel<Self>,
|
||||
client: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
@ -661,7 +661,7 @@ impl Room {
|
||||
}
|
||||
|
||||
async fn handle_room_updated(
|
||||
this: Handle<Self>,
|
||||
this: Model<Self>,
|
||||
envelope: TypedEnvelope<proto::RoomUpdated>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
@ -1101,7 +1101,7 @@ impl Room {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Handle<Project>>> {
|
||||
) -> Task<Result<Model<Project>>> {
|
||||
let client = self.client.clone();
|
||||
let user_store = self.user_store.clone();
|
||||
cx.emit(Event::RemoteProjectJoined { project_id: id });
|
||||
@ -1125,7 +1125,7 @@ impl Room {
|
||||
|
||||
pub(crate) fn share_project(
|
||||
&mut self,
|
||||
project: Handle<Project>,
|
||||
project: Model<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<u64>> {
|
||||
if let Some(project_id) = project.read(cx).remote_id() {
|
||||
@ -1161,7 +1161,7 @@ impl Room {
|
||||
|
||||
pub(crate) fn unshare_project(
|
||||
&mut self,
|
||||
project: Handle<Project>,
|
||||
project: Model<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<()> {
|
||||
let project_id = match project.read(cx).remote_id() {
|
||||
@ -1175,7 +1175,7 @@ impl Room {
|
||||
|
||||
pub(crate) fn set_location(
|
||||
&mut self,
|
||||
project: Option<&Handle<Project>>,
|
||||
project: Option<&Model<Project>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if self.status.is_offline() {
|
||||
|
@ -14,8 +14,8 @@ use futures::{
|
||||
future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt,
|
||||
};
|
||||
use gpui2::{
|
||||
serde_json, AnyHandle, AnyWeakHandle, AppContext, AsyncAppContext, Handle, SemanticVersion,
|
||||
Task, WeakHandle,
|
||||
serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task,
|
||||
WeakModel,
|
||||
};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::RwLock;
|
||||
@ -227,7 +227,7 @@ struct ClientState {
|
||||
_reconnect_task: Option<Task<()>>,
|
||||
reconnect_interval: Duration,
|
||||
entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
|
||||
models_by_message_type: HashMap<TypeId, AnyWeakHandle>,
|
||||
models_by_message_type: HashMap<TypeId, AnyWeakModel>,
|
||||
entity_types_by_message_type: HashMap<TypeId, TypeId>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
message_handlers: HashMap<
|
||||
@ -236,7 +236,7 @@ struct ClientState {
|
||||
dyn Send
|
||||
+ Sync
|
||||
+ Fn(
|
||||
AnyHandle,
|
||||
AnyModel,
|
||||
Box<dyn AnyTypedEnvelope>,
|
||||
&Arc<Client>,
|
||||
AsyncAppContext,
|
||||
@ -246,7 +246,7 @@ struct ClientState {
|
||||
}
|
||||
|
||||
enum WeakSubscriber {
|
||||
Entity { handle: AnyWeakHandle },
|
||||
Entity { handle: AnyWeakModel },
|
||||
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
|
||||
}
|
||||
|
||||
@ -314,7 +314,7 @@ impl<T> PendingEntitySubscription<T>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
pub fn set_model(mut self, model: &Handle<T>, cx: &mut AsyncAppContext) -> Subscription {
|
||||
pub fn set_model(mut self, model: &Model<T>, cx: &mut AsyncAppContext) -> Subscription {
|
||||
self.consumed = true;
|
||||
let mut state = self.client.state.write();
|
||||
let id = (TypeId::of::<T>(), self.remote_id);
|
||||
@ -552,13 +552,13 @@ impl Client {
|
||||
#[track_caller]
|
||||
pub fn add_message_handler<M, E, H, F>(
|
||||
self: &Arc<Self>,
|
||||
entity: WeakHandle<E>,
|
||||
entity: WeakModel<E>,
|
||||
handler: H,
|
||||
) -> Subscription
|
||||
where
|
||||
M: EnvelopedMessage,
|
||||
E: 'static + Send,
|
||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
F: 'static + Future<Output = Result<()>> + Send,
|
||||
{
|
||||
let message_type_id = TypeId::of::<M>();
|
||||
@ -594,13 +594,13 @@ impl Client {
|
||||
|
||||
pub fn add_request_handler<M, E, H, F>(
|
||||
self: &Arc<Self>,
|
||||
model: WeakHandle<E>,
|
||||
model: WeakModel<E>,
|
||||
handler: H,
|
||||
) -> Subscription
|
||||
where
|
||||
M: RequestMessage,
|
||||
E: 'static + Send,
|
||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
F: 'static + Future<Output = Result<M::Response>> + Send,
|
||||
{
|
||||
self.add_message_handler(model, move |handle, envelope, this, cx| {
|
||||
@ -616,7 +616,7 @@ impl Client {
|
||||
where
|
||||
M: EntityMessage,
|
||||
E: 'static + Send,
|
||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
F: 'static + Future<Output = Result<()>> + Send,
|
||||
{
|
||||
self.add_entity_message_handler::<M, E, _, _>(move |subscriber, message, client, cx| {
|
||||
@ -628,7 +628,7 @@ impl Client {
|
||||
where
|
||||
M: EntityMessage,
|
||||
E: 'static + Send,
|
||||
H: 'static + Send + Sync + Fn(AnyHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
F: 'static + Future<Output = Result<()>> + Send,
|
||||
{
|
||||
let model_type_id = TypeId::of::<E>();
|
||||
@ -667,7 +667,7 @@ impl Client {
|
||||
where
|
||||
M: EntityMessage + RequestMessage,
|
||||
E: 'static + Send,
|
||||
H: 'static + Send + Sync + Fn(Handle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
H: 'static + Send + Sync + Fn(Model<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
|
||||
F: 'static + Future<Output = Result<M::Response>> + Send,
|
||||
{
|
||||
self.add_model_message_handler(move |entity, envelope, client, cx| {
|
||||
@ -1546,7 +1546,7 @@ mod tests {
|
||||
let (done_tx1, mut done_rx1) = smol::channel::unbounded();
|
||||
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
|
||||
client.add_model_message_handler(
|
||||
move |model: Handle<Model>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
|
||||
move |model: Model<TestModel>, _: TypedEnvelope<proto::JoinProject>, _, mut cx| {
|
||||
match model.update(&mut cx, |model, _| model.id).unwrap() {
|
||||
1 => done_tx1.try_send(()).unwrap(),
|
||||
2 => done_tx2.try_send(()).unwrap(),
|
||||
@ -1555,15 +1555,15 @@ mod tests {
|
||||
async { Ok(()) }
|
||||
},
|
||||
);
|
||||
let model1 = cx.entity(|_| Model {
|
||||
let model1 = cx.build_model(|_| TestModel {
|
||||
id: 1,
|
||||
subscription: None,
|
||||
});
|
||||
let model2 = cx.entity(|_| Model {
|
||||
let model2 = cx.build_model(|_| TestModel {
|
||||
id: 2,
|
||||
subscription: None,
|
||||
});
|
||||
let model3 = cx.entity(|_| Model {
|
||||
let model3 = cx.build_model(|_| TestModel {
|
||||
id: 3,
|
||||
subscription: None,
|
||||
});
|
||||
@ -1596,7 +1596,7 @@ mod tests {
|
||||
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
|
||||
let server = FakeServer::for_client(user_id, &client, cx).await;
|
||||
|
||||
let model = cx.entity(|_| Model::default());
|
||||
let model = cx.build_model(|_| TestModel::default());
|
||||
let (done_tx1, _done_rx1) = smol::channel::unbounded();
|
||||
let (done_tx2, mut done_rx2) = smol::channel::unbounded();
|
||||
let subscription1 = client.add_message_handler(
|
||||
@ -1624,11 +1624,11 @@ mod tests {
|
||||
let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx));
|
||||
let server = FakeServer::for_client(user_id, &client, cx).await;
|
||||
|
||||
let model = cx.entity(|_| Model::default());
|
||||
let model = cx.build_model(|_| TestModel::default());
|
||||
let (done_tx, mut done_rx) = smol::channel::unbounded();
|
||||
let subscription = client.add_message_handler(
|
||||
model.clone().downgrade(),
|
||||
move |model: Handle<Model>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
|
||||
move |model: Model<TestModel>, _: TypedEnvelope<proto::Ping>, _, mut cx| {
|
||||
model
|
||||
.update(&mut cx, |model, _| model.subscription.take())
|
||||
.unwrap();
|
||||
@ -1644,7 +1644,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Model {
|
||||
struct TestModel {
|
||||
id: usize,
|
||||
subscription: Option<Subscription>,
|
||||
}
|
||||
|
@ -5,7 +5,9 @@ use parking_lot::Mutex;
|
||||
use serde::Serialize;
|
||||
use settings2::Settings;
|
||||
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
|
||||
use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt};
|
||||
use sysinfo::{
|
||||
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
use util::http::HttpClient;
|
||||
use util::{channel::ReleaseChannel, TryFutureExt};
|
||||
@ -161,8 +163,16 @@ impl Telemetry {
|
||||
|
||||
let this = self.clone();
|
||||
cx.spawn(|cx| async move {
|
||||
let mut system = System::new_all();
|
||||
system.refresh_all();
|
||||
// Avoiding calling `System::new_all()`, as there have been crashes related to it
|
||||
let refresh_kind = RefreshKind::new()
|
||||
.with_memory() // For memory usage
|
||||
.with_processes(ProcessRefreshKind::everything()) // For process usage
|
||||
.with_cpu(CpuRefreshKind::everything()); // For core count
|
||||
|
||||
let mut system = System::new_with_specifics(refresh_kind);
|
||||
|
||||
// Avoiding calling `refresh_all()`, just update what we need
|
||||
system.refresh_specifics(refresh_kind);
|
||||
|
||||
loop {
|
||||
// Waiting some amount of time before the first query is important to get a reasonable value
|
||||
@ -170,8 +180,7 @@ impl Telemetry {
|
||||
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
|
||||
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
|
||||
|
||||
system.refresh_memory();
|
||||
system.refresh_processes();
|
||||
system.refresh_specifics(refresh_kind);
|
||||
|
||||
let current_process = Pid::from_u32(std::process::id());
|
||||
let Some(process) = system.processes().get(¤t_process) else {
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{stream::BoxStream, StreamExt};
|
||||
use gpui2::{Context, Executor, Handle, TestAppContext};
|
||||
use gpui2::{Context, Executor, Model, TestAppContext};
|
||||
use parking_lot::Mutex;
|
||||
use rpc2::{
|
||||
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
|
||||
@ -194,9 +194,9 @@ impl FakeServer {
|
||||
&self,
|
||||
client: Arc<Client>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Handle<UserStore> {
|
||||
) -> Model<UserStore> {
|
||||
let http_client = FakeHttpClient::with_404_response();
|
||||
let user_store = cx.entity(|cx| UserStore::new(client, http_client, cx));
|
||||
let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx));
|
||||
assert_eq!(
|
||||
self.receive::<proto::GetUsers>()
|
||||
.await
|
||||
|
@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result};
|
||||
use collections::{hash_map::Entry, HashMap, HashSet};
|
||||
use feature_flags2::FeatureFlagAppExt;
|
||||
use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt};
|
||||
use gpui2::{AsyncAppContext, EventEmitter, Handle, ImageData, ModelContext, Task};
|
||||
use gpui2::{AsyncAppContext, EventEmitter, ImageData, Model, ModelContext, Task};
|
||||
use postage::{sink::Sink, watch};
|
||||
use rpc2::proto::{RequestMessage, UsersResponse};
|
||||
use std::sync::{Arc, Weak};
|
||||
@ -213,7 +213,7 @@ impl UserStore {
|
||||
}
|
||||
|
||||
async fn handle_update_invite_info(
|
||||
this: Handle<Self>,
|
||||
this: Model<Self>,
|
||||
message: TypedEnvelope<proto::UpdateInviteInfo>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
@ -229,7 +229,7 @@ impl UserStore {
|
||||
}
|
||||
|
||||
async fn handle_show_contacts(
|
||||
this: Handle<Self>,
|
||||
this: Model<Self>,
|
||||
_: TypedEnvelope<proto::ShowContacts>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
@ -243,7 +243,7 @@ impl UserStore {
|
||||
}
|
||||
|
||||
async fn handle_update_contacts(
|
||||
this: Handle<Self>,
|
||||
this: Model<Self>,
|
||||
message: TypedEnvelope<proto::UpdateContacts>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
@ -690,7 +690,7 @@ impl User {
|
||||
impl Contact {
|
||||
async fn from_proto(
|
||||
contact: proto::Contact,
|
||||
user_store: &Handle<UserStore>,
|
||||
user_store: &Model<UserStore>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let user = user_store
|
||||
|
@ -7,8 +7,8 @@ use async_tar::Archive;
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
|
||||
use gpui2::{
|
||||
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Handle, ModelContext, Task,
|
||||
WeakHandle,
|
||||
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Model, ModelContext, Task,
|
||||
WeakModel,
|
||||
};
|
||||
use language2::{
|
||||
language_settings::{all_language_settings, language_settings},
|
||||
@ -49,7 +49,7 @@ pub fn init(
|
||||
node_runtime: Arc<dyn NodeRuntime>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let copilot = cx.entity({
|
||||
let copilot = cx.build_model({
|
||||
let node_runtime = node_runtime.clone();
|
||||
move |cx| Copilot::start(new_server_id, http, node_runtime, cx)
|
||||
});
|
||||
@ -183,7 +183,7 @@ struct RegisteredBuffer {
|
||||
impl RegisteredBuffer {
|
||||
fn report_changes(
|
||||
&mut self,
|
||||
buffer: &Handle<Buffer>,
|
||||
buffer: &Model<Buffer>,
|
||||
cx: &mut ModelContext<Copilot>,
|
||||
) -> oneshot::Receiver<(i32, BufferSnapshot)> {
|
||||
let (done_tx, done_rx) = oneshot::channel();
|
||||
@ -278,7 +278,7 @@ pub struct Copilot {
|
||||
http: Arc<dyn HttpClient>,
|
||||
node_runtime: Arc<dyn NodeRuntime>,
|
||||
server: CopilotServer,
|
||||
buffers: HashSet<WeakHandle<Buffer>>,
|
||||
buffers: HashSet<WeakModel<Buffer>>,
|
||||
server_id: LanguageServerId,
|
||||
_subscription: gpui2::Subscription,
|
||||
}
|
||||
@ -292,9 +292,9 @@ impl EventEmitter for Copilot {
|
||||
}
|
||||
|
||||
impl Copilot {
|
||||
pub fn global(cx: &AppContext) -> Option<Handle<Self>> {
|
||||
if cx.has_global::<Handle<Self>>() {
|
||||
Some(cx.global::<Handle<Self>>().clone())
|
||||
pub fn global(cx: &AppContext) -> Option<Model<Self>> {
|
||||
if cx.has_global::<Model<Self>>() {
|
||||
Some(cx.global::<Model<Self>>().clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@ -383,7 +383,7 @@ impl Copilot {
|
||||
new_server_id: LanguageServerId,
|
||||
http: Arc<dyn HttpClient>,
|
||||
node_runtime: Arc<dyn NodeRuntime>,
|
||||
this: WeakHandle<Self>,
|
||||
this: WeakModel<Self>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> impl Future<Output = ()> {
|
||||
async move {
|
||||
@ -590,7 +590,7 @@ impl Copilot {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_buffer(&mut self, buffer: &Handle<Buffer>, cx: &mut ModelContext<Self>) {
|
||||
pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
|
||||
let weak_buffer = buffer.downgrade();
|
||||
self.buffers.insert(weak_buffer.clone());
|
||||
|
||||
@ -646,7 +646,7 @@ impl Copilot {
|
||||
|
||||
fn handle_buffer_event(
|
||||
&mut self,
|
||||
buffer: Handle<Buffer>,
|
||||
buffer: Model<Buffer>,
|
||||
event: &language2::Event,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<()> {
|
||||
@ -706,7 +706,7 @@ impl Copilot {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unregister_buffer(&mut self, buffer: &WeakHandle<Buffer>) {
|
||||
fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
|
||||
if let Ok(server) = self.server.as_running() {
|
||||
if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
|
||||
server
|
||||
@ -723,7 +723,7 @@ impl Copilot {
|
||||
|
||||
pub fn completions<T>(
|
||||
&mut self,
|
||||
buffer: &Handle<Buffer>,
|
||||
buffer: &Model<Buffer>,
|
||||
position: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<Completion>>>
|
||||
@ -735,7 +735,7 @@ impl Copilot {
|
||||
|
||||
pub fn completions_cycling<T>(
|
||||
&mut self,
|
||||
buffer: &Handle<Buffer>,
|
||||
buffer: &Model<Buffer>,
|
||||
position: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<Completion>>>
|
||||
@ -792,7 +792,7 @@ impl Copilot {
|
||||
|
||||
fn request_completions<R, T>(
|
||||
&mut self,
|
||||
buffer: &Handle<Buffer>,
|
||||
buffer: &Model<Buffer>,
|
||||
position: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<Completion>>>
|
||||
@ -926,7 +926,7 @@ fn id_for_language(language: Option<&Arc<Language>>) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn uri_for_buffer(buffer: &Handle<Buffer>, cx: &AppContext) -> lsp2::Url {
|
||||
fn uri_for_buffer(buffer: &Model<Buffer>, cx: &AppContext) -> lsp2::Url {
|
||||
if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
|
||||
lsp2::Url::from_file_path(file.abs_path(cx)).unwrap()
|
||||
} else {
|
||||
|
@ -967,7 +967,6 @@ impl CompletionsMenu {
|
||||
self.selected_item -= 1;
|
||||
} else {
|
||||
self.selected_item = self.matches.len() - 1;
|
||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||
}
|
||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
@ -1538,7 +1537,6 @@ impl CodeActionsMenu {
|
||||
self.selected_item -= 1;
|
||||
} else {
|
||||
self.selected_item = self.actions.len() - 1;
|
||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||
}
|
||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||
cx.notify();
|
||||
@ -1547,11 +1545,10 @@ impl CodeActionsMenu {
|
||||
fn select_next(&mut self, cx: &mut ViewContext<Editor>) {
|
||||
if self.selected_item + 1 < self.actions.len() {
|
||||
self.selected_item += 1;
|
||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||
} else {
|
||||
self.selected_item = 0;
|
||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||
}
|
||||
self.list.scroll_to(ScrollTarget::Show(self.selected_item));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,7 @@ use crate::{
|
||||
current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AnyWindowHandle,
|
||||
AppMetadata, AssetSource, ClipboardItem, Context, DispatchPhase, DisplayId, Executor,
|
||||
FocusEvent, FocusHandle, FocusId, KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly,
|
||||
Pixels, Platform, Point, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
|
||||
Pixels, Platform, Point, Render, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
|
||||
TextStyle, TextStyleRefinement, TextSystem, View, ViewContext, Window, WindowContext,
|
||||
WindowHandle, WindowId,
|
||||
};
|
||||
@ -309,10 +309,17 @@ impl AppContext {
|
||||
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
|
||||
) -> Result<R>
|
||||
where
|
||||
V: 'static,
|
||||
V: 'static + Send,
|
||||
{
|
||||
self.update_window(handle.any_handle, |cx| {
|
||||
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap();
|
||||
let root_view = cx
|
||||
.window
|
||||
.root_view
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.downcast()
|
||||
.unwrap();
|
||||
root_view.update(cx, update)
|
||||
})
|
||||
}
|
||||
@ -685,7 +692,7 @@ impl AppContext {
|
||||
|
||||
pub fn observe_release<E: 'static>(
|
||||
&mut self,
|
||||
handle: &Handle<E>,
|
||||
handle: &Model<E>,
|
||||
mut on_release: impl FnMut(&mut E, &mut AppContext) + Send + 'static,
|
||||
) -> Subscription {
|
||||
self.release_listeners.insert(
|
||||
@ -750,35 +757,35 @@ impl AppContext {
|
||||
}
|
||||
|
||||
impl Context for AppContext {
|
||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
||||
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||
type Result<T> = T;
|
||||
|
||||
/// Build an entity that is owned by the application. The given function will be invoked with
|
||||
/// a `ModelContext` and must return an object representing the entity. A `Handle` will be returned
|
||||
/// a `ModelContext` and must return an object representing the entity. A `Model` will be returned
|
||||
/// which can be used to access the entity in a context.
|
||||
fn entity<T: 'static + Send>(
|
||||
fn build_model<T: 'static + Send>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Handle<T> {
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Model<T> {
|
||||
self.update(|cx| {
|
||||
let slot = cx.entities.reserve();
|
||||
let entity = build_entity(&mut ModelContext::mutable(cx, slot.downgrade()));
|
||||
let entity = build_model(&mut ModelContext::mutable(cx, slot.downgrade()));
|
||||
cx.entities.insert(slot, entity)
|
||||
})
|
||||
}
|
||||
|
||||
/// Update the entity referenced by the given handle. The function is passed a mutable reference to the
|
||||
/// Update the entity referenced by the given model. The function is passed a mutable reference to the
|
||||
/// entity along with a `ModelContext` for the entity.
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
model: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> R {
|
||||
self.update(|cx| {
|
||||
let mut entity = cx.entities.lease(handle);
|
||||
let mut entity = cx.entities.lease(model);
|
||||
let result = update(
|
||||
&mut entity,
|
||||
&mut ModelContext::mutable(cx, handle.downgrade()),
|
||||
&mut ModelContext::mutable(cx, model.downgrade()),
|
||||
);
|
||||
cx.entities.end_lease(entity);
|
||||
result
|
||||
@ -861,10 +868,17 @@ impl MainThread<AppContext> {
|
||||
update: impl FnOnce(&mut V, &mut MainThread<ViewContext<'_, '_, V>>) -> R,
|
||||
) -> Result<R>
|
||||
where
|
||||
V: 'static,
|
||||
V: 'static + Send,
|
||||
{
|
||||
self.update_window(handle.any_handle, |cx| {
|
||||
let root_view = cx.window.root_view.as_ref().unwrap().downcast().unwrap();
|
||||
let root_view = cx
|
||||
.window
|
||||
.root_view
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.downcast()
|
||||
.unwrap();
|
||||
root_view.update(cx, update)
|
||||
})
|
||||
}
|
||||
@ -872,7 +886,7 @@ impl MainThread<AppContext> {
|
||||
/// Opens a new window with the given option and the root view returned by the given function.
|
||||
/// The function is invoked with a `WindowContext`, which can be used to interact with window-specific
|
||||
/// functionality.
|
||||
pub fn open_window<V: 'static>(
|
||||
pub fn open_window<V: Render>(
|
||||
&mut self,
|
||||
options: crate::WindowOptions,
|
||||
build_root_view: impl FnOnce(&mut WindowContext) -> View<V> + Send + 'static,
|
||||
@ -955,10 +969,8 @@ impl<G: 'static> DerefMut for GlobalLease<G> {
|
||||
/// Contains state associated with an active drag operation, started by dragging an element
|
||||
/// within the window or by dragging into the app from the underlying platform.
|
||||
pub(crate) struct AnyDrag {
|
||||
pub drag_handle_view: Option<AnyView>,
|
||||
pub view: AnyView,
|
||||
pub cursor_offset: Point<Pixels>,
|
||||
pub state: AnyBox,
|
||||
pub state_type: TypeId,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
AnyWindowHandle, AppContext, Component, Context, Executor, Handle, MainThread, ModelContext,
|
||||
Result, Task, View, ViewContext, VisualContext, WindowContext, WindowHandle,
|
||||
AnyWindowHandle, AppContext, Context, Executor, MainThread, Model, ModelContext, Result, Task,
|
||||
View, ViewContext, VisualContext, WindowContext, WindowHandle,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
@ -14,25 +14,25 @@ pub struct AsyncAppContext {
|
||||
}
|
||||
|
||||
impl Context for AsyncAppContext {
|
||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
||||
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||
type Result<T> = Result<T>;
|
||||
|
||||
fn entity<T: 'static>(
|
||||
fn build_model<T: 'static>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Self::Result<Handle<T>>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Self::Result<Model<T>>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
let app = self.app.upgrade().context("app was released")?;
|
||||
let mut lock = app.lock(); // Need this to compile
|
||||
Ok(lock.entity(build_entity))
|
||||
Ok(lock.build_model(build_model))
|
||||
}
|
||||
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
handle: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> Self::Result<R> {
|
||||
let app = self.app.upgrade().context("app was released")?;
|
||||
let mut lock = app.lock(); // Need this to compile
|
||||
@ -84,7 +84,7 @@ impl AsyncAppContext {
|
||||
update: impl FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> R,
|
||||
) -> Result<R>
|
||||
where
|
||||
V: 'static,
|
||||
V: 'static + Send,
|
||||
{
|
||||
let app = self.app.upgrade().context("app was released")?;
|
||||
let mut app_context = app.lock();
|
||||
@ -234,24 +234,24 @@ impl AsyncWindowContext {
|
||||
}
|
||||
|
||||
impl Context for AsyncWindowContext {
|
||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
||||
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||
type Result<T> = Result<T>;
|
||||
|
||||
fn entity<T>(
|
||||
fn build_model<T>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Result<Handle<T>>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Result<Model<T>>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
self.app
|
||||
.update_window(self.window, |cx| cx.entity(build_entity))
|
||||
.update_window(self.window, |cx| cx.build_model(build_model))
|
||||
}
|
||||
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
handle: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> Result<R> {
|
||||
self.app
|
||||
.update_window(self.window, |cx| cx.update_entity(handle, update))
|
||||
@ -261,17 +261,15 @@ impl Context for AsyncWindowContext {
|
||||
impl VisualContext for AsyncWindowContext {
|
||||
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
|
||||
|
||||
fn build_view<E, V>(
|
||||
fn build_view<V>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
||||
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||
) -> Self::Result<View<V>>
|
||||
where
|
||||
E: Component<V>,
|
||||
V: 'static + Send,
|
||||
{
|
||||
self.app
|
||||
.update_window(self.window, |cx| cx.build_view(build_entity, render))
|
||||
.update_window(self.window, |cx| cx.build_view(build_view_state))
|
||||
}
|
||||
|
||||
fn update_view<V: 'static, R>(
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{AnyBox, AppContext, Context, EntityHandle};
|
||||
use crate::{AnyBox, AppContext, Context};
|
||||
use anyhow::{anyhow, Result};
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use parking_lot::{RwLock, RwLockUpgradableReadGuard};
|
||||
@ -53,29 +53,29 @@ impl EntityMap {
|
||||
/// Reserve a slot for an entity, which you can subsequently use with `insert`.
|
||||
pub fn reserve<T: 'static>(&self) -> Slot<T> {
|
||||
let id = self.ref_counts.write().counts.insert(1.into());
|
||||
Slot(Handle::new(id, Arc::downgrade(&self.ref_counts)))
|
||||
Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
|
||||
}
|
||||
|
||||
/// Insert an entity into a slot obtained by calling `reserve`.
|
||||
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Handle<T>
|
||||
pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
let handle = slot.0;
|
||||
self.entities.insert(handle.entity_id, Box::new(entity));
|
||||
handle
|
||||
let model = slot.0;
|
||||
self.entities.insert(model.entity_id, Box::new(entity));
|
||||
model
|
||||
}
|
||||
|
||||
/// Move an entity to the stack.
|
||||
pub fn lease<'a, T>(&mut self, handle: &'a Handle<T>) -> Lease<'a, T> {
|
||||
self.assert_valid_context(handle);
|
||||
pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
|
||||
self.assert_valid_context(model);
|
||||
let entity = Some(
|
||||
self.entities
|
||||
.remove(handle.entity_id)
|
||||
.remove(model.entity_id)
|
||||
.expect("Circular entity lease. Is the entity already being updated?"),
|
||||
);
|
||||
Lease {
|
||||
handle,
|
||||
model,
|
||||
entity,
|
||||
entity_type: PhantomData,
|
||||
}
|
||||
@ -84,18 +84,18 @@ impl EntityMap {
|
||||
/// Return an entity after moving it to the stack.
|
||||
pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
|
||||
self.entities
|
||||
.insert(lease.handle.entity_id, lease.entity.take().unwrap());
|
||||
.insert(lease.model.entity_id, lease.entity.take().unwrap());
|
||||
}
|
||||
|
||||
pub fn read<T: 'static>(&self, handle: &Handle<T>) -> &T {
|
||||
self.assert_valid_context(handle);
|
||||
self.entities[handle.entity_id].downcast_ref().unwrap()
|
||||
pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
|
||||
self.assert_valid_context(model);
|
||||
self.entities[model.entity_id].downcast_ref().unwrap()
|
||||
}
|
||||
|
||||
fn assert_valid_context(&self, handle: &AnyHandle) {
|
||||
fn assert_valid_context(&self, model: &AnyModel) {
|
||||
debug_assert!(
|
||||
Weak::ptr_eq(&handle.entity_map, &Arc::downgrade(&self.ref_counts)),
|
||||
"used a handle with the wrong context"
|
||||
Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
|
||||
"used a model with the wrong context"
|
||||
);
|
||||
}
|
||||
|
||||
@ -115,7 +115,7 @@ impl EntityMap {
|
||||
|
||||
pub struct Lease<'a, T> {
|
||||
entity: Option<AnyBox>,
|
||||
pub handle: &'a Handle<T>,
|
||||
pub model: &'a Model<T>,
|
||||
entity_type: PhantomData<T>,
|
||||
}
|
||||
|
||||
@ -143,15 +143,15 @@ impl<'a, T> Drop for Lease<'a, T> {
|
||||
}
|
||||
|
||||
#[derive(Deref, DerefMut)]
|
||||
pub struct Slot<T>(Handle<T>);
|
||||
pub struct Slot<T>(Model<T>);
|
||||
|
||||
pub struct AnyHandle {
|
||||
pub struct AnyModel {
|
||||
pub(crate) entity_id: EntityId,
|
||||
entity_type: TypeId,
|
||||
pub(crate) entity_type: TypeId,
|
||||
entity_map: Weak<RwLock<EntityRefCounts>>,
|
||||
}
|
||||
|
||||
impl AnyHandle {
|
||||
impl AnyModel {
|
||||
fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
|
||||
Self {
|
||||
entity_id: id,
|
||||
@ -164,18 +164,18 @@ impl AnyHandle {
|
||||
self.entity_id
|
||||
}
|
||||
|
||||
pub fn downgrade(&self) -> AnyWeakHandle {
|
||||
AnyWeakHandle {
|
||||
pub fn downgrade(&self) -> AnyWeakModel {
|
||||
AnyWeakModel {
|
||||
entity_id: self.entity_id,
|
||||
entity_type: self.entity_type,
|
||||
entity_ref_counts: self.entity_map.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn downcast<T: 'static>(&self) -> Option<Handle<T>> {
|
||||
pub fn downcast<T: 'static>(&self) -> Option<Model<T>> {
|
||||
if TypeId::of::<T>() == self.entity_type {
|
||||
Some(Handle {
|
||||
any_handle: self.clone(),
|
||||
Some(Model {
|
||||
any_model: self.clone(),
|
||||
entity_type: PhantomData,
|
||||
})
|
||||
} else {
|
||||
@ -184,16 +184,16 @@ impl AnyHandle {
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for AnyHandle {
|
||||
impl Clone for AnyModel {
|
||||
fn clone(&self) -> Self {
|
||||
if let Some(entity_map) = self.entity_map.upgrade() {
|
||||
let entity_map = entity_map.read();
|
||||
let count = entity_map
|
||||
.counts
|
||||
.get(self.entity_id)
|
||||
.expect("detected over-release of a handle");
|
||||
.expect("detected over-release of a model");
|
||||
let prev_count = count.fetch_add(1, SeqCst);
|
||||
assert_ne!(prev_count, 0, "Detected over-release of a handle.");
|
||||
assert_ne!(prev_count, 0, "Detected over-release of a model.");
|
||||
}
|
||||
|
||||
Self {
|
||||
@ -204,16 +204,16 @@ impl Clone for AnyHandle {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AnyHandle {
|
||||
impl Drop for AnyModel {
|
||||
fn drop(&mut self) {
|
||||
if let Some(entity_map) = self.entity_map.upgrade() {
|
||||
let entity_map = entity_map.upgradable_read();
|
||||
let count = entity_map
|
||||
.counts
|
||||
.get(self.entity_id)
|
||||
.expect("Detected over-release of a handle.");
|
||||
.expect("Detected over-release of a model.");
|
||||
let prev_count = count.fetch_sub(1, SeqCst);
|
||||
assert_ne!(prev_count, 0, "Detected over-release of a handle.");
|
||||
assert_ne!(prev_count, 0, "Detected over-release of a model.");
|
||||
if prev_count == 1 {
|
||||
// We were the last reference to this entity, so we can remove it.
|
||||
let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
|
||||
@ -223,60 +223,65 @@ impl Drop for AnyHandle {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<Handle<T>> for AnyHandle {
|
||||
fn from(handle: Handle<T>) -> Self {
|
||||
handle.any_handle
|
||||
impl<T> From<Model<T>> for AnyModel {
|
||||
fn from(model: Model<T>) -> Self {
|
||||
model.any_model
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for AnyHandle {
|
||||
impl Hash for AnyModel {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.entity_id.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for AnyHandle {
|
||||
impl PartialEq for AnyModel {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.entity_id == other.entity_id
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for AnyHandle {}
|
||||
impl Eq for AnyModel {}
|
||||
|
||||
#[derive(Deref, DerefMut)]
|
||||
pub struct Handle<T> {
|
||||
pub struct Model<T> {
|
||||
#[deref]
|
||||
#[deref_mut]
|
||||
any_handle: AnyHandle,
|
||||
entity_type: PhantomData<T>,
|
||||
pub(crate) any_model: AnyModel,
|
||||
pub(crate) entity_type: PhantomData<T>,
|
||||
}
|
||||
|
||||
unsafe impl<T> Send for Handle<T> {}
|
||||
unsafe impl<T> Sync for Handle<T> {}
|
||||
unsafe impl<T> Send for Model<T> {}
|
||||
unsafe impl<T> Sync for Model<T> {}
|
||||
|
||||
impl<T: 'static> Handle<T> {
|
||||
impl<T: 'static> Model<T> {
|
||||
fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
|
||||
where
|
||||
T: 'static,
|
||||
{
|
||||
Self {
|
||||
any_handle: AnyHandle::new(id, TypeId::of::<T>(), entity_map),
|
||||
any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
|
||||
entity_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn downgrade(&self) -> WeakHandle<T> {
|
||||
WeakHandle {
|
||||
any_handle: self.any_handle.downgrade(),
|
||||
pub fn downgrade(&self) -> WeakModel<T> {
|
||||
WeakModel {
|
||||
any_model: self.any_model.downgrade(),
|
||||
entity_type: self.entity_type,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this into a dynamically typed model.
|
||||
pub fn into_any(self) -> AnyModel {
|
||||
self.any_model
|
||||
}
|
||||
|
||||
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
|
||||
cx.entities.read(self)
|
||||
}
|
||||
|
||||
/// Update the entity referenced by this handle with the given function.
|
||||
/// Update the entity referenced by this model with the given function.
|
||||
///
|
||||
/// The update function receives a context appropriate for its environment.
|
||||
/// When updating in an `AppContext`, it receives a `ModelContext`.
|
||||
@ -284,7 +289,7 @@ impl<T: 'static> Handle<T> {
|
||||
pub fn update<C, R>(
|
||||
&self,
|
||||
cx: &mut C,
|
||||
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R,
|
||||
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
|
||||
) -> C::Result<R>
|
||||
where
|
||||
C: Context,
|
||||
@ -293,73 +298,54 @@ impl<T: 'static> Handle<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for Handle<T> {
|
||||
impl<T> Clone for Model<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
any_handle: self.any_handle.clone(),
|
||||
any_model: self.any_model.clone(),
|
||||
entity_type: self.entity_type,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::fmt::Debug for Handle<T> {
|
||||
impl<T> std::fmt::Debug for Model<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Handle {{ entity_id: {:?}, entity_type: {:?} }}",
|
||||
self.any_handle.entity_id,
|
||||
"Model {{ entity_id: {:?}, entity_type: {:?} }}",
|
||||
self.any_model.entity_id,
|
||||
type_name::<T>()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Hash for Handle<T> {
|
||||
impl<T> Hash for Model<T> {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.any_handle.hash(state);
|
||||
self.any_model.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PartialEq for Handle<T> {
|
||||
impl<T> PartialEq for Model<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.any_handle == other.any_handle
|
||||
self.any_model == other.any_model
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Eq for Handle<T> {}
|
||||
impl<T> Eq for Model<T> {}
|
||||
|
||||
impl<T> PartialEq<WeakHandle<T>> for Handle<T> {
|
||||
fn eq(&self, other: &WeakHandle<T>) -> bool {
|
||||
self.entity_id == other.entity_id
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: 'static> EntityHandle<T> for Handle<T> {
|
||||
type Weak = WeakHandle<T>;
|
||||
|
||||
fn entity_id(&self) -> EntityId {
|
||||
self.entity_id
|
||||
}
|
||||
|
||||
fn downgrade(&self) -> Self::Weak {
|
||||
self.downgrade()
|
||||
}
|
||||
|
||||
fn upgrade_from(weak: &Self::Weak) -> Option<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
weak.upgrade()
|
||||
impl<T> PartialEq<WeakModel<T>> for Model<T> {
|
||||
fn eq(&self, other: &WeakModel<T>) -> bool {
|
||||
self.entity_id() == other.entity_id()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AnyWeakHandle {
|
||||
pub struct AnyWeakModel {
|
||||
pub(crate) entity_id: EntityId,
|
||||
entity_type: TypeId,
|
||||
entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
|
||||
}
|
||||
|
||||
impl AnyWeakHandle {
|
||||
impl AnyWeakModel {
|
||||
pub fn entity_id(&self) -> EntityId {
|
||||
self.entity_id
|
||||
}
|
||||
@ -373,14 +359,14 @@ impl AnyWeakHandle {
|
||||
ref_count > 0
|
||||
}
|
||||
|
||||
pub fn upgrade(&self) -> Option<AnyHandle> {
|
||||
pub fn upgrade(&self) -> Option<AnyModel> {
|
||||
let entity_map = self.entity_ref_counts.upgrade()?;
|
||||
entity_map
|
||||
.read()
|
||||
.counts
|
||||
.get(self.entity_id)?
|
||||
.fetch_add(1, SeqCst);
|
||||
Some(AnyHandle {
|
||||
Some(AnyModel {
|
||||
entity_id: self.entity_id,
|
||||
entity_type: self.entity_type,
|
||||
entity_map: self.entity_ref_counts.clone(),
|
||||
@ -388,55 +374,55 @@ impl AnyWeakHandle {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<WeakHandle<T>> for AnyWeakHandle {
|
||||
fn from(handle: WeakHandle<T>) -> Self {
|
||||
handle.any_handle
|
||||
impl<T> From<WeakModel<T>> for AnyWeakModel {
|
||||
fn from(model: WeakModel<T>) -> Self {
|
||||
model.any_model
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for AnyWeakHandle {
|
||||
impl Hash for AnyWeakModel {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.entity_id.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for AnyWeakHandle {
|
||||
impl PartialEq for AnyWeakModel {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.entity_id == other.entity_id
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for AnyWeakHandle {}
|
||||
impl Eq for AnyWeakModel {}
|
||||
|
||||
#[derive(Deref, DerefMut)]
|
||||
pub struct WeakHandle<T> {
|
||||
pub struct WeakModel<T> {
|
||||
#[deref]
|
||||
#[deref_mut]
|
||||
any_handle: AnyWeakHandle,
|
||||
any_model: AnyWeakModel,
|
||||
entity_type: PhantomData<T>,
|
||||
}
|
||||
|
||||
unsafe impl<T> Send for WeakHandle<T> {}
|
||||
unsafe impl<T> Sync for WeakHandle<T> {}
|
||||
unsafe impl<T> Send for WeakModel<T> {}
|
||||
unsafe impl<T> Sync for WeakModel<T> {}
|
||||
|
||||
impl<T> Clone for WeakHandle<T> {
|
||||
impl<T> Clone for WeakModel<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
any_handle: self.any_handle.clone(),
|
||||
any_model: self.any_model.clone(),
|
||||
entity_type: self.entity_type,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: 'static> WeakHandle<T> {
|
||||
pub fn upgrade(&self) -> Option<Handle<T>> {
|
||||
Some(Handle {
|
||||
any_handle: self.any_handle.upgrade()?,
|
||||
impl<T: 'static> WeakModel<T> {
|
||||
pub fn upgrade(&self) -> Option<Model<T>> {
|
||||
Some(Model {
|
||||
any_model: self.any_model.upgrade()?,
|
||||
entity_type: self.entity_type,
|
||||
})
|
||||
}
|
||||
|
||||
/// Update the entity referenced by this handle with the given function if
|
||||
/// Update the entity referenced by this model with the given function if
|
||||
/// the referenced entity still exists. Returns an error if the entity has
|
||||
/// been released.
|
||||
///
|
||||
@ -446,7 +432,7 @@ impl<T: 'static> WeakHandle<T> {
|
||||
pub fn update<C, R>(
|
||||
&self,
|
||||
cx: &mut C,
|
||||
update: impl FnOnce(&mut T, &mut C::EntityContext<'_, T>) -> R,
|
||||
update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R,
|
||||
) -> Result<R>
|
||||
where
|
||||
C: Context,
|
||||
@ -460,22 +446,22 @@ impl<T: 'static> WeakHandle<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Hash for WeakHandle<T> {
|
||||
impl<T> Hash for WeakModel<T> {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.any_handle.hash(state);
|
||||
self.any_model.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PartialEq for WeakHandle<T> {
|
||||
impl<T> PartialEq for WeakModel<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.any_handle == other.any_handle
|
||||
self.any_model == other.any_model
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Eq for WeakHandle<T> {}
|
||||
impl<T> Eq for WeakModel<T> {}
|
||||
|
||||
impl<T> PartialEq<Handle<T>> for WeakHandle<T> {
|
||||
fn eq(&self, other: &Handle<T>) -> bool {
|
||||
self.entity_id == other.entity_id
|
||||
impl<T> PartialEq<Model<T>> for WeakModel<T> {
|
||||
fn eq(&self, other: &Model<T>) -> bool {
|
||||
self.entity_id() == other.entity_id()
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, Handle, MainThread,
|
||||
Reference, Subscription, Task, WeakHandle,
|
||||
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, MainThread, Model,
|
||||
Reference, Subscription, Task, WeakModel,
|
||||
};
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use futures::FutureExt;
|
||||
@ -15,11 +15,11 @@ pub struct ModelContext<'a, T> {
|
||||
#[deref]
|
||||
#[deref_mut]
|
||||
app: Reference<'a, AppContext>,
|
||||
model_state: WeakHandle<T>,
|
||||
model_state: WeakModel<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'static> ModelContext<'a, T> {
|
||||
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakHandle<T>) -> Self {
|
||||
pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel<T>) -> Self {
|
||||
Self {
|
||||
app: Reference::Mutable(app),
|
||||
model_state,
|
||||
@ -30,20 +30,20 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
||||
self.model_state.entity_id
|
||||
}
|
||||
|
||||
pub fn handle(&self) -> Handle<T> {
|
||||
pub fn handle(&self) -> Model<T> {
|
||||
self.weak_handle()
|
||||
.upgrade()
|
||||
.expect("The entity must be alive if we have a model context")
|
||||
}
|
||||
|
||||
pub fn weak_handle(&self) -> WeakHandle<T> {
|
||||
pub fn weak_handle(&self) -> WeakModel<T> {
|
||||
self.model_state.clone()
|
||||
}
|
||||
|
||||
pub fn observe<T2: 'static>(
|
||||
&mut self,
|
||||
handle: &Handle<T2>,
|
||||
mut on_notify: impl FnMut(&mut T, Handle<T2>, &mut ModelContext<'_, T>) + Send + 'static,
|
||||
handle: &Model<T2>,
|
||||
mut on_notify: impl FnMut(&mut T, Model<T2>, &mut ModelContext<'_, T>) + Send + 'static,
|
||||
) -> Subscription
|
||||
where
|
||||
T: 'static + Send,
|
||||
@ -65,10 +65,8 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
||||
|
||||
pub fn subscribe<E: 'static + EventEmitter>(
|
||||
&mut self,
|
||||
handle: &Handle<E>,
|
||||
mut on_event: impl FnMut(&mut T, Handle<E>, &E::Event, &mut ModelContext<'_, T>)
|
||||
+ Send
|
||||
+ 'static,
|
||||
handle: &Model<E>,
|
||||
mut on_event: impl FnMut(&mut T, Model<E>, &E::Event, &mut ModelContext<'_, T>) + Send + 'static,
|
||||
) -> Subscription
|
||||
where
|
||||
T: 'static + Send,
|
||||
@ -107,7 +105,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
||||
|
||||
pub fn observe_release<E: 'static>(
|
||||
&mut self,
|
||||
handle: &Handle<E>,
|
||||
handle: &Model<E>,
|
||||
mut on_release: impl FnMut(&mut T, &mut E, &mut ModelContext<'_, T>) + Send + 'static,
|
||||
) -> Subscription
|
||||
where
|
||||
@ -182,7 +180,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
||||
|
||||
pub fn spawn<Fut, R>(
|
||||
&self,
|
||||
f: impl FnOnce(WeakHandle<T>, AsyncAppContext) -> Fut + Send + 'static,
|
||||
f: impl FnOnce(WeakModel<T>, AsyncAppContext) -> Fut + Send + 'static,
|
||||
) -> Task<R>
|
||||
where
|
||||
T: 'static,
|
||||
@ -195,7 +193,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
|
||||
|
||||
pub fn spawn_on_main<Fut, R>(
|
||||
&self,
|
||||
f: impl FnOnce(WeakHandle<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
|
||||
f: impl FnOnce(WeakModel<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
|
||||
) -> Task<R>
|
||||
where
|
||||
Fut: Future<Output = R> + 'static,
|
||||
@ -220,23 +218,23 @@ where
|
||||
}
|
||||
|
||||
impl<'a, T> Context for ModelContext<'a, T> {
|
||||
type EntityContext<'b, U> = ModelContext<'b, U>;
|
||||
type ModelContext<'b, U> = ModelContext<'b, U>;
|
||||
type Result<U> = U;
|
||||
|
||||
fn entity<U>(
|
||||
fn build_model<U>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, U>) -> U,
|
||||
) -> Handle<U>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, U>) -> U,
|
||||
) -> Model<U>
|
||||
where
|
||||
U: 'static + Send,
|
||||
{
|
||||
self.app.entity(build_entity)
|
||||
self.app.build_model(build_model)
|
||||
}
|
||||
|
||||
fn update_entity<U: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<U>,
|
||||
update: impl FnOnce(&mut U, &mut Self::EntityContext<'_, U>) -> R,
|
||||
handle: &Model<U>,
|
||||
update: impl FnOnce(&mut U, &mut Self::ModelContext<'_, U>) -> R,
|
||||
) -> R {
|
||||
self.app.update_entity(handle, update)
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, Handle, MainThread,
|
||||
AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, MainThread, Model,
|
||||
ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
@ -12,24 +12,24 @@ pub struct TestAppContext {
|
||||
}
|
||||
|
||||
impl Context for TestAppContext {
|
||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
||||
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||
type Result<T> = T;
|
||||
|
||||
fn entity<T: 'static>(
|
||||
fn build_model<T: 'static>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Self::Result<Handle<T>>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Self::Result<Model<T>>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
let mut lock = self.app.lock();
|
||||
lock.entity(build_entity)
|
||||
lock.build_model(build_model)
|
||||
}
|
||||
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
handle: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> Self::Result<R> {
|
||||
let mut lock = self.app.lock();
|
||||
lock.update_entity(handle, update)
|
||||
|
@ -4,7 +4,7 @@ pub(crate) use smallvec::SmallVec;
|
||||
use std::{any::Any, mem};
|
||||
|
||||
pub trait Element<V: 'static> {
|
||||
type ElementState: 'static;
|
||||
type ElementState: 'static + Send;
|
||||
|
||||
fn id(&self) -> Option<ElementId>;
|
||||
|
||||
|
@ -70,33 +70,31 @@ use taffy::TaffyLayoutEngine;
|
||||
type AnyBox = Box<dyn Any + Send>;
|
||||
|
||||
pub trait Context {
|
||||
type EntityContext<'a, T>;
|
||||
type ModelContext<'a, T>;
|
||||
type Result<T>;
|
||||
|
||||
fn entity<T>(
|
||||
fn build_model<T>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Self::Result<Handle<T>>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Self::Result<Model<T>>
|
||||
where
|
||||
T: 'static + Send;
|
||||
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
handle: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> Self::Result<R>;
|
||||
}
|
||||
|
||||
pub trait VisualContext: Context {
|
||||
type ViewContext<'a, 'w, V>;
|
||||
|
||||
fn build_view<E, V>(
|
||||
fn build_view<V>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
||||
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||
) -> Self::Result<View<V>>
|
||||
where
|
||||
E: Component<V>,
|
||||
V: 'static + Send;
|
||||
|
||||
fn update_view<V: 'static, R>(
|
||||
@ -140,37 +138,37 @@ impl<T> DerefMut for MainThread<T> {
|
||||
}
|
||||
|
||||
impl<C: Context> Context for MainThread<C> {
|
||||
type EntityContext<'a, T> = MainThread<C::EntityContext<'a, T>>;
|
||||
type ModelContext<'a, T> = MainThread<C::ModelContext<'a, T>>;
|
||||
type Result<T> = C::Result<T>;
|
||||
|
||||
fn entity<T>(
|
||||
fn build_model<T>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Self::Result<Handle<T>>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Self::Result<Model<T>>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
self.0.entity(|cx| {
|
||||
self.0.build_model(|cx| {
|
||||
let cx = unsafe {
|
||||
mem::transmute::<
|
||||
&mut C::EntityContext<'_, T>,
|
||||
&mut MainThread<C::EntityContext<'_, T>>,
|
||||
&mut C::ModelContext<'_, T>,
|
||||
&mut MainThread<C::ModelContext<'_, T>>,
|
||||
>(cx)
|
||||
};
|
||||
build_entity(cx)
|
||||
build_model(cx)
|
||||
})
|
||||
}
|
||||
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
handle: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> Self::Result<R> {
|
||||
self.0.update_entity(handle, |entity, cx| {
|
||||
let cx = unsafe {
|
||||
mem::transmute::<
|
||||
&mut C::EntityContext<'_, T>,
|
||||
&mut MainThread<C::EntityContext<'_, T>>,
|
||||
&mut C::ModelContext<'_, T>,
|
||||
&mut MainThread<C::ModelContext<'_, T>>,
|
||||
>(cx)
|
||||
};
|
||||
update(entity, cx)
|
||||
@ -181,27 +179,22 @@ impl<C: Context> Context for MainThread<C> {
|
||||
impl<C: VisualContext> VisualContext for MainThread<C> {
|
||||
type ViewContext<'a, 'w, V> = MainThread<C::ViewContext<'a, 'w, V>>;
|
||||
|
||||
fn build_view<E, V>(
|
||||
fn build_view<V>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
||||
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||
) -> Self::Result<View<V>>
|
||||
where
|
||||
E: Component<V>,
|
||||
V: 'static + Send,
|
||||
{
|
||||
self.0.build_view(
|
||||
|cx| {
|
||||
let cx = unsafe {
|
||||
mem::transmute::<
|
||||
&mut C::ViewContext<'_, '_, V>,
|
||||
&mut MainThread<C::ViewContext<'_, '_, V>>,
|
||||
>(cx)
|
||||
};
|
||||
build_entity(cx)
|
||||
},
|
||||
render,
|
||||
)
|
||||
self.0.build_view(|cx| {
|
||||
let cx = unsafe {
|
||||
mem::transmute::<
|
||||
&mut C::ViewContext<'_, '_, V>,
|
||||
&mut MainThread<C::ViewContext<'_, '_, V>>,
|
||||
>(cx)
|
||||
};
|
||||
build_view_state(cx)
|
||||
})
|
||||
}
|
||||
|
||||
fn update_view<V: 'static, R>(
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
point, px, Action, AnyBox, AnyDrag, AppContext, BorrowWindow, Bounds, Component,
|
||||
DispatchContext, DispatchPhase, Element, ElementId, FocusHandle, KeyMatch, Keystroke,
|
||||
Modifiers, Overflow, Pixels, Point, SharedString, Size, Style, StyleRefinement, View,
|
||||
div, point, px, Action, AnyDrag, AnyView, AppContext, BorrowWindow, Bounds, Component,
|
||||
DispatchContext, DispatchPhase, Div, Element, ElementId, FocusHandle, KeyMatch, Keystroke,
|
||||
Modifiers, Overflow, Pixels, Point, Render, SharedString, Size, Style, StyleRefinement, View,
|
||||
ViewContext,
|
||||
};
|
||||
use collections::HashMap;
|
||||
@ -258,17 +258,17 @@ pub trait StatelessInteractive<V: 'static>: Element<V> {
|
||||
self
|
||||
}
|
||||
|
||||
fn on_drop<S: 'static>(
|
||||
fn on_drop<W: 'static + Send>(
|
||||
mut self,
|
||||
listener: impl Fn(&mut V, S, &mut ViewContext<V>) + Send + 'static,
|
||||
listener: impl Fn(&mut V, View<W>, &mut ViewContext<V>) + Send + 'static,
|
||||
) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.stateless_interaction().drop_listeners.push((
|
||||
TypeId::of::<S>(),
|
||||
Box::new(move |view, drag_state, cx| {
|
||||
listener(view, *drag_state.downcast().unwrap(), cx);
|
||||
TypeId::of::<W>(),
|
||||
Box::new(move |view, dragged_view, cx| {
|
||||
listener(view, dragged_view.downcast().unwrap(), cx);
|
||||
}),
|
||||
));
|
||||
self
|
||||
@ -314,36 +314,22 @@ pub trait StatefulInteractive<V: 'static>: StatelessInteractive<V> {
|
||||
self
|
||||
}
|
||||
|
||||
fn on_drag<S, R, E>(
|
||||
fn on_drag<W>(
|
||||
mut self,
|
||||
listener: impl Fn(&mut V, &mut ViewContext<V>) -> Drag<S, R, V, E> + Send + 'static,
|
||||
listener: impl Fn(&mut V, &mut ViewContext<V>) -> View<W> + Send + 'static,
|
||||
) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
S: Any + Send,
|
||||
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
||||
R: 'static + Send,
|
||||
E: Component<V>,
|
||||
W: 'static + Send + Render,
|
||||
{
|
||||
debug_assert!(
|
||||
self.stateful_interaction().drag_listener.is_none(),
|
||||
"calling on_drag more than once on the same element is not supported"
|
||||
);
|
||||
self.stateful_interaction().drag_listener =
|
||||
Some(Box::new(move |view_state, cursor_offset, cx| {
|
||||
let drag = listener(view_state, cx);
|
||||
let drag_handle_view = Some(
|
||||
View::for_handle(cx.handle().upgrade().unwrap(), move |view_state, cx| {
|
||||
(drag.render_drag_handle)(view_state, cx)
|
||||
})
|
||||
.into_any(),
|
||||
);
|
||||
AnyDrag {
|
||||
drag_handle_view,
|
||||
cursor_offset,
|
||||
state: Box::new(drag.state),
|
||||
state_type: TypeId::of::<S>(),
|
||||
}
|
||||
Some(Box::new(move |view_state, cursor_offset, cx| AnyDrag {
|
||||
view: listener(view_state, cx).into_any(),
|
||||
cursor_offset,
|
||||
}));
|
||||
self
|
||||
}
|
||||
@ -412,7 +398,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
||||
if let Some(drag) = cx.active_drag.take() {
|
||||
for (state_type, group_drag_style) in &self.as_stateless().group_drag_over_styles {
|
||||
if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
|
||||
if *state_type == drag.state_type
|
||||
if *state_type == drag.view.entity_type()
|
||||
&& group_bounds.contains_point(&mouse_position)
|
||||
{
|
||||
style.refine(&group_drag_style.style);
|
||||
@ -421,7 +407,8 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
||||
}
|
||||
|
||||
for (state_type, drag_over_style) in &self.as_stateless().drag_over_styles {
|
||||
if *state_type == drag.state_type && bounds.contains_point(&mouse_position) {
|
||||
if *state_type == drag.view.entity_type() && bounds.contains_point(&mouse_position)
|
||||
{
|
||||
style.refine(drag_over_style);
|
||||
}
|
||||
}
|
||||
@ -509,7 +496,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
||||
cx.on_mouse_event(move |view, event: &MouseUpEvent, phase, cx| {
|
||||
if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
|
||||
if let Some(drag_state_type) =
|
||||
cx.active_drag.as_ref().map(|drag| drag.state_type)
|
||||
cx.active_drag.as_ref().map(|drag| drag.view.entity_type())
|
||||
{
|
||||
for (drop_state_type, listener) in &drop_listeners {
|
||||
if *drop_state_type == drag_state_type {
|
||||
@ -517,7 +504,7 @@ pub trait ElementInteraction<V: 'static>: 'static + Send {
|
||||
.active_drag
|
||||
.take()
|
||||
.expect("checked for type drag state type above");
|
||||
listener(view, drag.state, cx);
|
||||
listener(view, drag.view.clone(), cx);
|
||||
cx.notify();
|
||||
cx.stop_propagation();
|
||||
}
|
||||
@ -685,7 +672,7 @@ impl<V> From<ElementId> for StatefulInteraction<V> {
|
||||
}
|
||||
}
|
||||
|
||||
type DropListener<V> = dyn Fn(&mut V, AnyBox, &mut ViewContext<V>) + 'static + Send;
|
||||
type DropListener<V> = dyn Fn(&mut V, AnyView, &mut ViewContext<V>) + 'static + Send;
|
||||
|
||||
pub struct StatelessInteraction<V> {
|
||||
pub dispatch_context: DispatchContext,
|
||||
@ -866,7 +853,7 @@ pub struct Drag<S, R, V, E>
|
||||
where
|
||||
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
||||
V: 'static,
|
||||
E: Component<V>,
|
||||
E: Component<()>,
|
||||
{
|
||||
pub state: S,
|
||||
pub render_drag_handle: R,
|
||||
@ -877,7 +864,7 @@ impl<S, R, V, E> Drag<S, R, V, E>
|
||||
where
|
||||
R: Fn(&mut V, &mut ViewContext<V>) -> E,
|
||||
V: 'static,
|
||||
E: Component<V>,
|
||||
E: Component<()>,
|
||||
{
|
||||
pub fn new(state: S, render_drag_handle: R) -> Self {
|
||||
Drag {
|
||||
@ -888,6 +875,10 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// impl<S, R, V, E> Render for Drag<S, R, V, E> {
|
||||
// // fn render(&mut self, cx: ViewContext<Self>) ->
|
||||
// }
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
|
||||
pub enum MouseButton {
|
||||
Left,
|
||||
@ -995,6 +986,14 @@ impl Deref for MouseExitEvent {
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ExternalPaths(pub(crate) SmallVec<[PathBuf; 2]>);
|
||||
|
||||
impl Render for ExternalPaths {
|
||||
type Element = Div<Self>;
|
||||
|
||||
fn render(&mut self, _: &mut ViewContext<Self>) -> Self::Element {
|
||||
div() // Intentionally left empty because the platform will render icons for the dragged files
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FileDropEvent {
|
||||
Entered {
|
||||
|
@ -1,45 +1,35 @@
|
||||
use crate::{
|
||||
AnyBox, AnyElement, AppContext, AvailableSpace, BorrowWindow, Bounds, Component, Element,
|
||||
ElementId, EntityHandle, EntityId, Flatten, Handle, LayoutId, Pixels, Size, ViewContext,
|
||||
VisualContext, WeakHandle, WindowContext,
|
||||
AnyBox, AnyElement, AnyModel, AppContext, AvailableSpace, BorrowWindow, Bounds, Component,
|
||||
Element, ElementId, EntityHandle, EntityId, Flatten, LayoutId, Model, Pixels, Size,
|
||||
ViewContext, VisualContext, WeakModel, WindowContext,
|
||||
};
|
||||
use anyhow::{Context, Result};
|
||||
use parking_lot::Mutex;
|
||||
use std::{
|
||||
any::Any,
|
||||
any::{Any, TypeId},
|
||||
marker::PhantomData,
|
||||
sync::{Arc, Weak},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub struct View<V> {
|
||||
pub(crate) state: Handle<V>,
|
||||
render: Arc<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
|
||||
pub trait Render: 'static + Sized {
|
||||
type Element: Element<Self> + 'static + Send;
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element;
|
||||
}
|
||||
|
||||
impl<V: 'static> View<V> {
|
||||
pub fn for_handle<E>(
|
||||
state: Handle<V>,
|
||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
||||
) -> View<V>
|
||||
where
|
||||
E: Component<V>,
|
||||
{
|
||||
View {
|
||||
state,
|
||||
render: Arc::new(Mutex::new(
|
||||
move |state: &mut V, cx: &mut ViewContext<'_, '_, V>| render(state, cx).render(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
pub struct View<V> {
|
||||
pub(crate) model: Model<V>,
|
||||
}
|
||||
|
||||
impl<V: Render> View<V> {
|
||||
pub fn into_any(self) -> AnyView {
|
||||
AnyView(Arc::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: 'static> View<V> {
|
||||
pub fn downgrade(&self) -> WeakView<V> {
|
||||
WeakView {
|
||||
state: self.state.downgrade(),
|
||||
render: Arc::downgrade(&self.render),
|
||||
model: self.model.downgrade(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,20 +45,19 @@ impl<V: 'static> View<V> {
|
||||
}
|
||||
|
||||
pub fn read<'a>(&self, cx: &'a AppContext) -> &'a V {
|
||||
cx.entities.read(&self.state)
|
||||
self.model.read(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> Clone for View<V> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: self.state.clone(),
|
||||
render: self.render.clone(),
|
||||
model: self.model.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V> {
|
||||
impl<V: Render, ParentViewState: 'static> Component<ParentViewState> for View<V> {
|
||||
fn render(self) -> AnyElement<ParentViewState> {
|
||||
AnyElement::new(EraseViewState {
|
||||
view: self,
|
||||
@ -77,11 +66,14 @@ impl<V: 'static, ParentViewState: 'static> Component<ParentViewState> for View<V
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: 'static> Element<()> for View<V> {
|
||||
impl<V> Element<()> for View<V>
|
||||
where
|
||||
V: Render,
|
||||
{
|
||||
type ElementState = AnyElement<V>;
|
||||
|
||||
fn id(&self) -> Option<ElementId> {
|
||||
Some(ElementId::View(self.state.entity_id))
|
||||
fn id(&self) -> Option<crate::ElementId> {
|
||||
Some(ElementId::View(self.model.entity_id))
|
||||
}
|
||||
|
||||
fn initialize(
|
||||
@ -91,7 +83,7 @@ impl<V: 'static> Element<()> for View<V> {
|
||||
cx: &mut ViewContext<()>,
|
||||
) -> Self::ElementState {
|
||||
self.update(cx, |state, cx| {
|
||||
let mut any_element = (self.render.lock())(state, cx);
|
||||
let mut any_element = AnyElement::new(state.render(cx));
|
||||
any_element.initialize(state, cx);
|
||||
any_element
|
||||
})
|
||||
@ -121,7 +113,7 @@ impl<T: 'static> EntityHandle<T> for View<T> {
|
||||
type Weak = WeakView<T>;
|
||||
|
||||
fn entity_id(&self) -> EntityId {
|
||||
self.state.entity_id
|
||||
self.model.entity_id
|
||||
}
|
||||
|
||||
fn downgrade(&self) -> Self::Weak {
|
||||
@ -137,15 +129,13 @@ impl<T: 'static> EntityHandle<T> for View<T> {
|
||||
}
|
||||
|
||||
pub struct WeakView<V> {
|
||||
pub(crate) state: WeakHandle<V>,
|
||||
render: Weak<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
|
||||
pub(crate) model: WeakModel<V>,
|
||||
}
|
||||
|
||||
impl<V: 'static> WeakView<V> {
|
||||
pub fn upgrade(&self) -> Option<View<V>> {
|
||||
let state = self.state.upgrade()?;
|
||||
let render = self.render.upgrade()?;
|
||||
Some(View { state, render })
|
||||
let model = self.model.upgrade()?;
|
||||
Some(View { model })
|
||||
}
|
||||
|
||||
pub fn update<C, R>(
|
||||
@ -165,8 +155,7 @@ impl<V: 'static> WeakView<V> {
|
||||
impl<V> Clone for WeakView<V> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: self.state.clone(),
|
||||
render: self.render.clone(),
|
||||
model: self.model.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -178,13 +167,13 @@ struct EraseViewState<V, ParentV> {
|
||||
|
||||
unsafe impl<V, ParentV> Send for EraseViewState<V, ParentV> {}
|
||||
|
||||
impl<V: 'static, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> {
|
||||
impl<V: Render, ParentV: 'static> Component<ParentV> for EraseViewState<V, ParentV> {
|
||||
fn render(self) -> AnyElement<ParentV> {
|
||||
AnyElement::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> {
|
||||
impl<V: Render, ParentV: 'static> Element<ParentV> for EraseViewState<V, ParentV> {
|
||||
type ElementState = AnyBox;
|
||||
|
||||
fn id(&self) -> Option<ElementId> {
|
||||
@ -221,30 +210,43 @@ impl<V: 'static, ParentV: 'static> Element<ParentV> for EraseViewState<V, Parent
|
||||
}
|
||||
|
||||
trait ViewObject: Send + Sync {
|
||||
fn entity_type(&self) -> TypeId;
|
||||
fn entity_id(&self) -> EntityId;
|
||||
fn model(&self) -> AnyModel;
|
||||
fn initialize(&self, cx: &mut WindowContext) -> AnyBox;
|
||||
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId;
|
||||
fn paint(&self, bounds: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext);
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
}
|
||||
|
||||
impl<V: 'static> ViewObject for View<V> {
|
||||
impl<V> ViewObject for View<V>
|
||||
where
|
||||
V: Render,
|
||||
{
|
||||
fn entity_type(&self) -> TypeId {
|
||||
TypeId::of::<V>()
|
||||
}
|
||||
|
||||
fn entity_id(&self) -> EntityId {
|
||||
self.state.entity_id
|
||||
self.model.entity_id
|
||||
}
|
||||
|
||||
fn model(&self) -> AnyModel {
|
||||
self.model.clone().into_any()
|
||||
}
|
||||
|
||||
fn initialize(&self, cx: &mut WindowContext) -> AnyBox {
|
||||
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
|
||||
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
|
||||
self.update(cx, |state, cx| {
|
||||
let mut any_element = Box::new((self.render.lock())(state, cx));
|
||||
let mut any_element = Box::new(AnyElement::new(state.render(cx)));
|
||||
any_element.initialize(state, cx);
|
||||
any_element as AnyBox
|
||||
any_element
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn layout(&self, element: &mut AnyBox, cx: &mut WindowContext) -> LayoutId {
|
||||
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
|
||||
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
|
||||
self.update(cx, |state, cx| {
|
||||
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
|
||||
element.layout(state, cx)
|
||||
@ -253,7 +255,7 @@ impl<V: 'static> ViewObject for View<V> {
|
||||
}
|
||||
|
||||
fn paint(&self, _: Bounds<Pixels>, element: &mut AnyBox, cx: &mut WindowContext) {
|
||||
cx.with_element_id(self.state.entity_id, |_global_id, cx| {
|
||||
cx.with_element_id(self.model.entity_id, |_global_id, cx| {
|
||||
self.update(cx, |state, cx| {
|
||||
let element = element.downcast_mut::<AnyElement<V>>().unwrap();
|
||||
element.paint(state, cx);
|
||||
@ -270,8 +272,12 @@ impl<V: 'static> ViewObject for View<V> {
|
||||
pub struct AnyView(Arc<dyn ViewObject>);
|
||||
|
||||
impl AnyView {
|
||||
pub fn downcast<V: 'static>(&self) -> Option<View<V>> {
|
||||
self.0.as_any().downcast_ref().cloned()
|
||||
pub fn downcast<V: 'static + Send>(self) -> Option<View<V>> {
|
||||
self.0.model().downcast().map(|model| View { model })
|
||||
}
|
||||
|
||||
pub(crate) fn entity_type(&self) -> TypeId {
|
||||
self.0.entity_type()
|
||||
}
|
||||
|
||||
pub(crate) fn draw(&self, available_space: Size<AvailableSpace>, cx: &mut WindowContext) {
|
||||
@ -343,6 +349,18 @@ impl<ParentV: 'static> Component<ParentV> for EraseAnyViewState<ParentV> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E> Render for T
|
||||
where
|
||||
T: 'static + FnMut(&mut WindowContext) -> E,
|
||||
E: 'static + Send + Element<T>,
|
||||
{
|
||||
type Element = E;
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||
(self)(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<ParentV: 'static> Element<ParentV> for EraseAnyViewState<ParentV> {
|
||||
type ElementState = AnyBox;
|
||||
|
||||
|
@ -1,14 +1,14 @@
|
||||
use crate::{
|
||||
px, size, Action, AnyBox, AnyDrag, AnyView, AppContext, AsyncWindowContext, AvailableSpace,
|
||||
Bounds, BoxShadow, Context, Corners, DevicePixels, DispatchContext, DisplayId, Edges, Effect,
|
||||
EntityHandle, EntityId, EventEmitter, ExternalPaths, FileDropEvent, FocusEvent, FontId,
|
||||
GlobalElementId, GlyphId, Handle, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch,
|
||||
KeyMatcher, Keystroke, LayoutId, MainThread, MainThreadOnly, ModelContext, Modifiers,
|
||||
MonochromeSprite, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels,
|
||||
PlatformAtlas, PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams,
|
||||
RenderImageParams, RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size,
|
||||
Style, Subscription, TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext,
|
||||
WeakHandle, WeakView, WindowOptions, SUBPIXEL_VARIANTS,
|
||||
EntityHandle, EntityId, EventEmitter, FileDropEvent, FocusEvent, FontId, GlobalElementId,
|
||||
GlyphId, Hsla, ImageData, InputEvent, IsZero, KeyListener, KeyMatch, KeyMatcher, Keystroke,
|
||||
LayoutId, MainThread, MainThreadOnly, Model, ModelContext, Modifiers, MonochromeSprite,
|
||||
MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas,
|
||||
PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams, RenderImageParams,
|
||||
RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size, Style, Subscription,
|
||||
TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakModel, WeakView,
|
||||
WindowOptions, SUBPIXEL_VARIANTS,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
@ -918,15 +918,13 @@ impl<'a, 'w> WindowContext<'a, 'w> {
|
||||
root_view.draw(available_space, cx);
|
||||
});
|
||||
|
||||
if let Some(mut active_drag) = self.app.active_drag.take() {
|
||||
if let Some(active_drag) = self.app.active_drag.take() {
|
||||
self.stack(1, |cx| {
|
||||
let offset = cx.mouse_position() - active_drag.cursor_offset;
|
||||
cx.with_element_offset(Some(offset), |cx| {
|
||||
let available_space =
|
||||
size(AvailableSpace::MinContent, AvailableSpace::MinContent);
|
||||
if let Some(drag_handle_view) = &mut active_drag.drag_handle_view {
|
||||
drag_handle_view.draw(available_space, cx);
|
||||
}
|
||||
active_drag.view.draw(available_space, cx);
|
||||
cx.active_drag = Some(active_drag);
|
||||
});
|
||||
});
|
||||
@ -994,12 +992,12 @@ impl<'a, 'w> WindowContext<'a, 'w> {
|
||||
InputEvent::FileDrop(file_drop) => match file_drop {
|
||||
FileDropEvent::Entered { position, files } => {
|
||||
self.window.mouse_position = position;
|
||||
self.active_drag.get_or_insert_with(|| AnyDrag {
|
||||
drag_handle_view: None,
|
||||
cursor_offset: position,
|
||||
state: Box::new(files),
|
||||
state_type: TypeId::of::<ExternalPaths>(),
|
||||
});
|
||||
if self.active_drag.is_none() {
|
||||
self.active_drag = Some(AnyDrag {
|
||||
view: self.build_view(|_| files).into_any(),
|
||||
cursor_offset: position,
|
||||
});
|
||||
}
|
||||
InputEvent::MouseDown(MouseDownEvent {
|
||||
position,
|
||||
button: MouseButton::Left,
|
||||
@ -1267,30 +1265,30 @@ impl<'a, 'w> WindowContext<'a, 'w> {
|
||||
}
|
||||
|
||||
impl Context for WindowContext<'_, '_> {
|
||||
type EntityContext<'a, T> = ModelContext<'a, T>;
|
||||
type ModelContext<'a, T> = ModelContext<'a, T>;
|
||||
type Result<T> = T;
|
||||
|
||||
fn entity<T>(
|
||||
fn build_model<T>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Handle<T>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Model<T>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
let slot = self.app.entities.reserve();
|
||||
let entity = build_entity(&mut ModelContext::mutable(&mut *self.app, slot.downgrade()));
|
||||
self.entities.insert(slot, entity)
|
||||
let model = build_model(&mut ModelContext::mutable(&mut *self.app, slot.downgrade()));
|
||||
self.entities.insert(slot, model)
|
||||
}
|
||||
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
model: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> R {
|
||||
let mut entity = self.entities.lease(handle);
|
||||
let mut entity = self.entities.lease(model);
|
||||
let result = update(
|
||||
&mut *entity,
|
||||
&mut ModelContext::mutable(&mut *self.app, handle.downgrade()),
|
||||
&mut ModelContext::mutable(&mut *self.app, model.downgrade()),
|
||||
);
|
||||
self.entities.end_lease(entity);
|
||||
result
|
||||
@ -1300,21 +1298,17 @@ impl Context for WindowContext<'_, '_> {
|
||||
impl VisualContext for WindowContext<'_, '_> {
|
||||
type ViewContext<'a, 'w, V> = ViewContext<'a, 'w, V>;
|
||||
|
||||
/// Builds a new view in the current window. The first argument is a function that builds
|
||||
/// an entity representing the view's state. It is invoked with a `ViewContext` that provides
|
||||
/// entity-specific access to the window and application state during construction. The second
|
||||
/// argument is a render function that returns a component based on the view's state.
|
||||
fn build_view<E, V>(
|
||||
fn build_view<V>(
|
||||
&mut self,
|
||||
build_view_state: impl FnOnce(&mut Self::ViewContext<'_, '_, V>) -> V,
|
||||
render: impl Fn(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static,
|
||||
) -> Self::Result<View<V>>
|
||||
where
|
||||
E: crate::Component<V>,
|
||||
V: 'static + Send,
|
||||
{
|
||||
let slot = self.app.entities.reserve();
|
||||
let view = View::for_handle(slot.clone(), render);
|
||||
let view = View {
|
||||
model: slot.clone(),
|
||||
};
|
||||
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
|
||||
let entity = build_view_state(&mut cx);
|
||||
self.entities.insert(slot, entity);
|
||||
@ -1327,7 +1321,7 @@ impl VisualContext for WindowContext<'_, '_> {
|
||||
view: &View<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ViewContext<'_, '_, T>) -> R,
|
||||
) -> Self::Result<R> {
|
||||
let mut lease = self.app.entities.lease(&view.state);
|
||||
let mut lease = self.app.entities.lease(&view.model);
|
||||
let mut cx = ViewContext::mutable(&mut *self.app, &mut *self.window, view.downgrade());
|
||||
let result = update(&mut *lease, &mut cx);
|
||||
cx.app.entities.end_lease(lease);
|
||||
@ -1582,8 +1576,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
||||
self.view.clone()
|
||||
}
|
||||
|
||||
pub fn handle(&self) -> WeakHandle<V> {
|
||||
self.view.state.clone()
|
||||
pub fn model(&self) -> WeakModel<V> {
|
||||
self.view.model.clone()
|
||||
}
|
||||
|
||||
pub fn stack<R>(&mut self, order: u32, f: impl FnOnce(&mut Self) -> R) -> R {
|
||||
@ -1603,8 +1597,8 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
||||
|
||||
pub fn observe<E>(
|
||||
&mut self,
|
||||
handle: &Handle<E>,
|
||||
mut on_notify: impl FnMut(&mut V, Handle<E>, &mut ViewContext<'_, '_, V>) + Send + 'static,
|
||||
handle: &Model<E>,
|
||||
mut on_notify: impl FnMut(&mut V, Model<E>, &mut ViewContext<'_, '_, V>) + Send + 'static,
|
||||
) -> Subscription
|
||||
where
|
||||
E: 'static,
|
||||
@ -1665,7 +1659,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
||||
) -> Subscription {
|
||||
let window_handle = self.window.handle;
|
||||
self.app.release_listeners.insert(
|
||||
self.view.state.entity_id,
|
||||
self.view.model.entity_id,
|
||||
Box::new(move |this, cx| {
|
||||
let this = this.downcast_mut().expect("invalid entity type");
|
||||
// todo!("are we okay with silently swallowing the error?")
|
||||
@ -1676,7 +1670,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
||||
|
||||
pub fn observe_release<T: 'static>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
handle: &Model<T>,
|
||||
mut on_release: impl FnMut(&mut V, &mut T, &mut ViewContext<'_, '_, V>) + Send + 'static,
|
||||
) -> Subscription
|
||||
where
|
||||
@ -1698,7 +1692,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
|
||||
pub fn notify(&mut self) {
|
||||
self.window_cx.notify();
|
||||
self.window_cx.app.push_effect(Effect::Notify {
|
||||
emitter: self.view.state.entity_id,
|
||||
emitter: self.view.model.entity_id,
|
||||
});
|
||||
}
|
||||
|
||||
@ -1878,7 +1872,7 @@ where
|
||||
V::Event: Any + Send,
|
||||
{
|
||||
pub fn emit(&mut self, event: V::Event) {
|
||||
let emitter = self.view.state.entity_id;
|
||||
let emitter = self.view.model.entity_id;
|
||||
self.app.push_effect(Effect::Emit {
|
||||
emitter,
|
||||
event: Box::new(event),
|
||||
@ -1897,41 +1891,36 @@ impl<'a, 'w, V: 'static> MainThread<ViewContext<'a, 'w, V>> {
|
||||
}
|
||||
|
||||
impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> {
|
||||
type EntityContext<'b, U> = ModelContext<'b, U>;
|
||||
type ModelContext<'b, U> = ModelContext<'b, U>;
|
||||
type Result<U> = U;
|
||||
|
||||
fn entity<T>(
|
||||
fn build_model<T>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::EntityContext<'_, T>) -> T,
|
||||
) -> Handle<T>
|
||||
build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T,
|
||||
) -> Model<T>
|
||||
where
|
||||
T: 'static + Send,
|
||||
{
|
||||
self.window_cx.entity(build_entity)
|
||||
self.window_cx.build_model(build_model)
|
||||
}
|
||||
|
||||
fn update_entity<T: 'static, R>(
|
||||
&mut self,
|
||||
handle: &Handle<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::EntityContext<'_, T>) -> R,
|
||||
model: &Model<T>,
|
||||
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
|
||||
) -> R {
|
||||
self.window_cx.update_entity(handle, update)
|
||||
self.window_cx.update_entity(model, update)
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: 'static> VisualContext for ViewContext<'_, '_, V> {
|
||||
type ViewContext<'a, 'w, V2> = ViewContext<'a, 'w, V2>;
|
||||
|
||||
fn build_view<E, V2>(
|
||||
fn build_view<W: 'static + Send>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Self::ViewContext<'_, '_, V2>) -> V2,
|
||||
render: impl Fn(&mut V2, &mut ViewContext<'_, '_, V2>) -> E + Send + 'static,
|
||||
) -> Self::Result<View<V2>>
|
||||
where
|
||||
E: crate::Component<V2>,
|
||||
V2: 'static + Send,
|
||||
{
|
||||
self.window_cx.build_view(build_entity, render)
|
||||
build_view: impl FnOnce(&mut Self::ViewContext<'_, '_, W>) -> W,
|
||||
) -> Self::Result<View<W>> {
|
||||
self.window_cx.build_view(build_view)
|
||||
}
|
||||
|
||||
fn update_view<V2: 'static, R>(
|
||||
|
@ -5,7 +5,7 @@ use crate::language_settings::{
|
||||
use crate::Buffer;
|
||||
use clock::ReplicaId;
|
||||
use collections::BTreeMap;
|
||||
use gpui2::{AppContext, Handle};
|
||||
use gpui2::{AppContext, Model};
|
||||
use gpui2::{Context, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use proto::deserialize_operation;
|
||||
@ -42,7 +42,7 @@ fn init_logger() {
|
||||
fn test_line_endings(cx: &mut gpui2::AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "one\r\ntwo\rthree")
|
||||
.with_language(Arc::new(rust_lang()), cx);
|
||||
assert_eq!(buffer.text(), "one\ntwo\nthree");
|
||||
@ -138,8 +138,8 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
|
||||
let buffer_1_events = Arc::new(Mutex::new(Vec::new()));
|
||||
let buffer_2_events = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
let buffer1 = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef"));
|
||||
let buffer2 = cx.entity(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef"));
|
||||
let buffer1 = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), "abcdef"));
|
||||
let buffer2 = cx.build_model(|cx| Buffer::new(1, cx.entity_id().as_u64(), "abcdef"));
|
||||
let buffer1_ops = Arc::new(Mutex::new(Vec::new()));
|
||||
buffer1.update(cx, {
|
||||
let buffer1_ops = buffer1_ops.clone();
|
||||
@ -218,7 +218,7 @@ fn test_edit_events(cx: &mut gpui2::AppContext) {
|
||||
#[gpui2::test]
|
||||
async fn test_apply_diff(cx: &mut TestAppContext) {
|
||||
let text = "a\nbb\nccc\ndddd\neeeee\nffffff\n";
|
||||
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
||||
let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
||||
let anchor = buffer.update(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)));
|
||||
|
||||
let text = "a\nccc\ndddd\nffffff\n";
|
||||
@ -250,7 +250,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
|
||||
]
|
||||
.join("\n");
|
||||
|
||||
let buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
||||
let buffer = cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), text));
|
||||
|
||||
// Spawn a task to format the buffer's whitespace.
|
||||
// Pause so that the foratting task starts running.
|
||||
@ -314,7 +314,7 @@ async fn test_normalize_whitespace(cx: &mut gpui2::TestAppContext) {
|
||||
#[gpui2::test]
|
||||
async fn test_reparse(cx: &mut gpui2::TestAppContext) {
|
||||
let text = "fn a() {}";
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||
});
|
||||
|
||||
@ -442,7 +442,7 @@ async fn test_reparse(cx: &mut gpui2::TestAppContext) {
|
||||
|
||||
#[gpui2::test]
|
||||
async fn test_resetting_language(cx: &mut gpui2::TestAppContext) {
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
let mut buffer =
|
||||
Buffer::new(0, cx.entity_id().as_u64(), "{}").with_language(Arc::new(rust_lang()), cx);
|
||||
buffer.set_sync_parse_timeout(Duration::ZERO);
|
||||
@ -492,7 +492,7 @@ async fn test_outline(cx: &mut gpui2::TestAppContext) {
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||
});
|
||||
let outline = buffer
|
||||
@ -578,7 +578,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui2::TestAppContext) {
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||
});
|
||||
let outline = buffer
|
||||
@ -616,7 +616,7 @@ async fn test_outline_with_extra_context(cx: &mut gpui2::TestAppContext) {
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(language), cx)
|
||||
});
|
||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
|
||||
@ -660,7 +660,7 @@ async fn test_symbols_containing(cx: &mut gpui2::TestAppContext) {
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx)
|
||||
});
|
||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
|
||||
@ -881,7 +881,7 @@ fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &
|
||||
|
||||
#[gpui2::test]
|
||||
fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = "fn a() { b(|c| {}) }";
|
||||
let buffer =
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||
@ -922,7 +922,7 @@ fn test_range_for_syntax_ancestor(cx: &mut AppContext) {
|
||||
fn test_autoindent_with_soft_tabs(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = "fn a() {}";
|
||||
let mut buffer =
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||
@ -965,7 +965,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
|
||||
settings.defaults.hard_tabs = Some(true);
|
||||
});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = "fn a() {}";
|
||||
let mut buffer =
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||
@ -1006,7 +1006,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut AppContext) {
|
||||
fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let entity_id = cx.entity_id();
|
||||
let mut buffer = Buffer::new(
|
||||
0,
|
||||
@ -1080,7 +1080,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
|
||||
buffer
|
||||
});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
eprintln!("second buffer: {:?}", cx.entity_id());
|
||||
|
||||
let mut buffer = Buffer::new(
|
||||
@ -1147,7 +1147,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut AppC
|
||||
fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let mut buffer = Buffer::new(
|
||||
0,
|
||||
cx.entity_id().as_u64(),
|
||||
@ -1209,7 +1209,7 @@ fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut Ap
|
||||
fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let mut buffer = Buffer::new(
|
||||
0,
|
||||
cx.entity_id().as_u64(),
|
||||
@ -1266,7 +1266,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut AppContext) {
|
||||
fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = "a\nb";
|
||||
let mut buffer =
|
||||
Buffer::new(0, cx.entity_id().as_u64(), text).with_language(Arc::new(rust_lang()), cx);
|
||||
@ -1284,7 +1284,7 @@ fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut AppContext) {
|
||||
fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = "
|
||||
const a: usize = 1;
|
||||
fn b() {
|
||||
@ -1326,7 +1326,7 @@ fn test_autoindent_multi_line_insertion(cx: &mut AppContext) {
|
||||
fn test_autoindent_block_mode(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = r#"
|
||||
fn a() {
|
||||
b();
|
||||
@ -1410,7 +1410,7 @@ fn test_autoindent_block_mode(cx: &mut AppContext) {
|
||||
fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = r#"
|
||||
fn a() {
|
||||
if b() {
|
||||
@ -1490,7 +1490,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut AppContex
|
||||
fn test_autoindent_language_without_indents_query(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = "
|
||||
* one
|
||||
- a
|
||||
@ -1559,7 +1559,7 @@ fn test_autoindent_with_injected_languages(cx: &mut AppContext) {
|
||||
language_registry.add(html_language.clone());
|
||||
language_registry.add(javascript_language.clone());
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let (text, ranges) = marked_text_ranges(
|
||||
&"
|
||||
<div>ˇ
|
||||
@ -1610,7 +1610,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
|
||||
settings.defaults.tab_size = Some(2.try_into().unwrap());
|
||||
});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let mut buffer =
|
||||
Buffer::new(0, cx.entity_id().as_u64(), "").with_language(Arc::new(ruby_lang()), cx);
|
||||
|
||||
@ -1653,7 +1653,7 @@ fn test_autoindent_query_with_outdent_captures(cx: &mut AppContext) {
|
||||
fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let language = Language::new(
|
||||
LanguageConfig {
|
||||
name: "JavaScript".into(),
|
||||
@ -1742,7 +1742,7 @@ fn test_language_scope_at_with_javascript(cx: &mut AppContext) {
|
||||
fn test_language_scope_at_with_rust(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let language = Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
@ -1810,7 +1810,7 @@ fn test_language_scope_at_with_rust(cx: &mut AppContext) {
|
||||
fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
|
||||
init_settings(cx, |_| {});
|
||||
|
||||
cx.entity(|cx| {
|
||||
cx.build_model(|cx| {
|
||||
let text = r#"
|
||||
<ol>
|
||||
<% people.each do |person| %>
|
||||
@ -1858,7 +1858,7 @@ fn test_language_scope_at_with_combined_injections(cx: &mut AppContext) {
|
||||
fn test_serialization(cx: &mut gpui2::AppContext) {
|
||||
let mut now = Instant::now();
|
||||
|
||||
let buffer1 = cx.entity(|cx| {
|
||||
let buffer1 = cx.build_model(|cx| {
|
||||
let mut buffer = Buffer::new(0, cx.entity_id().as_u64(), "abc");
|
||||
buffer.edit([(3..3, "D")], None, cx);
|
||||
|
||||
@ -1881,7 +1881,7 @@ fn test_serialization(cx: &mut gpui2::AppContext) {
|
||||
let ops = cx
|
||||
.executor()
|
||||
.block(buffer1.read(cx).serialize_ops(None, cx));
|
||||
let buffer2 = cx.entity(|cx| {
|
||||
let buffer2 = cx.build_model(|cx| {
|
||||
let mut buffer = Buffer::from_proto(1, state, None).unwrap();
|
||||
buffer
|
||||
.apply_ops(
|
||||
@ -1914,10 +1914,11 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
|
||||
let mut replica_ids = Vec::new();
|
||||
let mut buffers = Vec::new();
|
||||
let network = Arc::new(Mutex::new(Network::new(rng.clone())));
|
||||
let base_buffer = cx.entity(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str()));
|
||||
let base_buffer =
|
||||
cx.build_model(|cx| Buffer::new(0, cx.entity_id().as_u64(), base_text.as_str()));
|
||||
|
||||
for i in 0..rng.gen_range(min_peers..=max_peers) {
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
let state = base_buffer.read(cx).to_proto();
|
||||
let ops = cx
|
||||
.executor()
|
||||
@ -2034,7 +2035,7 @@ fn test_random_collaboration(cx: &mut AppContext, mut rng: StdRng) {
|
||||
new_replica_id,
|
||||
replica_id
|
||||
);
|
||||
new_buffer = Some(cx.entity(|cx| {
|
||||
new_buffer = Some(cx.build_model(|cx| {
|
||||
let mut new_buffer =
|
||||
Buffer::from_proto(new_replica_id, old_buffer_state, None).unwrap();
|
||||
new_buffer
|
||||
@ -2396,7 +2397,7 @@ fn javascript_lang() -> Language {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn get_tree_sexp(buffer: &Handle<Buffer>, cx: &mut gpui2::TestAppContext) -> String {
|
||||
fn get_tree_sexp(buffer: &Model<Buffer>, cx: &mut gpui2::TestAppContext) -> String {
|
||||
buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let layers = snapshot.syntax.layers(buffer.as_text_snapshot());
|
||||
@ -2412,7 +2413,7 @@ fn assert_bracket_pairs(
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false);
|
||||
let buffer = cx.entity(|cx| {
|
||||
let buffer = cx.build_model(|cx| {
|
||||
Buffer::new(0, cx.entity_id().as_u64(), expected_text.clone())
|
||||
.with_language(Arc::new(language), cx)
|
||||
});
|
||||
|
12
crates/menu2/Cargo.toml
Normal file
12
crates/menu2/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "menu2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/menu2.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
gpui2 = { path = "../gpui2" }
|
25
crates/menu2/src/menu2.rs
Normal file
25
crates/menu2/src/menu2.rs
Normal file
@ -0,0 +1,25 @@
|
||||
// todo!(use actions! macro)
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct Cancel;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct Confirm;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct SecondaryConfirm;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct SelectPrev;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct SelectNext;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct SelectFirst;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct SelectLast;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct ShowContextMenu;
|
@ -16,7 +16,7 @@ client2 = { path = "../client2" }
|
||||
collections = { path = "../collections"}
|
||||
language2 = { path = "../language2" }
|
||||
gpui2 = { path = "../gpui2" }
|
||||
fs = { path = "../fs" }
|
||||
fs2 = { path = "../fs2" }
|
||||
lsp2 = { path = "../lsp2" }
|
||||
node_runtime = { path = "../node_runtime"}
|
||||
util = { path = "../util" }
|
||||
@ -32,4 +32,4 @@ parking_lot.workspace = true
|
||||
[dev-dependencies]
|
||||
language2 = { path = "../language2", features = ["test-support"] }
|
||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
fs2 = { path = "../fs2", features = ["test-support"] }
|
||||
|
@ -1,7 +1,7 @@
|
||||
use anyhow::Context;
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::Fs;
|
||||
use gpui2::{AsyncAppContext, Handle};
|
||||
use fs2::Fs;
|
||||
use gpui2::{AsyncAppContext, Model};
|
||||
use language2::{language_settings::language_settings, Buffer, BundledFormatter, Diff};
|
||||
use lsp2::{LanguageServer, LanguageServerId};
|
||||
use node_runtime::NodeRuntime;
|
||||
@ -183,7 +183,7 @@ impl Prettier {
|
||||
|
||||
pub async fn format(
|
||||
&self,
|
||||
buffer: &Handle<Buffer>,
|
||||
buffer: &Model<Buffer>,
|
||||
buffer_path: Option<PathBuf>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<Diff> {
|
||||
|
@ -25,7 +25,7 @@ client2 = { path = "../client2" }
|
||||
clock = { path = "../clock" }
|
||||
collections = { path = "../collections" }
|
||||
db2 = { path = "../db2" }
|
||||
fs = { path = "../fs" }
|
||||
fs2 = { path = "../fs2" }
|
||||
fsevent = { path = "../fsevent" }
|
||||
fuzzy2 = { path = "../fuzzy2" }
|
||||
git = { path = "../git" }
|
||||
@ -71,7 +71,7 @@ pretty_assertions.workspace = true
|
||||
client2 = { path = "../client2", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
db2 = { path = "../db2", features = ["test-support"] }
|
||||
fs = { path = "../fs", features = ["test-support"] }
|
||||
fs2 = { path = "../fs2", features = ["test-support"] }
|
||||
gpui2 = { path = "../gpui2", features = ["test-support"] }
|
||||
language2 = { path = "../language2", features = ["test-support"] }
|
||||
lsp2 = { path = "../lsp2", features = ["test-support"] }
|
||||
|
@ -7,7 +7,7 @@ use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use client2::proto::{self, PeerId};
|
||||
use futures::future;
|
||||
use gpui2::{AppContext, AsyncAppContext, Handle};
|
||||
use gpui2::{AppContext, AsyncAppContext, Model};
|
||||
use language2::{
|
||||
language_settings::{language_settings, InlayHintKind},
|
||||
point_from_lsp, point_to_lsp,
|
||||
@ -53,8 +53,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: <Self::LspRequest as lsp2::request::Request>::Result,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Self::Response>;
|
||||
@ -63,8 +63,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
|
||||
|
||||
async fn from_proto(
|
||||
message: Self::ProtoRequest,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Self>;
|
||||
|
||||
@ -79,8 +79,8 @@ pub(crate) trait LspCommand: 'static + Sized + Send {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: <Self::ProtoRequest as proto::RequestMessage>::Response,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Self::Response>;
|
||||
|
||||
@ -180,8 +180,8 @@ impl LspCommand for PrepareRename {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: Option<lsp2::PrepareRenameResponse>,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
_: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Option<Range<Anchor>>> {
|
||||
@ -215,8 +215,8 @@ impl LspCommand for PrepareRename {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::PrepareRename,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -256,8 +256,8 @@ impl LspCommand for PrepareRename {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::PrepareRenameResponse,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Option<Range<Anchor>>> {
|
||||
if message.can_rename {
|
||||
@ -307,8 +307,8 @@ impl LspCommand for PerformRename {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: Option<lsp2::WorkspaceEdit>,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<ProjectTransaction> {
|
||||
@ -343,8 +343,8 @@ impl LspCommand for PerformRename {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::PerformRename,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -379,8 +379,8 @@ impl LspCommand for PerformRename {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::PerformRenameResponse,
|
||||
project: Handle<Project>,
|
||||
_: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
_: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<ProjectTransaction> {
|
||||
let message = message
|
||||
@ -426,8 +426,8 @@ impl LspCommand for GetDefinition {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: Option<lsp2::GotoDefinitionResponse>,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Vec<LocationLink>> {
|
||||
@ -447,8 +447,8 @@ impl LspCommand for GetDefinition {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::GetDefinition,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -479,8 +479,8 @@ impl LspCommand for GetDefinition {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::GetDefinitionResponse,
|
||||
project: Handle<Project>,
|
||||
_: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
_: Model<Buffer>,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Vec<LocationLink>> {
|
||||
location_links_from_proto(message.links, project, cx).await
|
||||
@ -527,8 +527,8 @@ impl LspCommand for GetTypeDefinition {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: Option<lsp2::GotoTypeDefinitionResponse>,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Vec<LocationLink>> {
|
||||
@ -548,8 +548,8 @@ impl LspCommand for GetTypeDefinition {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::GetTypeDefinition,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -580,8 +580,8 @@ impl LspCommand for GetTypeDefinition {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::GetTypeDefinitionResponse,
|
||||
project: Handle<Project>,
|
||||
_: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
_: Model<Buffer>,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Vec<LocationLink>> {
|
||||
location_links_from_proto(message.links, project, cx).await
|
||||
@ -593,8 +593,8 @@ impl LspCommand for GetTypeDefinition {
|
||||
}
|
||||
|
||||
fn language_server_for_buffer(
|
||||
project: &Handle<Project>,
|
||||
buffer: &Handle<Buffer>,
|
||||
project: &Model<Project>,
|
||||
buffer: &Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> Result<(Arc<CachedLspAdapter>, Arc<LanguageServer>)> {
|
||||
@ -609,7 +609,7 @@ fn language_server_for_buffer(
|
||||
|
||||
async fn location_links_from_proto(
|
||||
proto_links: Vec<proto::LocationLink>,
|
||||
project: Handle<Project>,
|
||||
project: Model<Project>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<LocationLink>> {
|
||||
let mut links = Vec::new();
|
||||
@ -671,8 +671,8 @@ async fn location_links_from_proto(
|
||||
|
||||
async fn location_links_from_lsp(
|
||||
message: Option<lsp2::GotoDefinitionResponse>,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<LocationLink>> {
|
||||
@ -814,8 +814,8 @@ impl LspCommand for GetReferences {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
locations: Option<Vec<lsp2::Location>>,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<Location>> {
|
||||
@ -868,8 +868,8 @@ impl LspCommand for GetReferences {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::GetReferences,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -910,8 +910,8 @@ impl LspCommand for GetReferences {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::GetReferencesResponse,
|
||||
project: Handle<Project>,
|
||||
_: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
_: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<Location>> {
|
||||
let mut locations = Vec::new();
|
||||
@ -977,8 +977,8 @@ impl LspCommand for GetDocumentHighlights {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
lsp_highlights: Option<Vec<lsp2::DocumentHighlight>>,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
_: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<DocumentHighlight>> {
|
||||
@ -1016,8 +1016,8 @@ impl LspCommand for GetDocumentHighlights {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::GetDocumentHighlights,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -1060,8 +1060,8 @@ impl LspCommand for GetDocumentHighlights {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::GetDocumentHighlightsResponse,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<DocumentHighlight>> {
|
||||
let mut highlights = Vec::new();
|
||||
@ -1123,8 +1123,8 @@ impl LspCommand for GetHover {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: Option<lsp2::Hover>,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
_: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self::Response> {
|
||||
@ -1206,8 +1206,8 @@ impl LspCommand for GetHover {
|
||||
|
||||
async fn from_proto(
|
||||
message: Self::ProtoRequest,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -1272,8 +1272,8 @@ impl LspCommand for GetHover {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::GetHoverResponse,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self::Response> {
|
||||
let contents: Vec<_> = message
|
||||
@ -1341,8 +1341,8 @@ impl LspCommand for GetCompletions {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
completions: Option<lsp2::CompletionResponse>,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<Completion>> {
|
||||
@ -1484,8 +1484,8 @@ impl LspCommand for GetCompletions {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::GetCompletions,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let version = deserialize_version(&message.version);
|
||||
@ -1523,8 +1523,8 @@ impl LspCommand for GetCompletions {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::GetCompletionsResponse,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<Completion>> {
|
||||
buffer
|
||||
@ -1589,8 +1589,8 @@ impl LspCommand for GetCodeActions {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
actions: Option<lsp2::CodeActionResponse>,
|
||||
_: Handle<Project>,
|
||||
_: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
_: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
_: AsyncAppContext,
|
||||
) -> Result<Vec<CodeAction>> {
|
||||
@ -1623,8 +1623,8 @@ impl LspCommand for GetCodeActions {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::GetCodeActions,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let start = message
|
||||
@ -1663,8 +1663,8 @@ impl LspCommand for GetCodeActions {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::GetCodeActionsResponse,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Vec<CodeAction>> {
|
||||
buffer
|
||||
@ -1726,8 +1726,8 @@ impl LspCommand for OnTypeFormatting {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: Option<Vec<lsp2::TextEdit>>,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Option<Transaction>> {
|
||||
@ -1763,8 +1763,8 @@ impl LspCommand for OnTypeFormatting {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::OnTypeFormatting,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let position = message
|
||||
@ -1805,8 +1805,8 @@ impl LspCommand for OnTypeFormatting {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::OnTypeFormattingResponse,
|
||||
_: Handle<Project>,
|
||||
_: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
_: Model<Buffer>,
|
||||
_: AsyncAppContext,
|
||||
) -> Result<Option<Transaction>> {
|
||||
let Some(transaction) = message.transaction else {
|
||||
@ -1825,7 +1825,7 @@ impl LspCommand for OnTypeFormatting {
|
||||
impl InlayHints {
|
||||
pub async fn lsp_to_project_hint(
|
||||
lsp_hint: lsp2::InlayHint,
|
||||
buffer_handle: &Handle<Buffer>,
|
||||
buffer_handle: &Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
resolve_state: ResolveState,
|
||||
force_no_type_left_padding: bool,
|
||||
@ -2230,8 +2230,8 @@ impl LspCommand for InlayHints {
|
||||
async fn response_from_lsp(
|
||||
self,
|
||||
message: Option<Vec<lsp2::InlayHint>>,
|
||||
project: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
project: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> anyhow::Result<Vec<InlayHint>> {
|
||||
@ -2286,8 +2286,8 @@ impl LspCommand for InlayHints {
|
||||
|
||||
async fn from_proto(
|
||||
message: proto::InlayHints,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<Self> {
|
||||
let start = message
|
||||
@ -2326,8 +2326,8 @@ impl LspCommand for InlayHints {
|
||||
async fn response_from_proto(
|
||||
self,
|
||||
message: proto::InlayHintsResponse,
|
||||
_: Handle<Project>,
|
||||
buffer: Handle<Buffer>,
|
||||
_: Model<Project>,
|
||||
buffer: Model<Buffer>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> anyhow::Result<Vec<InlayHint>> {
|
||||
buffer
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
use crate::Project;
|
||||
use gpui2::{AnyWindowHandle, Context, Handle, ModelContext, WeakHandle};
|
||||
use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakModel};
|
||||
use settings2::Settings;
|
||||
use std::path::{Path, PathBuf};
|
||||
use terminal2::{
|
||||
@ -11,7 +11,7 @@ use terminal2::{
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
|
||||
pub struct Terminals {
|
||||
pub(crate) local_handles: Vec<WeakHandle<terminal2::Terminal>>,
|
||||
pub(crate) local_handles: Vec<WeakModel<terminal2::Terminal>>,
|
||||
}
|
||||
|
||||
impl Project {
|
||||
@ -20,7 +20,7 @@ impl Project {
|
||||
working_directory: Option<PathBuf>,
|
||||
window: AnyWindowHandle,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> anyhow::Result<Handle<Terminal>> {
|
||||
) -> anyhow::Result<Model<Terminal>> {
|
||||
if self.is_remote() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"creating terminals as a guest is not supported yet"
|
||||
@ -40,7 +40,7 @@ impl Project {
|
||||
|_, _| todo!("color_for_index"),
|
||||
)
|
||||
.map(|builder| {
|
||||
let terminal_handle = cx.entity(|cx| builder.subscribe(cx));
|
||||
let terminal_handle = cx.build_model(|cx| builder.subscribe(cx));
|
||||
|
||||
self.terminals
|
||||
.local_handles
|
||||
@ -108,7 +108,7 @@ impl Project {
|
||||
fn activate_python_virtual_environment(
|
||||
&mut self,
|
||||
activate_script: Option<PathBuf>,
|
||||
terminal_handle: &Handle<Terminal>,
|
||||
terminal_handle: &Model<Terminal>,
|
||||
cx: &mut ModelContext<Project>,
|
||||
) {
|
||||
if let Some(activate_script) = activate_script {
|
||||
@ -121,7 +121,7 @@ impl Project {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_terminal_handles(&self) -> &Vec<WeakHandle<terminal2::Terminal>> {
|
||||
pub fn local_terminal_handles(&self) -> &Vec<WeakModel<terminal2::Terminal>> {
|
||||
&self.terminals.local_handles
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ use anyhow::{anyhow, Context as _, Result};
|
||||
use client2::{proto, Client};
|
||||
use clock::ReplicaId;
|
||||
use collections::{HashMap, HashSet, VecDeque};
|
||||
use fs::{
|
||||
use fs2::{
|
||||
repository::{GitFileStatus, GitRepository, RepoPath},
|
||||
Fs,
|
||||
};
|
||||
@ -22,7 +22,7 @@ use futures::{
|
||||
use fuzzy2::CharBag;
|
||||
use git::{DOT_GIT, GITIGNORE};
|
||||
use gpui2::{
|
||||
AppContext, AsyncAppContext, Context, EventEmitter, Executor, Handle, ModelContext, Task,
|
||||
AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext, Task,
|
||||
};
|
||||
use language2::{
|
||||
proto::{
|
||||
@ -292,7 +292,7 @@ impl Worktree {
|
||||
fs: Arc<dyn Fs>,
|
||||
next_entry_id: Arc<AtomicUsize>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> Result<Handle<Self>> {
|
||||
) -> Result<Model<Self>> {
|
||||
// After determining whether the root entry is a file or a directory, populate the
|
||||
// snapshot's "root name", which will be used for the purpose of fuzzy matching.
|
||||
let abs_path = path.into();
|
||||
@ -301,7 +301,7 @@ impl Worktree {
|
||||
.await
|
||||
.context("failed to stat worktree path")?;
|
||||
|
||||
cx.entity(move |cx: &mut ModelContext<Worktree>| {
|
||||
cx.build_model(move |cx: &mut ModelContext<Worktree>| {
|
||||
let root_name = abs_path
|
||||
.file_name()
|
||||
.map_or(String::new(), |f| f.to_string_lossy().to_string());
|
||||
@ -406,8 +406,8 @@ impl Worktree {
|
||||
worktree: proto::WorktreeMetadata,
|
||||
client: Arc<Client>,
|
||||
cx: &mut AppContext,
|
||||
) -> Handle<Self> {
|
||||
cx.entity(|cx: &mut ModelContext<Self>| {
|
||||
) -> Model<Self> {
|
||||
cx.build_model(|cx: &mut ModelContext<Self>| {
|
||||
let snapshot = Snapshot {
|
||||
id: WorktreeId(worktree.id as usize),
|
||||
abs_path: Arc::from(PathBuf::from(worktree.abs_path)),
|
||||
@ -593,7 +593,7 @@ impl LocalWorktree {
|
||||
id: u64,
|
||||
path: &Path,
|
||||
cx: &mut ModelContext<Worktree>,
|
||||
) -> Task<Result<Handle<Buffer>>> {
|
||||
) -> Task<Result<Model<Buffer>>> {
|
||||
let path = Arc::from(path);
|
||||
cx.spawn(move |this, mut cx| async move {
|
||||
let (file, contents, diff_base) = this
|
||||
@ -603,7 +603,7 @@ impl LocalWorktree {
|
||||
.executor()
|
||||
.spawn(async move { text::Buffer::new(0, id, contents) })
|
||||
.await;
|
||||
cx.entity(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file))))
|
||||
cx.build_model(|_| Buffer::build(text_buffer, diff_base, Some(Arc::new(file))))
|
||||
})
|
||||
}
|
||||
|
||||
@ -920,7 +920,7 @@ impl LocalWorktree {
|
||||
|
||||
pub fn save_buffer(
|
||||
&self,
|
||||
buffer_handle: Handle<Buffer>,
|
||||
buffer_handle: Model<Buffer>,
|
||||
path: Arc<Path>,
|
||||
has_changed_file: bool,
|
||||
cx: &mut ModelContext<Worktree>,
|
||||
@ -1331,7 +1331,7 @@ impl RemoteWorktree {
|
||||
|
||||
pub fn save_buffer(
|
||||
&self,
|
||||
buffer_handle: Handle<Buffer>,
|
||||
buffer_handle: Model<Buffer>,
|
||||
cx: &mut ModelContext<Worktree>,
|
||||
) -> Task<Result<()>> {
|
||||
let buffer = buffer_handle.read(cx);
|
||||
@ -2577,7 +2577,7 @@ impl fmt::Debug for Snapshot {
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct File {
|
||||
pub worktree: Handle<Worktree>,
|
||||
pub worktree: Model<Worktree>,
|
||||
pub path: Arc<Path>,
|
||||
pub mtime: SystemTime,
|
||||
pub(crate) entry_id: ProjectEntryId,
|
||||
@ -2701,7 +2701,7 @@ impl language2::LocalFile for File {
|
||||
}
|
||||
|
||||
impl File {
|
||||
pub fn for_entry(entry: Entry, worktree: Handle<Worktree>) -> Arc<Self> {
|
||||
pub fn for_entry(entry: Entry, worktree: Model<Worktree>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
worktree,
|
||||
path: entry.path.clone(),
|
||||
@ -2714,7 +2714,7 @@ impl File {
|
||||
|
||||
pub fn from_proto(
|
||||
proto: rpc2::proto::File,
|
||||
worktree: Handle<Worktree>,
|
||||
worktree: Model<Worktree>,
|
||||
cx: &AppContext,
|
||||
) -> Result<Self> {
|
||||
let worktree_id = worktree
|
||||
@ -2815,7 +2815,7 @@ pub type UpdatedGitRepositoriesSet = Arc<[(Arc<Path>, GitRepositoryChange)]>;
|
||||
impl Entry {
|
||||
fn new(
|
||||
path: Arc<Path>,
|
||||
metadata: &fs::Metadata,
|
||||
metadata: &fs2::Metadata,
|
||||
next_entry_id: &AtomicUsize,
|
||||
root_char_bag: CharBag,
|
||||
) -> Self {
|
||||
|
@ -42,6 +42,7 @@ sha1 = "0.10.5"
|
||||
ndarray = { version = "0.15.0" }
|
||||
|
||||
[dev-dependencies]
|
||||
ai = { path = "../ai", features = ["test-support"] }
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
|
@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
|
||||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
|
||||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
executor: Arc<Background>,
|
||||
api_key: Option<String>,
|
||||
) -> Self {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
@ -64,14 +59,9 @@ impl EmbeddingQueue {
|
||||
pending_batch_token_count: 0,
|
||||
finished_files_tx,
|
||||
finished_files_rx,
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_api_key(&mut self, api_key: Option<String>) {
|
||||
self.api_key = api_key
|
||||
}
|
||||
|
||||
pub fn push(&mut self, file: FileToEmbed) {
|
||||
if file.spans.is_empty() {
|
||||
self.finished_files_tx.try_send(file).unwrap();
|
||||
@ -118,7 +108,6 @@ impl EmbeddingQueue {
|
||||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
@ -143,7 +132,7 @@ impl EmbeddingQueue {
|
||||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans, api_key).await {
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
|
@ -1,4 +1,7 @@
|
||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::{
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::TruncationDirection,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use rusqlite::{
|
||||
@ -108,7 +111,14 @@ impl CodeContextRetriever {
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
@ -131,7 +141,15 @@ impl CodeContextRetriever {
|
||||
)
|
||||
.replace("<item>", &content);
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_span = model.truncate(
|
||||
&document_span,
|
||||
model.capacity()?,
|
||||
ai::models::TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_span)?;
|
||||
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
@ -222,8 +240,13 @@ impl CodeContextRetriever {
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &span.content);
|
||||
|
||||
let (document_content, token_count) =
|
||||
self.embedding_provider.truncate(&document_content);
|
||||
let model = self.embedding_provider.base_model();
|
||||
let document_content = model.truncate(
|
||||
&document_content,
|
||||
model.capacity()?,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
let token_count = model.count_tokens(&document_content)?;
|
||||
|
||||
span.content = document_content;
|
||||
span.token_count = token_count;
|
||||
|
@ -7,7 +7,8 @@ pub mod semantic_index_settings;
|
||||
mod semantic_index_tests;
|
||||
|
||||
use crate::semantic_index_settings::SemanticIndexSettings;
|
||||
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
|
||||
use ai::embedding::{Embedding, EmbeddingProvider};
|
||||
use ai::providers::open_ai::OpenAIEmbeddingProvider;
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use db::VectorDatabase;
|
||||
@ -88,7 +89,7 @@ pub fn init(
|
||||
let semantic_index = SemanticIndex::new(
|
||||
fs,
|
||||
db_file_path,
|
||||
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
|
||||
Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
|
||||
language_registry,
|
||||
cx.clone(),
|
||||
)
|
||||
@ -123,8 +124,6 @@ pub struct SemanticIndex {
|
||||
_embedding_task: Task<()>,
|
||||
_parsing_files_tasks: Vec<Task<()>>,
|
||||
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
|
||||
api_key: Option<String>,
|
||||
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
|
||||
}
|
||||
|
||||
struct ProjectState {
|
||||
@ -278,18 +277,18 @@ impl SemanticIndex {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authenticate(&mut self, cx: &AppContext) {
|
||||
if self.api_key.is_none() {
|
||||
self.api_key = self.embedding_provider.retrieve_credentials(cx);
|
||||
|
||||
self.embedding_queue
|
||||
.lock()
|
||||
.set_api_key(self.api_key.clone());
|
||||
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
|
||||
if !self.embedding_provider.has_credentials() {
|
||||
self.embedding_provider.retrieve_credentials(cx);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.embedding_provider.has_credentials()
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.embedding_provider.has_credentials()
|
||||
}
|
||||
|
||||
pub fn enabled(cx: &AppContext) -> bool {
|
||||
@ -339,7 +338,7 @@ impl SemanticIndex {
|
||||
Ok(cx.add_model(|cx| {
|
||||
let t0 = Instant::now();
|
||||
let embedding_queue =
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
|
||||
EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
|
||||
let _embedding_task = cx.background().spawn({
|
||||
let embedded_files = embedding_queue.finished_files();
|
||||
let db = db.clone();
|
||||
@ -404,8 +403,6 @@ impl SemanticIndex {
|
||||
_embedding_task,
|
||||
_parsing_files_tasks,
|
||||
projects: Default::default(),
|
||||
api_key: None,
|
||||
embedding_queue
|
||||
}
|
||||
}))
|
||||
}
|
||||
@ -720,13 +717,13 @@ impl SemanticIndex {
|
||||
|
||||
let index = self.index_project(project.clone(), cx);
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
index.await?;
|
||||
let t0 = Instant::now();
|
||||
|
||||
let query = embedding_provider
|
||||
.embed_batch(vec![query], api_key)
|
||||
.embed_batch(vec![query])
|
||||
.await?
|
||||
.pop()
|
||||
.ok_or_else(|| anyhow!("could not embed query"))?;
|
||||
@ -944,7 +941,6 @@ impl SemanticIndex {
|
||||
let fs = self.fs.clone();
|
||||
let db_path = self.db.path().clone();
|
||||
let background = cx.background().clone();
|
||||
let api_key = self.api_key.clone();
|
||||
cx.background().spawn(async move {
|
||||
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
|
||||
let mut results = Vec::<SearchResult>::new();
|
||||
@ -959,15 +955,10 @@ impl SemanticIndex {
|
||||
.parse_file_with_template(None, &snapshot.text(), language)
|
||||
.log_err()
|
||||
.unwrap_or_default();
|
||||
if Self::embed_spans(
|
||||
&mut spans,
|
||||
embedding_provider.as_ref(),
|
||||
&db,
|
||||
api_key.clone(),
|
||||
)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
|
||||
.await
|
||||
.log_err()
|
||||
.is_some()
|
||||
{
|
||||
for span in spans {
|
||||
let similarity = span.embedding.unwrap().similarity(&query);
|
||||
@ -1007,9 +998,8 @@ impl SemanticIndex {
|
||||
project: ModelHandle<Project>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
if self.api_key.is_none() {
|
||||
self.authenticate(cx);
|
||||
if self.api_key.is_none() {
|
||||
if !self.is_authenticated() {
|
||||
if !self.authenticate(cx) {
|
||||
return Task::ready(Err(anyhow!("user is not authenticated")));
|
||||
}
|
||||
}
|
||||
@ -1192,7 +1182,6 @@ impl SemanticIndex {
|
||||
spans: &mut [Span],
|
||||
embedding_provider: &dyn EmbeddingProvider,
|
||||
db: &VectorDatabase,
|
||||
api_key: Option<String>,
|
||||
) -> Result<()> {
|
||||
let mut batch = Vec::new();
|
||||
let mut batch_tokens = 0;
|
||||
@ -1215,7 +1204,7 @@ impl SemanticIndex {
|
||||
|
||||
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch), api_key.clone())
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
embeddings.extend(batch_embeddings);
|
||||
batch_tokens = 0;
|
||||
@ -1227,7 +1216,7 @@ impl SemanticIndex {
|
||||
|
||||
if !batch.is_empty() {
|
||||
let batch_embeddings = embedding_provider
|
||||
.embed_batch(mem::take(&mut batch), api_key)
|
||||
.embed_batch(mem::take(&mut batch))
|
||||
.await?;
|
||||
|
||||
embeddings.extend(batch_embeddings);
|
||||
|
@ -4,10 +4,9 @@ use crate::{
|
||||
semantic_index_settings::SemanticIndexSettings,
|
||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||
};
|
||||
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
|
||||
use ai::test::FakeEmbeddingProvider;
|
||||
|
||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
|
||||
use rand::{rngs::StdRng, Rng};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{self, AtomicUsize},
|
||||
Arc,
|
||||
},
|
||||
time::{Instant, SystemTime},
|
||||
};
|
||||
use std::{path::Path, sync::Arc, time::SystemTime};
|
||||
use unindent::Unindent;
|
||||
use util::RandomCharIter;
|
||||
|
||||
@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||
for file in &files {
|
||||
queue.push(file.clone());
|
||||
}
|
||||
@ -280,7 +272,7 @@ fn assert_search_results(
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_rust() {
|
||||
let language = rust_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_json() {
|
||||
let language = json_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
@ -466,7 +458,7 @@ fn assert_documents_eq(
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_javascript() {
|
||||
let language = js_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_lua() {
|
||||
let language = lua_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_elixir() {
|
||||
let language = elixir_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_cpp() {
|
||||
let language = cpp_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_ruby() {
|
||||
let language = ruby_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_php() {
|
||||
let language = php_lang();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
|
||||
Some("Fake Credentials".to_string())
|
||||
}
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
(span.to_string(), 1)
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
200
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(
|
||||
&self,
|
||||
spans: Vec<String>,
|
||||
_api_key: Option<String>,
|
||||
) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
fn js_lang() -> Arc<Language> {
|
||||
Arc::new(
|
||||
Language::new(
|
||||
|
@ -1,9 +1,11 @@
|
||||
mod colors;
|
||||
mod focus;
|
||||
mod kitchen_sink;
|
||||
mod scroll;
|
||||
mod text;
|
||||
mod z_index;
|
||||
|
||||
pub use colors::*;
|
||||
pub use focus::*;
|
||||
pub use kitchen_sink::*;
|
||||
pub use scroll::*;
|
||||
|
38
crates/storybook2/src/stories/colors.rs
Normal file
38
crates/storybook2/src/stories/colors.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use crate::story::Story;
|
||||
use gpui2::{px, Div, Render};
|
||||
use ui::prelude::*;
|
||||
|
||||
pub struct ColorsStory;
|
||||
|
||||
impl Render for ColorsStory {
|
||||
type Element = Div<Self>;
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||
let color_scales = theme2::default_color_scales();
|
||||
|
||||
Story::container(cx)
|
||||
.child(Story::title(cx, "Colors"))
|
||||
.child(
|
||||
div()
|
||||
.id("colors")
|
||||
.flex()
|
||||
.flex_col()
|
||||
.gap_1()
|
||||
.overflow_y_scroll()
|
||||
.text_color(gpui2::white())
|
||||
.children(color_scales.into_iter().map(|(name, scale)| {
|
||||
div()
|
||||
.flex()
|
||||
.child(
|
||||
div()
|
||||
.w(px(75.))
|
||||
.line_height(px(24.))
|
||||
.child(name.to_string()),
|
||||
)
|
||||
.child(div().flex().gap_1().children(
|
||||
(1..=12).map(|step| div().flex().size_6().bg(scale.step(cx, step))),
|
||||
))
|
||||
})),
|
||||
)
|
||||
}
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
use crate::themes::rose_pine;
|
||||
use gpui2::{
|
||||
div, Focusable, KeyBinding, ParentElement, StatelessInteractive, Styled, View, VisualContext,
|
||||
WindowContext,
|
||||
div, Div, FocusEnabled, Focusable, KeyBinding, ParentElement, Render, StatefulInteraction,
|
||||
StatelessInteractive, Styled, View, VisualContext, WindowContext,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use theme2::theme;
|
||||
|
||||
#[derive(Clone, Default, PartialEq, Deserialize)]
|
||||
struct ActionA;
|
||||
@ -14,12 +14,10 @@ struct ActionB;
|
||||
#[derive(Clone, Default, PartialEq, Deserialize)]
|
||||
struct ActionC;
|
||||
|
||||
pub struct FocusStory {
|
||||
text: View<()>,
|
||||
}
|
||||
pub struct FocusStory {}
|
||||
|
||||
impl FocusStory {
|
||||
pub fn view(cx: &mut WindowContext) -> View<()> {
|
||||
pub fn view(cx: &mut WindowContext) -> View<Self> {
|
||||
cx.bind_keys([
|
||||
KeyBinding::new("cmd-a", ActionA, Some("parent")),
|
||||
KeyBinding::new("cmd-a", ActionB, Some("child-1")),
|
||||
@ -27,91 +25,92 @@ impl FocusStory {
|
||||
]);
|
||||
cx.register_action_type::<ActionA>();
|
||||
cx.register_action_type::<ActionB>();
|
||||
let theme = rose_pine();
|
||||
|
||||
let color_1 = theme.lowest.negative.default.foreground;
|
||||
let color_2 = theme.lowest.positive.default.foreground;
|
||||
let color_3 = theme.lowest.warning.default.foreground;
|
||||
let color_4 = theme.lowest.accent.default.foreground;
|
||||
let color_5 = theme.lowest.variant.default.foreground;
|
||||
let color_6 = theme.highest.negative.default.foreground;
|
||||
cx.build_view(move |cx| Self {})
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for FocusStory {
|
||||
type Element = Div<Self, StatefulInteraction<Self>, FocusEnabled<Self>>;
|
||||
|
||||
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
|
||||
let theme = theme(cx);
|
||||
let color_1 = theme.git_created;
|
||||
let color_2 = theme.git_modified;
|
||||
let color_3 = theme.git_deleted;
|
||||
let color_4 = theme.git_conflict;
|
||||
let color_5 = theme.git_ignored;
|
||||
let color_6 = theme.git_renamed;
|
||||
let child_1 = cx.focus_handle();
|
||||
let child_2 = cx.focus_handle();
|
||||
|
||||
cx.build_view(
|
||||
|_| (),
|
||||
move |_, cx| {
|
||||
div()
|
||||
.id("parent")
|
||||
.focusable()
|
||||
.context("parent")
|
||||
.on_action(|_, action: &ActionA, phase, cx| {
|
||||
println!("Action A dispatched on parent during {:?}", phase);
|
||||
})
|
||||
.on_action(|_, action: &ActionB, phase, cx| {
|
||||
println!("Action B dispatched on parent during {:?}", phase);
|
||||
})
|
||||
.on_focus(|_, _, _| println!("Parent focused"))
|
||||
.on_blur(|_, _, _| println!("Parent blurred"))
|
||||
.on_focus_in(|_, _, _| println!("Parent focus_in"))
|
||||
.on_focus_out(|_, _, _| println!("Parent focus_out"))
|
||||
.on_key_down(|_, event, phase, _| {
|
||||
println!("Key down on parent {:?} {:?}", phase, event)
|
||||
})
|
||||
.on_key_up(|_, event, phase, _| println!("Key up on parent {:?} {:?}", phase, event))
|
||||
.size_full()
|
||||
.bg(color_1)
|
||||
.focus(|style| style.bg(color_2))
|
||||
.focus_in(|style| style.bg(color_3))
|
||||
.child(
|
||||
div()
|
||||
.id("parent")
|
||||
.focusable()
|
||||
.context("parent")
|
||||
.on_action(|_, action: &ActionA, phase, cx| {
|
||||
println!("Action A dispatched on parent during {:?}", phase);
|
||||
})
|
||||
.track_focus(&child_1)
|
||||
.context("child-1")
|
||||
.on_action(|_, action: &ActionB, phase, cx| {
|
||||
println!("Action B dispatched on parent during {:?}", phase);
|
||||
println!("Action B dispatched on child 1 during {:?}", phase);
|
||||
})
|
||||
.on_focus(|_, _, _| println!("Parent focused"))
|
||||
.on_blur(|_, _, _| println!("Parent blurred"))
|
||||
.on_focus_in(|_, _, _| println!("Parent focus_in"))
|
||||
.on_focus_out(|_, _, _| println!("Parent focus_out"))
|
||||
.w_full()
|
||||
.h_6()
|
||||
.bg(color_4)
|
||||
.focus(|style| style.bg(color_5))
|
||||
.in_focus(|style| style.bg(color_6))
|
||||
.on_focus(|_, _, _| println!("Child 1 focused"))
|
||||
.on_blur(|_, _, _| println!("Child 1 blurred"))
|
||||
.on_focus_in(|_, _, _| println!("Child 1 focus_in"))
|
||||
.on_focus_out(|_, _, _| println!("Child 1 focus_out"))
|
||||
.on_key_down(|_, event, phase, _| {
|
||||
println!("Key down on parent {:?} {:?}", phase, event)
|
||||
println!("Key down on child 1 {:?} {:?}", phase, event)
|
||||
})
|
||||
.on_key_up(|_, event, phase, _| {
|
||||
println!("Key up on parent {:?} {:?}", phase, event)
|
||||
println!("Key up on child 1 {:?} {:?}", phase, event)
|
||||
})
|
||||
.size_full()
|
||||
.bg(color_1)
|
||||
.focus(|style| style.bg(color_2))
|
||||
.focus_in(|style| style.bg(color_3))
|
||||
.child(
|
||||
div()
|
||||
.track_focus(&child_1)
|
||||
.context("child-1")
|
||||
.on_action(|_, action: &ActionB, phase, cx| {
|
||||
println!("Action B dispatched on child 1 during {:?}", phase);
|
||||
})
|
||||
.w_full()
|
||||
.h_6()
|
||||
.bg(color_4)
|
||||
.focus(|style| style.bg(color_5))
|
||||
.in_focus(|style| style.bg(color_6))
|
||||
.on_focus(|_, _, _| println!("Child 1 focused"))
|
||||
.on_blur(|_, _, _| println!("Child 1 blurred"))
|
||||
.on_focus_in(|_, _, _| println!("Child 1 focus_in"))
|
||||
.on_focus_out(|_, _, _| println!("Child 1 focus_out"))
|
||||
.on_key_down(|_, event, phase, _| {
|
||||
println!("Key down on child 1 {:?} {:?}", phase, event)
|
||||
})
|
||||
.on_key_up(|_, event, phase, _| {
|
||||
println!("Key up on child 1 {:?} {:?}", phase, event)
|
||||
})
|
||||
.child("Child 1"),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.track_focus(&child_2)
|
||||
.context("child-2")
|
||||
.on_action(|_, action: &ActionC, phase, cx| {
|
||||
println!("Action C dispatched on child 2 during {:?}", phase);
|
||||
})
|
||||
.w_full()
|
||||
.h_6()
|
||||
.bg(color_4)
|
||||
.on_focus(|_, _, _| println!("Child 2 focused"))
|
||||
.on_blur(|_, _, _| println!("Child 2 blurred"))
|
||||
.on_focus_in(|_, _, _| println!("Child 2 focus_in"))
|
||||
.on_focus_out(|_, _, _| println!("Child 2 focus_out"))
|
||||
.on_key_down(|_, event, phase, _| {
|
||||
println!("Key down on child 2 {:?} {:?}", phase, event)
|
||||
})
|
||||
.on_key_up(|_, event, phase, _| {
|
||||
println!("Key up on child 2 {:?} {:?}", phase, event)
|
||||
})
|
||||
.child("Child 2"),
|
||||
)
|
||||
},
|
||||
)
|
||||
.child("Child 1"),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.track_focus(&child_2)
|
||||
.context("child-2")
|
||||
.on_action(|_, action: &ActionC, phase, cx| {
|
||||
println!("Action C dispatched on child 2 during {:?}", phase);
|
||||
})
|
||||
.w_full()
|
||||
.h_6()
|
||||
.bg(color_4)
|
||||
.on_focus(|_, _, _| println!("Child 2 focused"))
|
||||
.on_blur(|_, _, _| println!("Child 2 blurred"))
|
||||
.on_focus_in(|_, _, _| println!("Child 2 focus_in"))
|
||||
.on_focus_out(|_, _, _| println!("Child 2 focus_out"))
|
||||
.on_key_down(|_, event, phase, _| {
|
||||
println!("Key down on child 2 {:?} {:?}", phase, event)
|
||||
})
|
||||
.on_key_up(|_, event, phase, _| {
|
||||
println!("Key up on child 2 {:?} {:?}", phase, event)
|
||||
})
|
||||
.child("Child 2"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1,26 +1,23 @@
|
||||
use gpui2::{AppContext, Context, View};
|
||||
use crate::{
|
||||
story::Story,
|
||||
story_selector::{ComponentStory, ElementStory},
|
||||
};
|
||||
use gpui2::{Div, Render, StatefulInteraction, View, VisualContext};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
|
||||
use crate::story::Story;
|
||||
use crate::story_selector::{ComponentStory, ElementStory};
|
||||
|
||||
pub struct KitchenSinkStory {}
|
||||
pub struct KitchenSinkStory;
|
||||
|
||||
impl KitchenSinkStory {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
pub fn view(cx: &mut WindowContext) -> View<Self> {
|
||||
cx.build_view(|cx| Self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn view(cx: &mut AppContext) -> View<Self> {
|
||||
{
|
||||
let state = cx.entity(|cx| Self::new());
|
||||
let render = Self::render;
|
||||
View::for_handle(state, render)
|
||||
}
|
||||
}
|
||||
impl Render for KitchenSinkStory {
|
||||
type Element = Div<Self, StatefulInteraction<Self>>;
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||
let element_stories = ElementStory::iter()
|
||||
.map(|selector| selector.story(cx))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -1,58 +1,54 @@
|
||||
use crate::themes::rose_pine;
|
||||
use gpui2::{
|
||||
div, px, Component, ParentElement, SharedString, Styled, View, VisualContext, WindowContext,
|
||||
div, px, Component, Div, ParentElement, Render, SharedString, StatefulInteraction, Styled,
|
||||
View, VisualContext, WindowContext,
|
||||
};
|
||||
use theme2::theme;
|
||||
|
||||
pub struct ScrollStory {
|
||||
text: View<()>,
|
||||
}
|
||||
pub struct ScrollStory;
|
||||
|
||||
impl ScrollStory {
|
||||
pub fn view(cx: &mut WindowContext) -> View<()> {
|
||||
let theme = rose_pine();
|
||||
|
||||
{
|
||||
cx.build_view(|cx| (), move |_, cx| checkerboard(1))
|
||||
}
|
||||
pub fn view(cx: &mut WindowContext) -> View<ScrollStory> {
|
||||
cx.build_view(|cx| ScrollStory)
|
||||
}
|
||||
}
|
||||
|
||||
fn checkerboard<S>(depth: usize) -> impl Component<S>
|
||||
where
|
||||
S: 'static + Send + Sync,
|
||||
{
|
||||
let theme = rose_pine();
|
||||
let color_1 = theme.lowest.positive.default.background;
|
||||
let color_2 = theme.lowest.warning.default.background;
|
||||
impl Render for ScrollStory {
|
||||
type Element = Div<Self, StatefulInteraction<Self>>;
|
||||
|
||||
div()
|
||||
.id("parent")
|
||||
.bg(theme.lowest.base.default.background)
|
||||
.size_full()
|
||||
.overflow_scroll()
|
||||
.children((0..10).map(|row| {
|
||||
div()
|
||||
.w(px(1000.))
|
||||
.h(px(100.))
|
||||
.flex()
|
||||
.flex_row()
|
||||
.children((0..10).map(|column| {
|
||||
let id = SharedString::from(format!("{}, {}", row, column));
|
||||
let bg = if row % 2 == column % 2 {
|
||||
color_1
|
||||
} else {
|
||||
color_2
|
||||
};
|
||||
div().id(id).bg(bg).size(px(100. / depth as f32)).when(
|
||||
row >= 5 && column >= 5,
|
||||
|d| {
|
||||
d.overflow_scroll()
|
||||
.child(div().size(px(50.)).bg(color_1))
|
||||
.child(div().size(px(50.)).bg(color_2))
|
||||
.child(div().size(px(50.)).bg(color_1))
|
||||
.child(div().size(px(50.)).bg(color_2))
|
||||
},
|
||||
)
|
||||
}))
|
||||
}))
|
||||
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
|
||||
let theme = theme(cx);
|
||||
let color_1 = theme.git_created;
|
||||
let color_2 = theme.git_modified;
|
||||
|
||||
div()
|
||||
.id("parent")
|
||||
.bg(theme.background)
|
||||
.size_full()
|
||||
.overflow_scroll()
|
||||
.children((0..10).map(|row| {
|
||||
div()
|
||||
.w(px(1000.))
|
||||
.h(px(100.))
|
||||
.flex()
|
||||
.flex_row()
|
||||
.children((0..10).map(|column| {
|
||||
let id = SharedString::from(format!("{}, {}", row, column));
|
||||
let bg = if row % 2 == column % 2 {
|
||||
color_1
|
||||
} else {
|
||||
color_2
|
||||
};
|
||||
div().id(id).bg(bg).size(px(100. as f32)).when(
|
||||
row >= 5 && column >= 5,
|
||||
|d| {
|
||||
d.overflow_scroll()
|
||||
.child(div().size(px(50.)).bg(color_1))
|
||||
.child(div().size(px(50.)).bg(color_2))
|
||||
.child(div().size(px(50.)).bg(color_1))
|
||||
.child(div().size(px(50.)).bg(color_2))
|
||||
},
|
||||
)
|
||||
}))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
@ -1,20 +1,21 @@
|
||||
use gpui2::{div, white, ParentElement, Styled, View, VisualContext, WindowContext};
|
||||
use gpui2::{div, white, Div, ParentElement, Render, Styled, View, VisualContext, WindowContext};
|
||||
|
||||
pub struct TextStory {
|
||||
text: View<()>,
|
||||
}
|
||||
pub struct TextStory;
|
||||
|
||||
impl TextStory {
|
||||
pub fn view(cx: &mut WindowContext) -> View<()> {
|
||||
cx.build_view(|cx| (), |_, cx| {
|
||||
div()
|
||||
.size_full()
|
||||
.bg(white())
|
||||
.child(concat!(
|
||||
"The quick brown fox jumps over the lazy dog. ",
|
||||
"Meanwhile, the lazy dog decided it was time for a change. ",
|
||||
"He started daily workout routines, ate healthier and became the fastest dog in town.",
|
||||
))
|
||||
})
|
||||
pub fn view(cx: &mut WindowContext) -> View<Self> {
|
||||
cx.build_view(|cx| Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for TextStory {
|
||||
type Element = Div<Self>;
|
||||
|
||||
fn render(&mut self, cx: &mut gpui2::ViewContext<Self>) -> Self::Element {
|
||||
div().size_full().bg(white()).child(concat!(
|
||||
"The quick brown fox jumps over the lazy dog. ",
|
||||
"Meanwhile, the lazy dog decided it was time for a change. ",
|
||||
"He started daily workout routines, ate healthier and became the fastest dog in town.",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
@ -1,15 +1,16 @@
|
||||
use gpui2::{px, rgb, Div, Hsla};
|
||||
use gpui2::{px, rgb, Div, Hsla, Render};
|
||||
use ui::prelude::*;
|
||||
|
||||
use crate::story::Story;
|
||||
|
||||
/// A reimplementation of the MDN `z-index` example, found here:
|
||||
/// [https://developer.mozilla.org/en-US/docs/Web/CSS/z-index](https://developer.mozilla.org/en-US/docs/Web/CSS/z-index).
|
||||
#[derive(Component)]
|
||||
pub struct ZIndexStory;
|
||||
|
||||
impl ZIndexStory {
|
||||
fn render<V: 'static>(self, _view: &mut V, cx: &mut ViewContext<V>) -> impl Component<V> {
|
||||
impl Render for ZIndexStory {
|
||||
type Element = Div<Self>;
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||
Story::container(cx)
|
||||
.child(Story::title(cx, "z-index"))
|
||||
.child(
|
||||
|
@ -7,13 +7,14 @@ use clap::builder::PossibleValue;
|
||||
use clap::ValueEnum;
|
||||
use gpui2::{AnyView, VisualContext};
|
||||
use strum::{EnumIter, EnumString, IntoEnumIterator};
|
||||
use ui::prelude::*;
|
||||
use ui::{prelude::*, AvatarStory, ButtonStory, DetailsStory, IconStory, InputStory, LabelStory};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display, EnumString, EnumIter)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum ElementStory {
|
||||
Avatar,
|
||||
Button,
|
||||
Colors,
|
||||
Details,
|
||||
Focus,
|
||||
Icon,
|
||||
@ -27,18 +28,17 @@ pub enum ElementStory {
|
||||
impl ElementStory {
|
||||
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
|
||||
match self {
|
||||
Self::Avatar => { cx.build_view(|cx| (), |_, _| ui::AvatarStory.render()) }.into_any(),
|
||||
Self::Button => { cx.build_view(|cx| (), |_, _| ui::ButtonStory.render()) }.into_any(),
|
||||
Self::Details => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::DetailsStory.render()) }.into_any()
|
||||
}
|
||||
Self::Colors => cx.build_view(|_| ColorsStory).into_any(),
|
||||
Self::Avatar => cx.build_view(|_| AvatarStory).into_any(),
|
||||
Self::Button => cx.build_view(|_| ButtonStory).into_any(),
|
||||
Self::Details => cx.build_view(|_| DetailsStory).into_any(),
|
||||
Self::Focus => FocusStory::view(cx).into_any(),
|
||||
Self::Icon => { cx.build_view(|cx| (), |_, _| ui::IconStory.render()) }.into_any(),
|
||||
Self::Input => { cx.build_view(|cx| (), |_, _| ui::InputStory.render()) }.into_any(),
|
||||
Self::Label => { cx.build_view(|cx| (), |_, _| ui::LabelStory.render()) }.into_any(),
|
||||
Self::Icon => cx.build_view(|_| IconStory).into_any(),
|
||||
Self::Input => cx.build_view(|_| InputStory).into_any(),
|
||||
Self::Label => cx.build_view(|_| LabelStory).into_any(),
|
||||
Self::Scroll => ScrollStory::view(cx).into_any(),
|
||||
Self::Text => TextStory::view(cx).into_any(),
|
||||
Self::ZIndex => { cx.build_view(|cx| (), |_, _| ZIndexStory.render()) }.into_any(),
|
||||
Self::ZIndex => cx.build_view(|_| ZIndexStory).into_any(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -77,69 +77,31 @@ pub enum ComponentStory {
|
||||
impl ComponentStory {
|
||||
pub fn story(&self, cx: &mut WindowContext) -> AnyView {
|
||||
match self {
|
||||
Self::AssistantPanel => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::AssistantPanelStory.render()) }.into_any()
|
||||
}
|
||||
Self::Buffer => { cx.build_view(|cx| (), |_, _| ui::BufferStory.render()) }.into_any(),
|
||||
Self::Breadcrumb => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::BreadcrumbStory.render()) }.into_any()
|
||||
}
|
||||
Self::ChatPanel => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::ChatPanelStory.render()) }.into_any()
|
||||
}
|
||||
Self::CollabPanel => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::CollabPanelStory.render()) }.into_any()
|
||||
}
|
||||
Self::CommandPalette => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::CommandPaletteStory.render()) }.into_any()
|
||||
}
|
||||
Self::ContextMenu => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::ContextMenuStory.render()) }.into_any()
|
||||
}
|
||||
Self::Facepile => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::FacepileStory.render()) }.into_any()
|
||||
}
|
||||
Self::Keybinding => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::KeybindingStory.render()) }.into_any()
|
||||
}
|
||||
Self::LanguageSelector => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::LanguageSelectorStory.render()) }.into_any()
|
||||
}
|
||||
Self::MultiBuffer => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::MultiBufferStory.render()) }.into_any()
|
||||
}
|
||||
Self::NotificationsPanel => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::NotificationsPanelStory.render()) }.into_any()
|
||||
}
|
||||
Self::Palette => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::PaletteStory.render()) }.into_any()
|
||||
}
|
||||
Self::Panel => { cx.build_view(|cx| (), |_, _| ui::PanelStory.render()) }.into_any(),
|
||||
Self::ProjectPanel => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::ProjectPanelStory.render()) }.into_any()
|
||||
}
|
||||
Self::RecentProjects => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::RecentProjectsStory.render()) }.into_any()
|
||||
}
|
||||
Self::Tab => { cx.build_view(|cx| (), |_, _| ui::TabStory.render()) }.into_any(),
|
||||
Self::TabBar => { cx.build_view(|cx| (), |_, _| ui::TabBarStory.render()) }.into_any(),
|
||||
Self::Terminal => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::TerminalStory.render()) }.into_any()
|
||||
}
|
||||
Self::ThemeSelector => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::ThemeSelectorStory.render()) }.into_any()
|
||||
}
|
||||
Self::AssistantPanel => cx.build_view(|_| ui::AssistantPanelStory).into_any(),
|
||||
Self::Buffer => cx.build_view(|_| ui::BufferStory).into_any(),
|
||||
Self::Breadcrumb => cx.build_view(|_| ui::BreadcrumbStory).into_any(),
|
||||
Self::ChatPanel => cx.build_view(|_| ui::ChatPanelStory).into_any(),
|
||||
Self::CollabPanel => cx.build_view(|_| ui::CollabPanelStory).into_any(),
|
||||
Self::CommandPalette => cx.build_view(|_| ui::CommandPaletteStory).into_any(),
|
||||
Self::ContextMenu => cx.build_view(|_| ui::ContextMenuStory).into_any(),
|
||||
Self::Facepile => cx.build_view(|_| ui::FacepileStory).into_any(),
|
||||
Self::Keybinding => cx.build_view(|_| ui::KeybindingStory).into_any(),
|
||||
Self::LanguageSelector => cx.build_view(|_| ui::LanguageSelectorStory).into_any(),
|
||||
Self::MultiBuffer => cx.build_view(|_| ui::MultiBufferStory).into_any(),
|
||||
Self::NotificationsPanel => cx.build_view(|cx| ui::NotificationsPanelStory).into_any(),
|
||||
Self::Palette => cx.build_view(|cx| ui::PaletteStory).into_any(),
|
||||
Self::Panel => cx.build_view(|cx| ui::PanelStory).into_any(),
|
||||
Self::ProjectPanel => cx.build_view(|_| ui::ProjectPanelStory).into_any(),
|
||||
Self::RecentProjects => cx.build_view(|_| ui::RecentProjectsStory).into_any(),
|
||||
Self::Tab => cx.build_view(|_| ui::TabStory).into_any(),
|
||||
Self::TabBar => cx.build_view(|_| ui::TabBarStory).into_any(),
|
||||
Self::Terminal => cx.build_view(|_| ui::TerminalStory).into_any(),
|
||||
Self::ThemeSelector => cx.build_view(|_| ui::ThemeSelectorStory).into_any(),
|
||||
Self::Toast => cx.build_view(|_| ui::ToastStory).into_any(),
|
||||
Self::Toolbar => cx.build_view(|_| ui::ToolbarStory).into_any(),
|
||||
Self::TrafficLights => cx.build_view(|_| ui::TrafficLightsStory).into_any(),
|
||||
Self::Copilot => cx.build_view(|_| ui::CopilotModalStory).into_any(),
|
||||
Self::TitleBar => ui::TitleBarStory::view(cx).into_any(),
|
||||
Self::Toast => { cx.build_view(|cx| (), |_, _| ui::ToastStory.render()) }.into_any(),
|
||||
Self::Toolbar => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::ToolbarStory.render()) }.into_any()
|
||||
}
|
||||
Self::TrafficLights => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::TrafficLightsStory.render()) }.into_any()
|
||||
}
|
||||
Self::Copilot => {
|
||||
{ cx.build_view(|cx| (), |_, _| ui::CopilotModalStory.render()) }.into_any()
|
||||
}
|
||||
Self::Workspace => ui::WorkspaceStory::view(cx).into_any(),
|
||||
}
|
||||
}
|
||||
|
@ -4,21 +4,20 @@ mod assets;
|
||||
mod stories;
|
||||
mod story;
|
||||
mod story_selector;
|
||||
mod themes;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use gpui2::{
|
||||
div, px, size, AnyView, AppContext, Bounds, ViewContext, VisualContext, WindowBounds,
|
||||
WindowOptions,
|
||||
div, px, size, AnyView, AppContext, Bounds, Div, Render, ViewContext, VisualContext,
|
||||
WindowBounds, WindowOptions,
|
||||
};
|
||||
use log::LevelFilter;
|
||||
use settings2::{default_settings, Settings, SettingsStore};
|
||||
use simplelog::SimpleLogger;
|
||||
use story_selector::ComponentStory;
|
||||
use theme2::{ThemeRegistry, ThemeSettings};
|
||||
use ui::{prelude::*, themed};
|
||||
use ui::prelude::*;
|
||||
|
||||
use crate::assets::Assets;
|
||||
use crate::story_selector::StorySelector;
|
||||
@ -50,7 +49,6 @@ fn main() {
|
||||
|
||||
let story_selector = args.story.clone();
|
||||
let theme_name = args.theme.unwrap_or("One Dark".to_string());
|
||||
let theme = themes::load_theme(theme_name.clone()).unwrap();
|
||||
|
||||
let asset_source = Arc::new(Assets);
|
||||
gpui2::App::production(asset_source).run(move |cx| {
|
||||
@ -84,12 +82,7 @@ fn main() {
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
move |cx| {
|
||||
cx.build_view(
|
||||
|cx| StoryWrapper::new(selector.story(cx), theme),
|
||||
StoryWrapper::render,
|
||||
)
|
||||
},
|
||||
move |cx| cx.build_view(|cx| StoryWrapper::new(selector.story(cx))),
|
||||
);
|
||||
|
||||
cx.activate(true);
|
||||
@ -99,22 +92,23 @@ fn main() {
|
||||
#[derive(Clone)]
|
||||
pub struct StoryWrapper {
|
||||
story: AnyView,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl StoryWrapper {
|
||||
pub(crate) fn new(story: AnyView, theme: Theme) -> Self {
|
||||
Self { story, theme }
|
||||
pub(crate) fn new(story: AnyView) -> Self {
|
||||
Self { story }
|
||||
}
|
||||
}
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl Component<Self> {
|
||||
themed(self.theme.clone(), cx, |cx| {
|
||||
div()
|
||||
.flex()
|
||||
.flex_col()
|
||||
.size_full()
|
||||
.child(self.story.clone())
|
||||
})
|
||||
impl Render for StoryWrapper {
|
||||
type Element = Div<Self>;
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> Self::Element {
|
||||
div()
|
||||
.flex()
|
||||
.flex_col()
|
||||
.size_full()
|
||||
.child(self.story.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,30 +0,0 @@
|
||||
mod rose_pine;
|
||||
|
||||
pub use rose_pine::*;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use gpui2::serde_json;
|
||||
use serde::Deserialize;
|
||||
use ui::Theme;
|
||||
|
||||
use crate::assets::Assets;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LegacyTheme {
|
||||
pub base_theme: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Loads the [`Theme`] with the given name.
|
||||
pub fn load_theme(name: String) -> Result<Theme> {
|
||||
let theme_contents = Assets::get(&format!("themes/{name}.json"))
|
||||
.with_context(|| format!("theme file not found: '{name}'"))?;
|
||||
|
||||
let legacy_theme: LegacyTheme =
|
||||
serde_json::from_str(std::str::from_utf8(&theme_contents.data)?)
|
||||
.context("failed to parse legacy theme")?;
|
||||
|
||||
let theme: Theme = serde_json::from_value(legacy_theme.base_theme.clone())
|
||||
.context("failed to parse `base_theme`")?;
|
||||
|
||||
Ok(theme)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
2118
crates/theme2/src/default.rs
Normal file
2118
crates/theme2/src/default.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,4 @@
|
||||
use crate::{
|
||||
themes::{one_dark, rose_pine, rose_pine_dawn, rose_pine_moon, sandcastle},
|
||||
Theme, ThemeMetadata,
|
||||
};
|
||||
use crate::{themes, Theme, ThemeMetadata};
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui2::SharedString;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
@ -41,11 +38,45 @@ impl Default for ThemeRegistry {
|
||||
};
|
||||
|
||||
this.insert_themes([
|
||||
one_dark(),
|
||||
rose_pine(),
|
||||
rose_pine_dawn(),
|
||||
rose_pine_moon(),
|
||||
sandcastle(),
|
||||
themes::andromeda(),
|
||||
themes::atelier_cave_dark(),
|
||||
themes::atelier_cave_light(),
|
||||
themes::atelier_dune_dark(),
|
||||
themes::atelier_dune_light(),
|
||||
themes::atelier_estuary_dark(),
|
||||
themes::atelier_estuary_light(),
|
||||
themes::atelier_forest_dark(),
|
||||
themes::atelier_forest_light(),
|
||||
themes::atelier_heath_dark(),
|
||||
themes::atelier_heath_light(),
|
||||
themes::atelier_lakeside_dark(),
|
||||
themes::atelier_lakeside_light(),
|
||||
themes::atelier_plateau_dark(),
|
||||
themes::atelier_plateau_light(),
|
||||
themes::atelier_savanna_dark(),
|
||||
themes::atelier_savanna_light(),
|
||||
themes::atelier_seaside_dark(),
|
||||
themes::atelier_seaside_light(),
|
||||
themes::atelier_sulphurpool_dark(),
|
||||
themes::atelier_sulphurpool_light(),
|
||||
themes::ayu_dark(),
|
||||
themes::ayu_light(),
|
||||
themes::ayu_mirage(),
|
||||
themes::gruvbox_dark(),
|
||||
themes::gruvbox_dark_hard(),
|
||||
themes::gruvbox_dark_soft(),
|
||||
themes::gruvbox_light(),
|
||||
themes::gruvbox_light_hard(),
|
||||
themes::gruvbox_light_soft(),
|
||||
themes::one_dark(),
|
||||
themes::one_light(),
|
||||
themes::rose_pine(),
|
||||
themes::rose_pine_dawn(),
|
||||
themes::rose_pine_moon(),
|
||||
themes::sandcastle(),
|
||||
themes::solarized_dark(),
|
||||
themes::solarized_light(),
|
||||
themes::summercamp(),
|
||||
]);
|
||||
|
||||
this
|
||||
|
164
crates/theme2/src/scale.rs
Normal file
164
crates/theme2/src/scale.rs
Normal file
@ -0,0 +1,164 @@
|
||||
use gpui2::{AppContext, Hsla};
|
||||
use indexmap::IndexMap;
|
||||
|
||||
use crate::{theme, Appearance};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum ColorScaleName {
|
||||
Gray,
|
||||
Mauve,
|
||||
Slate,
|
||||
Sage,
|
||||
Olive,
|
||||
Sand,
|
||||
Gold,
|
||||
Bronze,
|
||||
Brown,
|
||||
Yellow,
|
||||
Amber,
|
||||
Orange,
|
||||
Tomato,
|
||||
Red,
|
||||
Ruby,
|
||||
Crimson,
|
||||
Pink,
|
||||
Plum,
|
||||
Purple,
|
||||
Violet,
|
||||
Iris,
|
||||
Indigo,
|
||||
Blue,
|
||||
Cyan,
|
||||
Teal,
|
||||
Jade,
|
||||
Green,
|
||||
Grass,
|
||||
Lime,
|
||||
Mint,
|
||||
Sky,
|
||||
Black,
|
||||
White,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ColorScaleName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
Self::Gray => "Gray",
|
||||
Self::Mauve => "Mauve",
|
||||
Self::Slate => "Slate",
|
||||
Self::Sage => "Sage",
|
||||
Self::Olive => "Olive",
|
||||
Self::Sand => "Sand",
|
||||
Self::Gold => "Gold",
|
||||
Self::Bronze => "Bronze",
|
||||
Self::Brown => "Brown",
|
||||
Self::Yellow => "Yellow",
|
||||
Self::Amber => "Amber",
|
||||
Self::Orange => "Orange",
|
||||
Self::Tomato => "Tomato",
|
||||
Self::Red => "Red",
|
||||
Self::Ruby => "Ruby",
|
||||
Self::Crimson => "Crimson",
|
||||
Self::Pink => "Pink",
|
||||
Self::Plum => "Plum",
|
||||
Self::Purple => "Purple",
|
||||
Self::Violet => "Violet",
|
||||
Self::Iris => "Iris",
|
||||
Self::Indigo => "Indigo",
|
||||
Self::Blue => "Blue",
|
||||
Self::Cyan => "Cyan",
|
||||
Self::Teal => "Teal",
|
||||
Self::Jade => "Jade",
|
||||
Self::Green => "Green",
|
||||
Self::Grass => "Grass",
|
||||
Self::Lime => "Lime",
|
||||
Self::Mint => "Mint",
|
||||
Self::Sky => "Sky",
|
||||
Self::Black => "Black",
|
||||
Self::White => "White",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub type ColorScale = [Hsla; 12];
|
||||
|
||||
pub type ColorScales = IndexMap<ColorScaleName, ColorScaleSet>;
|
||||
|
||||
/// A one-based step in a [`ColorScale`].
|
||||
pub type ColorScaleStep = usize;
|
||||
|
||||
pub struct ColorScaleSet {
|
||||
name: ColorScaleName,
|
||||
light: ColorScale,
|
||||
dark: ColorScale,
|
||||
light_alpha: ColorScale,
|
||||
dark_alpha: ColorScale,
|
||||
}
|
||||
|
||||
impl ColorScaleSet {
|
||||
pub fn new(
|
||||
name: ColorScaleName,
|
||||
light: ColorScale,
|
||||
light_alpha: ColorScale,
|
||||
dark: ColorScale,
|
||||
dark_alpha: ColorScale,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
light,
|
||||
light_alpha,
|
||||
dark,
|
||||
dark_alpha,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> String {
|
||||
self.name.to_string()
|
||||
}
|
||||
|
||||
pub fn light(&self, step: ColorScaleStep) -> Hsla {
|
||||
self.light[step - 1]
|
||||
}
|
||||
|
||||
pub fn light_alpha(&self, step: ColorScaleStep) -> Hsla {
|
||||
self.light_alpha[step - 1]
|
||||
}
|
||||
|
||||
pub fn dark(&self, step: ColorScaleStep) -> Hsla {
|
||||
self.dark[step - 1]
|
||||
}
|
||||
|
||||
pub fn dark_alpha(&self, step: ColorScaleStep) -> Hsla {
|
||||
self.dark_alpha[step - 1]
|
||||
}
|
||||
|
||||
fn current_appearance(cx: &AppContext) -> Appearance {
|
||||
let theme = theme(cx);
|
||||
if theme.metadata.is_light {
|
||||
Appearance::Light
|
||||
} else {
|
||||
Appearance::Dark
|
||||
}
|
||||
}
|
||||
|
||||
pub fn step(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
|
||||
let appearance = Self::current_appearance(cx);
|
||||
|
||||
match appearance {
|
||||
Appearance::Light => self.light(step),
|
||||
Appearance::Dark => self.dark(step),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn step_alpha(&self, cx: &AppContext, step: ColorScaleStep) -> Hsla {
|
||||
let appearance = Self::current_appearance(cx);
|
||||
match appearance {
|
||||
Appearance::Light => self.light_alpha(step),
|
||||
Appearance::Dark => self.dark_alpha(step),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,14 +1,24 @@
|
||||
mod default;
|
||||
mod registry;
|
||||
mod scale;
|
||||
mod settings;
|
||||
mod themes;
|
||||
|
||||
pub use default::*;
|
||||
pub use registry::*;
|
||||
pub use scale::*;
|
||||
pub use settings::*;
|
||||
|
||||
use gpui2::{AppContext, HighlightStyle, Hsla, SharedString};
|
||||
use settings2::Settings;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Appearance {
|
||||
Light,
|
||||
Dark,
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
cx.set_global(ThemeRegistry::default());
|
||||
ThemeSettings::register(cx);
|
||||
@ -18,6 +28,10 @@ pub fn active_theme<'a>(cx: &'a AppContext) -> &'a Arc<Theme> {
|
||||
&ThemeSettings::get_global(cx).active_theme
|
||||
}
|
||||
|
||||
pub fn theme(cx: &AppContext) -> Arc<Theme> {
|
||||
active_theme(cx).clone()
|
||||
}
|
||||
|
||||
pub struct Theme {
|
||||
pub metadata: ThemeMetadata,
|
||||
|
||||
|
130
crates/theme2/src/themes/andromeda.rs
Normal file
130
crates/theme2/src/themes/andromeda.rs
Normal file
@ -0,0 +1,130 @@
|
||||
use gpui2::rgba;
|
||||
|
||||
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||
|
||||
pub fn andromeda() -> Theme {
|
||||
Theme {
|
||||
metadata: ThemeMetadata {
|
||||
name: "Andromeda".into(),
|
||||
is_light: false,
|
||||
},
|
||||
transparent: rgba(0x00000000).into(),
|
||||
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||
border: rgba(0x2b2f38ff).into(),
|
||||
border_variant: rgba(0x2b2f38ff).into(),
|
||||
border_focused: rgba(0x183934ff).into(),
|
||||
border_transparent: rgba(0x00000000).into(),
|
||||
elevated_surface: rgba(0x262933ff).into(),
|
||||
surface: rgba(0x21242bff).into(),
|
||||
background: rgba(0x262933ff).into(),
|
||||
filled_element: rgba(0x262933ff).into(),
|
||||
filled_element_hover: rgba(0xffffff1e).into(),
|
||||
filled_element_active: rgba(0xffffff28).into(),
|
||||
filled_element_selected: rgba(0x12231fff).into(),
|
||||
filled_element_disabled: rgba(0x00000000).into(),
|
||||
ghost_element: rgba(0x00000000).into(),
|
||||
ghost_element_hover: rgba(0xffffff14).into(),
|
||||
ghost_element_active: rgba(0xffffff1e).into(),
|
||||
ghost_element_selected: rgba(0x12231fff).into(),
|
||||
ghost_element_disabled: rgba(0x00000000).into(),
|
||||
text: rgba(0xf7f7f8ff).into(),
|
||||
text_muted: rgba(0xaca8aeff).into(),
|
||||
text_placeholder: rgba(0xf82871ff).into(),
|
||||
text_disabled: rgba(0x6b6b73ff).into(),
|
||||
text_accent: rgba(0x10a793ff).into(),
|
||||
icon_muted: rgba(0xaca8aeff).into(),
|
||||
syntax: SyntaxTheme {
|
||||
highlights: vec![
|
||||
("emphasis".into(), rgba(0x10a793ff).into()),
|
||||
("punctuation.bracket".into(), rgba(0xd8d5dbff).into()),
|
||||
("attribute".into(), rgba(0x10a793ff).into()),
|
||||
("variable".into(), rgba(0xf7f7f8ff).into()),
|
||||
("predictive".into(), rgba(0x315f70ff).into()),
|
||||
("property".into(), rgba(0x10a793ff).into()),
|
||||
("variant".into(), rgba(0x10a793ff).into()),
|
||||
("embedded".into(), rgba(0xf7f7f8ff).into()),
|
||||
("string.special".into(), rgba(0xf29c14ff).into()),
|
||||
("keyword".into(), rgba(0x10a793ff).into()),
|
||||
("tag".into(), rgba(0x10a793ff).into()),
|
||||
("enum".into(), rgba(0xf29c14ff).into()),
|
||||
("link_text".into(), rgba(0xf29c14ff).into()),
|
||||
("primary".into(), rgba(0xf7f7f8ff).into()),
|
||||
("punctuation".into(), rgba(0xd8d5dbff).into()),
|
||||
("punctuation.special".into(), rgba(0xd8d5dbff).into()),
|
||||
("function".into(), rgba(0xfee56cff).into()),
|
||||
("number".into(), rgba(0x96df71ff).into()),
|
||||
("preproc".into(), rgba(0xf7f7f8ff).into()),
|
||||
("operator".into(), rgba(0xf29c14ff).into()),
|
||||
("constructor".into(), rgba(0x10a793ff).into()),
|
||||
("string.escape".into(), rgba(0xafabb1ff).into()),
|
||||
("string.special.symbol".into(), rgba(0xf29c14ff).into()),
|
||||
("string".into(), rgba(0xf29c14ff).into()),
|
||||
("comment".into(), rgba(0xafabb1ff).into()),
|
||||
("hint".into(), rgba(0x618399ff).into()),
|
||||
("type".into(), rgba(0x08e7c5ff).into()),
|
||||
("label".into(), rgba(0x10a793ff).into()),
|
||||
("comment.doc".into(), rgba(0xafabb1ff).into()),
|
||||
("text.literal".into(), rgba(0xf29c14ff).into()),
|
||||
("constant".into(), rgba(0x96df71ff).into()),
|
||||
("string.regex".into(), rgba(0xf29c14ff).into()),
|
||||
("emphasis.strong".into(), rgba(0x10a793ff).into()),
|
||||
("title".into(), rgba(0xf7f7f8ff).into()),
|
||||
("punctuation.delimiter".into(), rgba(0xd8d5dbff).into()),
|
||||
("link_uri".into(), rgba(0x96df71ff).into()),
|
||||
("boolean".into(), rgba(0x96df71ff).into()),
|
||||
("punctuation.list_marker".into(), rgba(0xd8d5dbff).into()),
|
||||
],
|
||||
},
|
||||
status_bar: rgba(0x262933ff).into(),
|
||||
title_bar: rgba(0x262933ff).into(),
|
||||
toolbar: rgba(0x1e2025ff).into(),
|
||||
tab_bar: rgba(0x21242bff).into(),
|
||||
editor: rgba(0x1e2025ff).into(),
|
||||
editor_subheader: rgba(0x21242bff).into(),
|
||||
editor_active_line: rgba(0x21242bff).into(),
|
||||
terminal: rgba(0x1e2025ff).into(),
|
||||
image_fallback_background: rgba(0x262933ff).into(),
|
||||
git_created: rgba(0x96df71ff).into(),
|
||||
git_modified: rgba(0x10a793ff).into(),
|
||||
git_deleted: rgba(0xf82871ff).into(),
|
||||
git_conflict: rgba(0xfee56cff).into(),
|
||||
git_ignored: rgba(0x6b6b73ff).into(),
|
||||
git_renamed: rgba(0xfee56cff).into(),
|
||||
players: [
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x10a793ff).into(),
|
||||
selection: rgba(0x10a7933d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x96df71ff).into(),
|
||||
selection: rgba(0x96df713d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xc74cecff).into(),
|
||||
selection: rgba(0xc74cec3d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xf29c14ff).into(),
|
||||
selection: rgba(0xf29c143d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x893ea6ff).into(),
|
||||
selection: rgba(0x893ea63d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x08e7c5ff).into(),
|
||||
selection: rgba(0x08e7c53d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xf82871ff).into(),
|
||||
selection: rgba(0xf828713d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xfee56cff).into(),
|
||||
selection: rgba(0xfee56c3d).into(),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
136
crates/theme2/src/themes/atelier_cave_dark.rs
Normal file
136
crates/theme2/src/themes/atelier_cave_dark.rs
Normal file
@ -0,0 +1,136 @@
|
||||
use gpui2::rgba;
|
||||
|
||||
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||
|
||||
pub fn atelier_cave_dark() -> Theme {
|
||||
Theme {
|
||||
metadata: ThemeMetadata {
|
||||
name: "Atelier Cave Dark".into(),
|
||||
is_light: false,
|
||||
},
|
||||
transparent: rgba(0x00000000).into(),
|
||||
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||
border: rgba(0x56505eff).into(),
|
||||
border_variant: rgba(0x56505eff).into(),
|
||||
border_focused: rgba(0x222953ff).into(),
|
||||
border_transparent: rgba(0x00000000).into(),
|
||||
elevated_surface: rgba(0x3a353fff).into(),
|
||||
surface: rgba(0x221f26ff).into(),
|
||||
background: rgba(0x3a353fff).into(),
|
||||
filled_element: rgba(0x3a353fff).into(),
|
||||
filled_element_hover: rgba(0xffffff1e).into(),
|
||||
filled_element_active: rgba(0xffffff28).into(),
|
||||
filled_element_selected: rgba(0x161a35ff).into(),
|
||||
filled_element_disabled: rgba(0x00000000).into(),
|
||||
ghost_element: rgba(0x00000000).into(),
|
||||
ghost_element_hover: rgba(0xffffff14).into(),
|
||||
ghost_element_active: rgba(0xffffff1e).into(),
|
||||
ghost_element_selected: rgba(0x161a35ff).into(),
|
||||
ghost_element_disabled: rgba(0x00000000).into(),
|
||||
text: rgba(0xefecf4ff).into(),
|
||||
text_muted: rgba(0x898591ff).into(),
|
||||
text_placeholder: rgba(0xbe4677ff).into(),
|
||||
text_disabled: rgba(0x756f7eff).into(),
|
||||
text_accent: rgba(0x566ddaff).into(),
|
||||
icon_muted: rgba(0x898591ff).into(),
|
||||
syntax: SyntaxTheme {
|
||||
highlights: vec![
|
||||
("comment.doc".into(), rgba(0x8b8792ff).into()),
|
||||
("tag".into(), rgba(0x566ddaff).into()),
|
||||
("link_text".into(), rgba(0xaa563bff).into()),
|
||||
("constructor".into(), rgba(0x566ddaff).into()),
|
||||
("punctuation".into(), rgba(0xe2dfe7ff).into()),
|
||||
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
|
||||
("string.special.symbol".into(), rgba(0x299292ff).into()),
|
||||
("string.escape".into(), rgba(0x8b8792ff).into()),
|
||||
("emphasis".into(), rgba(0x566ddaff).into()),
|
||||
("type".into(), rgba(0xa06d3aff).into()),
|
||||
("punctuation.delimiter".into(), rgba(0x8b8792ff).into()),
|
||||
("variant".into(), rgba(0xa06d3aff).into()),
|
||||
("variable.special".into(), rgba(0x9559e7ff).into()),
|
||||
("text.literal".into(), rgba(0xaa563bff).into()),
|
||||
("punctuation.list_marker".into(), rgba(0xe2dfe7ff).into()),
|
||||
("comment".into(), rgba(0x655f6dff).into()),
|
||||
("function.method".into(), rgba(0x576cdbff).into()),
|
||||
("property".into(), rgba(0xbe4677ff).into()),
|
||||
("operator".into(), rgba(0x8b8792ff).into()),
|
||||
("emphasis.strong".into(), rgba(0x566ddaff).into()),
|
||||
("label".into(), rgba(0x566ddaff).into()),
|
||||
("enum".into(), rgba(0xaa563bff).into()),
|
||||
("number".into(), rgba(0xaa563bff).into()),
|
||||
("primary".into(), rgba(0xe2dfe7ff).into()),
|
||||
("keyword".into(), rgba(0x9559e7ff).into()),
|
||||
(
|
||||
"function.special.definition".into(),
|
||||
rgba(0xa06d3aff).into(),
|
||||
),
|
||||
("punctuation.bracket".into(), rgba(0x8b8792ff).into()),
|
||||
("constant".into(), rgba(0x2b9292ff).into()),
|
||||
("string.special".into(), rgba(0xbf3fbfff).into()),
|
||||
("title".into(), rgba(0xefecf4ff).into()),
|
||||
("preproc".into(), rgba(0xefecf4ff).into()),
|
||||
("link_uri".into(), rgba(0x2b9292ff).into()),
|
||||
("string".into(), rgba(0x299292ff).into()),
|
||||
("embedded".into(), rgba(0xefecf4ff).into()),
|
||||
("hint".into(), rgba(0x706897ff).into()),
|
||||
("boolean".into(), rgba(0x2b9292ff).into()),
|
||||
("variable".into(), rgba(0xe2dfe7ff).into()),
|
||||
("predictive".into(), rgba(0x615787ff).into()),
|
||||
("string.regex".into(), rgba(0x388bc6ff).into()),
|
||||
("function".into(), rgba(0x576cdbff).into()),
|
||||
("attribute".into(), rgba(0x566ddaff).into()),
|
||||
],
|
||||
},
|
||||
status_bar: rgba(0x3a353fff).into(),
|
||||
title_bar: rgba(0x3a353fff).into(),
|
||||
toolbar: rgba(0x19171cff).into(),
|
||||
tab_bar: rgba(0x221f26ff).into(),
|
||||
editor: rgba(0x19171cff).into(),
|
||||
editor_subheader: rgba(0x221f26ff).into(),
|
||||
editor_active_line: rgba(0x221f26ff).into(),
|
||||
terminal: rgba(0x19171cff).into(),
|
||||
image_fallback_background: rgba(0x3a353fff).into(),
|
||||
git_created: rgba(0x2b9292ff).into(),
|
||||
git_modified: rgba(0x566ddaff).into(),
|
||||
git_deleted: rgba(0xbe4677ff).into(),
|
||||
git_conflict: rgba(0xa06d3aff).into(),
|
||||
git_ignored: rgba(0x756f7eff).into(),
|
||||
git_renamed: rgba(0xa06d3aff).into(),
|
||||
players: [
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x566ddaff).into(),
|
||||
selection: rgba(0x566dda3d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x2b9292ff).into(),
|
||||
selection: rgba(0x2b92923d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xbf41bfff).into(),
|
||||
selection: rgba(0xbf41bf3d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xaa563bff).into(),
|
||||
selection: rgba(0xaa563b3d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x955ae6ff).into(),
|
||||
selection: rgba(0x955ae63d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x3a8bc6ff).into(),
|
||||
selection: rgba(0x3a8bc63d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xbe4677ff).into(),
|
||||
selection: rgba(0xbe46773d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xa06d3aff).into(),
|
||||
selection: rgba(0xa06d3a3d).into(),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
136
crates/theme2/src/themes/atelier_cave_light.rs
Normal file
136
crates/theme2/src/themes/atelier_cave_light.rs
Normal file
@ -0,0 +1,136 @@
|
||||
use gpui2::rgba;
|
||||
|
||||
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||
|
||||
pub fn atelier_cave_light() -> Theme {
|
||||
Theme {
|
||||
metadata: ThemeMetadata {
|
||||
name: "Atelier Cave Light".into(),
|
||||
is_light: true,
|
||||
},
|
||||
transparent: rgba(0x00000000).into(),
|
||||
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||
border: rgba(0x8f8b96ff).into(),
|
||||
border_variant: rgba(0x8f8b96ff).into(),
|
||||
border_focused: rgba(0xc8c7f2ff).into(),
|
||||
border_transparent: rgba(0x00000000).into(),
|
||||
elevated_surface: rgba(0xbfbcc5ff).into(),
|
||||
surface: rgba(0xe6e3ebff).into(),
|
||||
background: rgba(0xbfbcc5ff).into(),
|
||||
filled_element: rgba(0xbfbcc5ff).into(),
|
||||
filled_element_hover: rgba(0xffffff1e).into(),
|
||||
filled_element_active: rgba(0xffffff28).into(),
|
||||
filled_element_selected: rgba(0xe1e0f9ff).into(),
|
||||
filled_element_disabled: rgba(0x00000000).into(),
|
||||
ghost_element: rgba(0x00000000).into(),
|
||||
ghost_element_hover: rgba(0xffffff14).into(),
|
||||
ghost_element_active: rgba(0xffffff1e).into(),
|
||||
ghost_element_selected: rgba(0xe1e0f9ff).into(),
|
||||
ghost_element_disabled: rgba(0x00000000).into(),
|
||||
text: rgba(0x19171cff).into(),
|
||||
text_muted: rgba(0x5a5462ff).into(),
|
||||
text_placeholder: rgba(0xbd4677ff).into(),
|
||||
text_disabled: rgba(0x6e6876ff).into(),
|
||||
text_accent: rgba(0x586cdaff).into(),
|
||||
icon_muted: rgba(0x5a5462ff).into(),
|
||||
syntax: SyntaxTheme {
|
||||
highlights: vec![
|
||||
("link_text".into(), rgba(0xaa573cff).into()),
|
||||
("string".into(), rgba(0x299292ff).into()),
|
||||
("emphasis".into(), rgba(0x586cdaff).into()),
|
||||
("label".into(), rgba(0x586cdaff).into()),
|
||||
("property".into(), rgba(0xbe4677ff).into()),
|
||||
("emphasis.strong".into(), rgba(0x586cdaff).into()),
|
||||
("constant".into(), rgba(0x2b9292ff).into()),
|
||||
(
|
||||
"function.special.definition".into(),
|
||||
rgba(0xa06d3aff).into(),
|
||||
),
|
||||
("embedded".into(), rgba(0x19171cff).into()),
|
||||
("punctuation.special".into(), rgba(0xbf3fbfff).into()),
|
||||
("function".into(), rgba(0x576cdbff).into()),
|
||||
("tag".into(), rgba(0x586cdaff).into()),
|
||||
("number".into(), rgba(0xaa563bff).into()),
|
||||
("primary".into(), rgba(0x26232aff).into()),
|
||||
("text.literal".into(), rgba(0xaa573cff).into()),
|
||||
("variant".into(), rgba(0xa06d3aff).into()),
|
||||
("type".into(), rgba(0xa06d3aff).into()),
|
||||
("punctuation".into(), rgba(0x26232aff).into()),
|
||||
("string.escape".into(), rgba(0x585260ff).into()),
|
||||
("keyword".into(), rgba(0x9559e7ff).into()),
|
||||
("title".into(), rgba(0x19171cff).into()),
|
||||
("constructor".into(), rgba(0x586cdaff).into()),
|
||||
("punctuation.list_marker".into(), rgba(0x26232aff).into()),
|
||||
("string.special".into(), rgba(0xbf3fbfff).into()),
|
||||
("operator".into(), rgba(0x585260ff).into()),
|
||||
("function.method".into(), rgba(0x576cdbff).into()),
|
||||
("link_uri".into(), rgba(0x2b9292ff).into()),
|
||||
("variable.special".into(), rgba(0x9559e7ff).into()),
|
||||
("hint".into(), rgba(0x776d9dff).into()),
|
||||
("punctuation.bracket".into(), rgba(0x585260ff).into()),
|
||||
("string.special.symbol".into(), rgba(0x299292ff).into()),
|
||||
("predictive".into(), rgba(0x887fafff).into()),
|
||||
("attribute".into(), rgba(0x586cdaff).into()),
|
||||
("enum".into(), rgba(0xaa573cff).into()),
|
||||
("preproc".into(), rgba(0x19171cff).into()),
|
||||
("boolean".into(), rgba(0x2b9292ff).into()),
|
||||
("variable".into(), rgba(0x26232aff).into()),
|
||||
("comment.doc".into(), rgba(0x585260ff).into()),
|
||||
("string.regex".into(), rgba(0x388bc6ff).into()),
|
||||
("punctuation.delimiter".into(), rgba(0x585260ff).into()),
|
||||
("comment".into(), rgba(0x7d7787ff).into()),
|
||||
],
|
||||
},
|
||||
status_bar: rgba(0xbfbcc5ff).into(),
|
||||
title_bar: rgba(0xbfbcc5ff).into(),
|
||||
toolbar: rgba(0xefecf4ff).into(),
|
||||
tab_bar: rgba(0xe6e3ebff).into(),
|
||||
editor: rgba(0xefecf4ff).into(),
|
||||
editor_subheader: rgba(0xe6e3ebff).into(),
|
||||
editor_active_line: rgba(0xe6e3ebff).into(),
|
||||
terminal: rgba(0xefecf4ff).into(),
|
||||
image_fallback_background: rgba(0xbfbcc5ff).into(),
|
||||
git_created: rgba(0x2b9292ff).into(),
|
||||
git_modified: rgba(0x586cdaff).into(),
|
||||
git_deleted: rgba(0xbd4677ff).into(),
|
||||
git_conflict: rgba(0xa06e3bff).into(),
|
||||
git_ignored: rgba(0x6e6876ff).into(),
|
||||
git_renamed: rgba(0xa06e3bff).into(),
|
||||
players: [
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x586cdaff).into(),
|
||||
selection: rgba(0x586cda3d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x2b9292ff).into(),
|
||||
selection: rgba(0x2b92923d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xbf41bfff).into(),
|
||||
selection: rgba(0xbf41bf3d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xaa573cff).into(),
|
||||
selection: rgba(0xaa573c3d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x955ae6ff).into(),
|
||||
selection: rgba(0x955ae63d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x3a8bc6ff).into(),
|
||||
selection: rgba(0x3a8bc63d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xbd4677ff).into(),
|
||||
selection: rgba(0xbd46773d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xa06e3bff).into(),
|
||||
selection: rgba(0xa06e3b3d).into(),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
136
crates/theme2/src/themes/atelier_dune_dark.rs
Normal file
136
crates/theme2/src/themes/atelier_dune_dark.rs
Normal file
@ -0,0 +1,136 @@
|
||||
use gpui2::rgba;
|
||||
|
||||
use crate::{PlayerTheme, SyntaxTheme, Theme, ThemeMetadata};
|
||||
|
||||
pub fn atelier_dune_dark() -> Theme {
|
||||
Theme {
|
||||
metadata: ThemeMetadata {
|
||||
name: "Atelier Dune Dark".into(),
|
||||
is_light: false,
|
||||
},
|
||||
transparent: rgba(0x00000000).into(),
|
||||
mac_os_traffic_light_red: rgba(0xec695eff).into(),
|
||||
mac_os_traffic_light_yellow: rgba(0xf4bf4eff).into(),
|
||||
mac_os_traffic_light_green: rgba(0x61c553ff).into(),
|
||||
border: rgba(0x6c695cff).into(),
|
||||
border_variant: rgba(0x6c695cff).into(),
|
||||
border_focused: rgba(0x262f56ff).into(),
|
||||
border_transparent: rgba(0x00000000).into(),
|
||||
elevated_surface: rgba(0x45433bff).into(),
|
||||
surface: rgba(0x262622ff).into(),
|
||||
background: rgba(0x45433bff).into(),
|
||||
filled_element: rgba(0x45433bff).into(),
|
||||
filled_element_hover: rgba(0xffffff1e).into(),
|
||||
filled_element_active: rgba(0xffffff28).into(),
|
||||
filled_element_selected: rgba(0x171e38ff).into(),
|
||||
filled_element_disabled: rgba(0x00000000).into(),
|
||||
ghost_element: rgba(0x00000000).into(),
|
||||
ghost_element_hover: rgba(0xffffff14).into(),
|
||||
ghost_element_active: rgba(0xffffff1e).into(),
|
||||
ghost_element_selected: rgba(0x171e38ff).into(),
|
||||
ghost_element_disabled: rgba(0x00000000).into(),
|
||||
text: rgba(0xfefbecff).into(),
|
||||
text_muted: rgba(0xa4a08bff).into(),
|
||||
text_placeholder: rgba(0xd73837ff).into(),
|
||||
text_disabled: rgba(0x8f8b77ff).into(),
|
||||
text_accent: rgba(0x6684e0ff).into(),
|
||||
icon_muted: rgba(0xa4a08bff).into(),
|
||||
syntax: SyntaxTheme {
|
||||
highlights: vec![
|
||||
("constructor".into(), rgba(0x6684e0ff).into()),
|
||||
("punctuation".into(), rgba(0xe8e4cfff).into()),
|
||||
("punctuation.delimiter".into(), rgba(0xa6a28cff).into()),
|
||||
("string.special".into(), rgba(0xd43451ff).into()),
|
||||
("string.escape".into(), rgba(0xa6a28cff).into()),
|
||||
("comment".into(), rgba(0x7d7a68ff).into()),
|
||||
("enum".into(), rgba(0xb65611ff).into()),
|
||||
("variable.special".into(), rgba(0xb854d4ff).into()),
|
||||
("primary".into(), rgba(0xe8e4cfff).into()),
|
||||
("comment.doc".into(), rgba(0xa6a28cff).into()),
|
||||
("label".into(), rgba(0x6684e0ff).into()),
|
||||
("operator".into(), rgba(0xa6a28cff).into()),
|
||||
("string".into(), rgba(0x5fac38ff).into()),
|
||||
("variant".into(), rgba(0xae9512ff).into()),
|
||||
("variable".into(), rgba(0xe8e4cfff).into()),
|
||||
("function.method".into(), rgba(0x6583e1ff).into()),
|
||||
(
|
||||
"function.special.definition".into(),
|
||||
rgba(0xae9512ff).into(),
|
||||
),
|
||||
("string.regex".into(), rgba(0x1ead82ff).into()),
|
||||
("emphasis.strong".into(), rgba(0x6684e0ff).into()),
|
||||
("punctuation.special".into(), rgba(0xd43451ff).into()),
|
||||
("punctuation.bracket".into(), rgba(0xa6a28cff).into()),
|
||||
("link_text".into(), rgba(0xb65611ff).into()),
|
||||
("link_uri".into(), rgba(0x5fac39ff).into()),
|
||||
("boolean".into(), rgba(0x5fac39ff).into()),
|
||||
("hint".into(), rgba(0xb17272ff).into()),
|
||||
("tag".into(), rgba(0x6684e0ff).into()),
|
||||
("function".into(), rgba(0x6583e1ff).into()),
|
||||
("title".into(), rgba(0xfefbecff).into()),
|
||||
("property".into(), rgba(0xd73737ff).into()),
|
||||
("type".into(), rgba(0xae9512ff).into()),
|
||||
("constant".into(), rgba(0x5fac39ff).into()),
|
||||
("attribute".into(), rgba(0x6684e0ff).into()),
|
||||
("predictive".into(), rgba(0x9c6262ff).into()),
|
||||
("string.special.symbol".into(), rgba(0x5fac38ff).into()),
|
||||
("punctuation.list_marker".into(), rgba(0xe8e4cfff).into()),
|
||||
("emphasis".into(), rgba(0x6684e0ff).into()),
|
||||
("keyword".into(), rgba(0xb854d4ff).into()),
|
||||
("text.literal".into(), rgba(0xb65611ff).into()),
|
||||
("number".into(), rgba(0xb65610ff).into()),
|
||||
("preproc".into(), rgba(0xfefbecff).into()),
|
||||
("embedded".into(), rgba(0xfefbecff).into()),
|
||||
],
|
||||
},
|
||||
status_bar: rgba(0x45433bff).into(),
|
||||
title_bar: rgba(0x45433bff).into(),
|
||||
toolbar: rgba(0x20201dff).into(),
|
||||
tab_bar: rgba(0x262622ff).into(),
|
||||
editor: rgba(0x20201dff).into(),
|
||||
editor_subheader: rgba(0x262622ff).into(),
|
||||
editor_active_line: rgba(0x262622ff).into(),
|
||||
terminal: rgba(0x20201dff).into(),
|
||||
image_fallback_background: rgba(0x45433bff).into(),
|
||||
git_created: rgba(0x5fac39ff).into(),
|
||||
git_modified: rgba(0x6684e0ff).into(),
|
||||
git_deleted: rgba(0xd73837ff).into(),
|
||||
git_conflict: rgba(0xae9414ff).into(),
|
||||
git_ignored: rgba(0x8f8b77ff).into(),
|
||||
git_renamed: rgba(0xae9414ff).into(),
|
||||
players: [
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x6684e0ff).into(),
|
||||
selection: rgba(0x6684e03d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x5fac39ff).into(),
|
||||
selection: rgba(0x5fac393d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xd43651ff).into(),
|
||||
selection: rgba(0xd436513d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xb65611ff).into(),
|
||||
selection: rgba(0xb656113d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xb854d3ff).into(),
|
||||
selection: rgba(0xb854d33d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0x20ad83ff).into(),
|
||||
selection: rgba(0x20ad833d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xd73837ff).into(),
|
||||
selection: rgba(0xd738373d).into(),
|
||||
},
|
||||
PlayerTheme {
|
||||
cursor: rgba(0xae9414ff).into(),
|
||||
selection: rgba(0xae94143d).into(),
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user