Merge branch 'main' into guest-exp

This commit is contained in:
Conrad Irwin 2023-10-23 17:47:21 +02:00
commit ea4e67fb76
141 changed files with 6720 additions and 2077 deletions

57
Cargo.lock generated
View File

@ -91,6 +91,7 @@ dependencies = [
"futures 0.3.28", "futures 0.3.28",
"gpui", "gpui",
"isahc", "isahc",
"language",
"lazy_static", "lazy_static",
"log", "log",
"matrixmultiply", "matrixmultiply",
@ -1467,7 +1468,7 @@ dependencies = [
[[package]] [[package]]
name = "collab" name = "collab"
version = "0.24.0" version = "0.25.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
@ -1503,6 +1504,7 @@ dependencies = [
"lsp", "lsp",
"nanoid", "nanoid",
"node_runtime", "node_runtime",
"notifications",
"parking_lot 0.11.2", "parking_lot 0.11.2",
"pretty_assertions", "pretty_assertions",
"project", "project",
@ -1558,13 +1560,17 @@ dependencies = [
"fuzzy", "fuzzy",
"gpui", "gpui",
"language", "language",
"lazy_static",
"log", "log",
"menu", "menu",
"notifications",
"picker", "picker",
"postage", "postage",
"pretty_assertions",
"project", "project",
"recent_projects", "recent_projects",
"rich_text", "rich_text",
"rpc",
"schemars", "schemars",
"serde", "serde",
"serde_derive", "serde_derive",
@ -1573,6 +1579,7 @@ dependencies = [
"theme", "theme",
"theme_selector", "theme_selector",
"time", "time",
"tree-sitter-markdown",
"util", "util",
"vcs_menu", "vcs_menu",
"workspace", "workspace",
@ -4730,6 +4737,26 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "notifications"
version = "0.1.0"
dependencies = [
"anyhow",
"channel",
"client",
"clock",
"collections",
"db",
"feature_flags",
"gpui",
"rpc",
"settings",
"sum_tree",
"text",
"time",
"util",
]
[[package]] [[package]]
name = "ntapi" name = "ntapi"
version = "0.3.7" version = "0.3.7"
@ -6404,8 +6431,10 @@ dependencies = [
"rsa 0.4.0", "rsa 0.4.0",
"serde", "serde",
"serde_derive", "serde_derive",
"serde_json",
"smol", "smol",
"smol-timeout", "smol-timeout",
"strum",
"tempdir", "tempdir",
"tracing", "tracing",
"util", "util",
@ -6626,6 +6655,12 @@ dependencies = [
"untrusted", "untrusted",
] ]
[[package]]
name = "rustversion"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
[[package]] [[package]]
name = "rustybuzz" name = "rustybuzz"
version = "0.3.0" version = "0.3.0"
@ -7700,6 +7735,22 @@ name = "strum"
version = "0.25.0" version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.25.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059"
dependencies = [
"heck 0.4.1",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.37",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
@ -9098,6 +9149,7 @@ name = "vcs_menu"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"fs",
"fuzzy", "fuzzy",
"gpui", "gpui",
"picker", "picker",
@ -10042,7 +10094,7 @@ dependencies = [
[[package]] [[package]]
name = "zed" name = "zed"
version = "0.109.0" version = "0.110.0"
dependencies = [ dependencies = [
"activity_indicator", "activity_indicator",
"ai", "ai",
@ -10097,6 +10149,7 @@ dependencies = [
"log", "log",
"lsp", "lsp",
"node_runtime", "node_runtime",
"notifications",
"num_cpus", "num_cpus",
"outline", "outline",
"parking_lot 0.11.2", "parking_lot 0.11.2",

View File

@ -47,6 +47,7 @@ members = [
"crates/media", "crates/media",
"crates/menu", "crates/menu",
"crates/node_runtime", "crates/node_runtime",
"crates/notifications",
"crates/outline", "crates/outline",
"crates/picker", "crates/picker",
"crates/plugin", "crates/plugin",
@ -112,6 +113,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] } serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] }
smallvec = { version = "1.6", features = ["union"] } smallvec = { version = "1.6", features = ["union"] }
smol = { version = "1.2" } smol = { version = "1.2" }
strum = { version = "0.25.0", features = ["derive"] }
sysinfo = "0.29.10" sysinfo = "0.29.10"
tempdir = { version = "0.3.7" } tempdir = { version = "0.3.7" }
thiserror = { version = "1.0.29" } thiserror = { version = "1.0.29" }

8
assets/icons/bell.svg Normal file
View File

@ -0,0 +1,8 @@
<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
<path
fill-rule="evenodd"
clip-rule="evenodd"
d="M8.60124 1.25086C8.60124 1.75459 8.26278 2.17927 7.80087 2.30989C10.1459 2.4647 12 4.41582 12 6.79999V10.25C12 11.0563 12.0329 11.7074 12.7236 12.0528C12.931 12.1565 13.0399 12.3892 12.9866 12.6149C12.9333 12.8406 12.7319 13 12.5 13H8.16144C8.36904 13.1832 8.49997 13.4513 8.49997 13.75C8.49997 14.3023 8.05226 14.75 7.49997 14.75C6.94769 14.75 6.49997 14.3023 6.49997 13.75C6.49997 13.4513 6.63091 13.1832 6.83851 13H2.49999C2.2681 13 2.06664 12.8406 2.01336 12.6149C1.96009 12.3892 2.06897 12.1565 2.27638 12.0528C2.96708 11.7074 2.99999 11.0563 2.99999 10.25V6.79999C2.99999 4.41537 4.85481 2.46396 7.20042 2.3098C6.73867 2.17908 6.40036 1.75448 6.40036 1.25086C6.40036 0.643104 6.89304 0.150421 7.5008 0.150421C8.10855 0.150421 8.60124 0.643104 8.60124 1.25086ZM7.49999 3.29999C5.56699 3.29999 3.99999 4.86699 3.99999 6.79999V10.25L4.00002 10.3009C4.0005 10.7463 4.00121 11.4084 3.69929 12H11.3007C10.9988 11.4084 10.9995 10.7463 11 10.3009L11 10.25V6.79999C11 4.86699 9.43299 3.29999 7.49999 3.29999Z"
fill="currentColor"
/>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -142,6 +142,14 @@
// Default width of the channels panel. // Default width of the channels panel.
"default_width": 240 "default_width": 240
}, },
"notification_panel": {
// Whether to show the collaboration panel button in the status bar.
"button": true,
// Where to dock channels panel. Can be 'left' or 'right'.
"dock": "right",
// Default width of the channels panel.
"default_width": 240
},
"assistant": { "assistant": {
// Whether to show the assistant panel button in the status bar. // Whether to show the assistant panel button in the status bar.
"button": true, "button": true,

View File

@ -11,6 +11,7 @@ doctest = false
[dependencies] [dependencies]
gpui = { path = "../gpui" } gpui = { path = "../gpui" }
util = { path = "../util" } util = { path = "../util" }
language = { path = "../language" }
async-trait.workspace = true async-trait.workspace = true
anyhow.workspace = true anyhow.workspace = true
futures.workspace = true futures.workspace = true

View File

@ -1,2 +1,4 @@
pub mod completion; pub mod completion;
pub mod embedding; pub mod embedding;
pub mod models;
pub mod templates;

View File

@ -53,6 +53,8 @@ pub struct OpenAIRequest {
pub model: String, pub model: String,
pub messages: Vec<RequestMessage>, pub messages: Vec<RequestMessage>,
pub stream: bool, pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

View File

@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait; use async_trait::async_trait;
use futures::AsyncReadExt; use futures::AsyncReadExt;
use gpui::executor::Background; use gpui::executor::Background;
use gpui::serde_json; use gpui::{serde_json, ViewContext};
use isahc::http::StatusCode; use isahc::http::StatusCode;
use isahc::prelude::Configurable; use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response}; use isahc::{AsyncBody, Response};
@ -20,9 +20,11 @@ use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE}; use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request}; use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::completion::OPENAI_API_URL;
lazy_static! { lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
} }
@ -87,6 +89,7 @@ impl Embedding {
#[derive(Clone)] #[derive(Clone)]
pub struct OpenAIEmbeddings { pub struct OpenAIEmbeddings {
pub api_key: Option<String>,
pub client: Arc<dyn HttpClient>, pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>, pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>, rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings {
const OPENAI_INPUT_LIMIT: usize = 8190; const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings { impl OpenAIEmbeddings {
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self { pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
if self.api_key.is_none() {
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
};
if let Some(api_key) = api_key {
self.api_key = Some(api_key);
}
}
}
pub fn new(
api_key: Option<String>,
client: Arc<dyn HttpClient>,
executor: Arc<Background>,
) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings { OpenAIEmbeddings {
api_key,
client, client,
executor, executor,
rate_limit_count_rx, rate_limit_count_rx,
@ -237,8 +265,9 @@ impl OpenAIEmbeddings {
#[async_trait] #[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings { impl EmbeddingProvider for OpenAIEmbeddings {
fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
OPENAI_API_KEY.as_ref().is_some() self.api_key.is_some()
} }
fn max_tokens_per_batch(&self) -> usize { fn max_tokens_per_batch(&self) -> usize {
50000 50000
} }
@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4; const MAX_RETRIES: usize = 4;
let api_key = OPENAI_API_KEY let Some(api_key) = self.api_key.clone() else {
.as_ref() return Err(anyhow!("no open ai key provided"));
.ok_or_else(|| anyhow!("no api key"))?; };
let mut request_number = 0; let mut request_number = 0;
let mut rate_limiting = false; let mut rate_limiting = false;
@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
while request_number < MAX_RETRIES { while request_number < MAX_RETRIES {
response = self response = self
.send_request( .send_request(
api_key, &api_key,
spans.iter().map(|x| &**x).collect(), spans.iter().map(|x| &**x).collect(),
request_timeout, request_timeout,
) )

66
crates/ai/src/models.rs Normal file
View File

@ -0,0 +1,66 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use util::ResultExt;
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
pub struct OpenAILanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAILanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
OpenAILanguageModel {
name: model_name.to_string(),
bpe,
}
}
}
impl LanguageModel for OpenAILanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
bpe.decode(tokens[..length].to_vec())
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
bpe.decode(tokens[length..].to_vec())
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

View File

@ -0,0 +1,350 @@
use std::cmp::Reverse;
use std::ops::Range;
use std::sync::Arc;
use language::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::templates::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
Code,
}
// TODO: Set this up to manage for defaults well
pub struct PromptArguments {
pub model: Arc<dyn LanguageModel>,
pub user_prompt: Option<String>,
pub language_name: Option<String>,
pub project_name: Option<String>,
pub snippets: Vec<PromptCodeSnippet>,
pub reserved_tokens: usize,
pub buffer: Option<BufferSnapshot>,
pub selected_range: Option<Range<usize>>,
}
impl PromptArguments {
pub(crate) fn get_file_type(&self) -> PromptFileType {
if self
.language_name
.as_ref()
.and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
.unwrap_or(true)
{
PromptFileType::Code
} else {
PromptFileType::Text
}
}
}
pub trait PromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)>;
}
#[repr(i8)]
#[derive(PartialEq, Eq, Ord)]
pub enum PromptPriority {
Mandatory, // Ignores truncation
Ordered { order: usize }, // Truncates based on priority
}
impl PartialOrd for PromptPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
(Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
(Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
}
}
}
pub struct PromptChain {
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
}
impl PromptChain {
pub fn new(
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
) -> Self {
PromptChain { args, templates }
}
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
// Argsort based on Prompt Priority
let seperator = "\n";
let seperator_tokens = self.args.model.count_tokens(seperator)?;
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
// If Truncate
let mut tokens_outstanding = if truncate {
Some(self.args.model.capacity()? - self.args.reserved_tokens)
} else {
None
};
let mut prompts = vec!["".to_string(); sorted_indices.len()];
for idx in sorted_indices {
let (_, template) = &self.templates[idx];
if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err()
{
if template_prompt != "" {
prompts[idx] = template_prompt;
if let Some(remaining_tokens) = tokens_outstanding {
let new_tokens = prompt_token_count + seperator_tokens;
tokens_outstanding = if remaining_tokens > new_tokens {
Some(remaining_tokens - new_tokens)
} else {
Some(0)
};
}
}
}
}
prompts.retain(|x| x != "");
let full_prompt = prompts.join(seperator);
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
anyhow::Ok((prompts.join(seperator), total_token_count))
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
#[test]
pub fn test_prompt_chain() {
struct TestPromptTemplate {}
impl PromptTemplate for TestPromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(&content, max_token_length)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
struct TestLowPriorityTemplate {}
impl PromptTemplate for TestLowPriorityTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a low priority test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(&content, max_token_length)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
#[derive(Clone)]
struct DummyLanguageModel {
capacity: usize,
}
impl LanguageModel for DummyLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
anyhow::Ok(
content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
)
}
fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
anyhow::Ok(
content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
)
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 2 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(prompt, "This is a test promp".to_string());
assert_eq!(token_count, capacity);
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Mandatory,
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(
prompt,
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
.to_string()
);
assert_eq!(token_count, capacity - reserved_tokens);
}
}

View File

@ -0,0 +1,160 @@
use anyhow::anyhow;
use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
use crate::templates::base::PromptArguments;
use crate::templates::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate_start(&start_window, start_goal_tokens)?;
let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
impl PromptTemplate for FileContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
if let Some(buffer) = &args.buffer {
let mut prompt = String::new();
// Add Initial Preamble
// TODO: Do we want to add the path in here?
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
let language_name = args
.language_name
.clone()
.unwrap_or("".to_string())
.to_lowercase();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if truncated {
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
}
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(&prompt, max_tokens)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
} else {
Err(anyhow!("no buffer provided to retrieve file context from"))
}
}
}

View File

@ -0,0 +1,95 @@
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
pub fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}
pub struct GenerateInlineContent {}
impl PromptTemplate for GenerateInlineContent {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let Some(user_prompt) = &args.user_prompt else {
return Err(anyhow!("user prompt not provided"));
};
let file_type = args.get_file_type();
let content_type = match &file_type {
PromptFileType::Code => "code",
PromptFileType::Text => "text",
};
let mut prompt = String::new();
if let Some(selected_range) = &args.selected_range {
if selected_range.start == selected_range.end {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
capitalize(content_type)
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
}
} else {
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}"
)
.unwrap();
}
if let Some(language_name) = &args.language_name {
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
match file_type {
PromptFileType::Code => {
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
}
_ => {}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(&prompt, max_tokens)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}

View File

@ -0,0 +1,5 @@
pub mod base;
pub mod file_context;
pub mod generate;
pub mod preamble;
pub mod repository_context;

View File

@ -0,0 +1,52 @@
use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut prompts = Vec::new();
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
"You are an expert {}engineer.",
args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
prompts.push("You are an expert engineer.".to_string());
}
}
if let Some(project_name) = args.project_name.clone() {
prompts.push(format!(
"You are currently working inside the '{project_name}' project in code editor Zed."
));
}
if let Some(mut remaining_tokens) = max_token_length {
let mut prompt = String::new();
let mut total_count = 0;
for prompt_piece in prompts {
let prompt_token_count =
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
if remaining_tokens > prompt_token_count {
writeln!(prompt, "{prompt_piece}").unwrap();
remaining_tokens -= prompt_token_count;
total_count += prompt_token_count;
}
}
anyhow::Ok((prompt, total_count))
} else {
let prompt = prompts.join("\n");
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}
}

View File

@ -0,0 +1,94 @@
use crate::templates::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui::{AsyncAppContext, ModelHandle};
use language::{Anchor, Buffer};
#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
let language_name = buffer
.language()
.and_then(|language| Some(language.name().to_string().to_lowercase()));
let file_path = buffer
.file()
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
});
PromptCodeSnippet {
path: file_path,
language_name,
content,
}
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.and_then(|path| Some(path.to_string_lossy().to_string()))
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
pub struct RepositoryContext {}
impl PromptTemplate for RepositoryContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
let mut prompt = String::new();
let mut remaining_tokens = max_token_length.clone();
let seperator_token_length = args.model.count_tokens("\n")?;
for snippet in &args.snippets {
let mut snippet_prompt = template.to_string();
let content = snippet.to_string();
writeln!(snippet_prompt, "{content}").unwrap();
let token_count = args.model.count_tokens(&snippet_prompt)?;
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
if let Some(tokens_left) = remaining_tokens {
if tokens_left >= token_count {
writeln!(prompt, "{snippet_prompt}").unwrap();
remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
{
Some(tokens_left - token_count - seperator_token_length)
} else {
Some(0)
};
}
} else {
writeln!(prompt, "{snippet_prompt}").unwrap();
}
}
}
let total_token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, total_token_count))
}
}

View File

@ -1,12 +1,15 @@
use crate::{ use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind}, codegen::{self, Codegen, CodegenKind},
prompts::{generate_content_prompt, PromptCodeSnippet}, prompts::generate_content_prompt,
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage, SavedMessage,
}; };
use ai::completion::{ use ai::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, completion::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
},
templates::repository_context::PromptCodeSnippet,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
@ -609,6 +612,18 @@ impl AssistantPanel {
let project = pending_assist.project.clone(); let project = pending_assist.project.clone();
let project_name = if let Some(project) = project.upgrade(cx) {
Some(
project
.read(cx)
.worktree_root_names(cx)
.collect::<Vec<&str>>()
.join("/"),
)
} else {
None
};
self.inline_prompt_history self.inline_prompt_history
.retain(|prompt| prompt != user_prompt); .retain(|prompt| prompt != user_prompt);
self.inline_prompt_history.push_back(user_prompt.into()); self.inline_prompt_history.push_back(user_prompt.into());
@ -646,7 +661,19 @@ impl AssistantPanel {
None None
}; };
let codegen_kind = codegen.read(cx).kind().clone(); // Higher Temperature increases the randomness of model outputs.
// If Markdown or No Language is Known, increase the randomness for more creative output
// If Code, decrease temperature to get more deterministic outputs
let temperature = if let Some(language) = language_name.clone() {
if language.to_string() != "Markdown".to_string() {
0.5
} else {
1.0
}
} else {
1.0
};
let user_prompt = user_prompt.to_string(); let user_prompt = user_prompt.to_string();
let snippets = if retrieve_context { let snippets = if retrieve_context {
@ -668,14 +695,7 @@ impl AssistantPanel {
let snippets = cx.spawn(|_, cx| async move { let snippets = cx.spawn(|_, cx| async move {
let mut snippets = Vec::new(); let mut snippets = Vec::new();
for result in search_results.await { for result in search_results.await {
snippets.push(PromptCodeSnippet::new(result, &cx)); snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx));
// snippets.push(result.buffer.read_with(&cx, |buffer, _| {
// buffer
// .snapshot()
// .text_for_range(result.range)
// .collect::<String>()
// }));
} }
snippets snippets
}); });
@ -696,11 +716,11 @@ impl AssistantPanel {
generate_content_prompt( generate_content_prompt(
user_prompt, user_prompt,
language_name, language_name,
&buffer, buffer,
range, range,
codegen_kind,
snippets, snippets,
model_name, model_name,
project_name,
) )
}); });
@ -717,18 +737,23 @@ impl AssistantPanel {
} }
cx.spawn(|_, mut cx| async move { cx.spawn(|_, mut cx| async move {
let prompt = prompt.await; // I Don't know if we want to return a ? here.
let prompt = prompt.await?;
messages.push(RequestMessage { messages.push(RequestMessage {
role: Role::User, role: Role::User,
content: prompt, content: prompt,
}); });
let request = OpenAIRequest { let request = OpenAIRequest {
model: model.full_name().into(), model: model.full_name().into(),
messages, messages,
stream: true, stream: true,
stop: vec!["|END|>".to_string()],
temperature,
}; };
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
anyhow::Ok(())
}) })
.detach(); .detach();
} }
@ -1718,6 +1743,8 @@ impl Conversation {
.map(|message| message.to_open_ai_message(self.buffer.read(cx))) .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.collect(), .collect(),
stream: true, stream: true,
stop: vec![],
temperature: 1.0,
}; };
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
@ -2002,6 +2029,8 @@ impl Conversation {
model: self.model.full_name().to_string(), model: self.model.full_name().to_string(),
messages: messages.collect(), messages: messages.collect(),
stream: true, stream: true,
stop: vec![],
temperature: 1.0,
}; };
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);

View File

@ -1,60 +1,13 @@
use crate::codegen::CodegenKind; use ai::models::{LanguageModel, OpenAILanguageModel};
use gpui::AsyncAppContext; use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::templates::file_context::FileContext;
use ai::templates::generate::GenerateInlineContent;
use ai::templates::preamble::EngineerPreamble;
use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use semantic_index::SearchResult;
use std::cmp::{self, Reverse}; use std::cmp::{self, Reverse};
use std::fmt::Write;
use std::ops::Range; use std::ops::Range;
use std::path::PathBuf; use std::sync::Arc;
use tiktoken_rs::ChatCompletionRequestMessage;
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
let (content, language_name, file_path) =
search_result.buffer.read_with(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot
.text_for_range(search_result.range.clone())
.collect::<String>();
let language_name = buffer
.language()
.and_then(|language| Some(language.name().to_string()));
let file_path = buffer
.file()
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
});
PromptCodeSnippet {
path: file_path,
language_name,
content,
}
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.and_then(|path| Some(path.to_string_lossy().to_string()))
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
#[allow(dead_code)] #[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String { fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
@ -170,134 +123,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
pub fn generate_content_prompt( pub fn generate_content_prompt(
user_prompt: String, user_prompt: String,
language_name: Option<&str>, language_name: Option<&str>,
buffer: &BufferSnapshot, buffer: BufferSnapshot,
range: Range<impl ToOffset>, range: Range<usize>,
kind: CodegenKind,
search_results: Vec<PromptCodeSnippet>, search_results: Vec<PromptCodeSnippet>,
model: &str, model: &str,
) -> String { project_name: Option<String>,
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; ) -> anyhow::Result<String> {
const RESERVED_TOKENS_FOR_GENERATION: usize = 1000; // Using new Prompt Templates
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
let mut prompts = Vec::new(); let lang_name = if let Some(language_name) = language_name {
let range = range.to_offset(buffer); Some(language_name.to_string())
// General Preamble
if let Some(language_name) = language_name {
prompts.push(format!("You're an expert {language_name} engineer.\n"));
} else { } else {
prompts.push("You're an expert engineer.\n".to_string()); None
}
// Snippets
let mut snippet_position = prompts.len() - 1;
let mut content = String::new();
content.extend(buffer.text_for_range(0..range.start));
if range.start == range.end {
content.push_str("<|START|>");
} else {
content.push_str("<|START|");
}
content.extend(buffer.text_for_range(range.clone()));
if range.start != range.end {
content.push_str("|END|>");
}
content.extend(buffer.text_for_range(range.end..buffer.len()));
prompts.push("The file you are currently working on has the following content:\n".to_string());
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
prompts.push(format!("```{language_name}\n{content}\n```"));
} else {
prompts.push(format!("```\n{content}\n```"));
}
match kind {
CodegenKind::Generate { position: _ } => {
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
prompts
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
prompts.push(
"Text can't be replaced, so assume your answer will be inserted at the cursor."
.to_string(),
);
prompts.push(format!(
"Generate text based on the users prompt: {user_prompt}"
));
}
CodegenKind::Transform { range: _ } => {
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
prompts.push(format!(
"Modify the users code selected text based upon the users prompt: '{user_prompt}'"
));
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
}
}
if let Some(language_name) = language_name {
prompts.push(format!(
"Your answer MUST always and only be valid {language_name}"
));
}
prompts.push("Never make remarks about the output.".to_string());
prompts.push("Do not return any text, except the generated code.".to_string());
prompts.push("Do not wrap your text in a Markdown block".to_string());
let current_messages = [ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(prompts.join("\n")),
function_call: None,
name: None,
}];
let mut remaining_token_count = if let Ok(current_token_count) =
tiktoken_rs::num_tokens_from_messages(model, &current_messages)
{
let max_token_count = tiktoken_rs::model::get_context_size(model);
let intermediate_token_count = max_token_count - current_token_count;
if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
0
} else {
intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
}
} else {
// If tiktoken fails to count token count, assume we have no space remaining.
0
}; };
// TODO: let args = PromptArguments {
// - add repository name to snippet model: openai_model,
// - add file path language_name: lang_name.clone(),
// - add language project_name,
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) { snippets: search_results.clone(),
let mut template = "You are working inside a large repository, here are a few code snippets that may be useful"; reserved_tokens: 1000,
buffer: Some(buffer),
selected_range: Some(range),
user_prompt: Some(user_prompt.clone()),
};
for search_result in search_results { let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
let mut snippet_prompt = template.to_string(); (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
let snippet = search_result.to_string(); (
writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap(); PromptPriority::Ordered { order: 1 },
Box::new(RepositoryContext {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(FileContext {}),
),
(
PromptPriority::Mandatory,
Box::new(GenerateInlineContent {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, _) = chain.generate(true)?;
let token_count = encoding anyhow::Ok(prompt)
.encode_with_special_tokens(snippet_prompt.as_str())
.len();
if token_count <= remaining_token_count {
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
prompts.insert(snippet_position, snippet_prompt);
snippet_position += 1;
remaining_token_count -= token_count;
// If you have already added the template to the prompt, remove the template.
template = "";
}
} else {
break;
}
}
}
prompts.join("\n")
} }
#[cfg(test)] #[cfg(test)]

View File

@ -7,7 +7,10 @@ use gpui::{AppContext, ModelHandle};
use std::sync::Arc; use std::sync::Arc;
pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL}; pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL};
pub use channel_chat::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId}; pub use channel_chat::{
mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId,
MessageParams,
};
pub use channel_store::{ pub use channel_store::{
Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore, Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore,
}; };

View File

@ -3,12 +3,17 @@ use anyhow::{anyhow, Result};
use client::{ use client::{
proto, proto,
user::{User, UserStore}, user::{User, UserStore},
Client, Subscription, TypedEnvelope, Client, Subscription, TypedEnvelope, UserId,
}; };
use futures::lock::Mutex; use futures::lock::Mutex;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
use rand::prelude::*; use rand::prelude::*;
use std::{collections::HashSet, mem, ops::Range, sync::Arc}; use std::{
collections::HashSet,
mem,
ops::{ControlFlow, Range},
sync::Arc,
};
use sum_tree::{Bias, SumTree}; use sum_tree::{Bias, SumTree};
use time::OffsetDateTime; use time::OffsetDateTime;
use util::{post_inc, ResultExt as _, TryFutureExt}; use util::{post_inc, ResultExt as _, TryFutureExt};
@ -16,6 +21,7 @@ use util::{post_inc, ResultExt as _, TryFutureExt};
pub struct ChannelChat { pub struct ChannelChat {
pub channel_id: ChannelId, pub channel_id: ChannelId,
messages: SumTree<ChannelMessage>, messages: SumTree<ChannelMessage>,
acknowledged_message_ids: HashSet<u64>,
channel_store: ModelHandle<ChannelStore>, channel_store: ModelHandle<ChannelStore>,
loaded_all_messages: bool, loaded_all_messages: bool,
last_acknowledged_id: Option<u64>, last_acknowledged_id: Option<u64>,
@ -27,6 +33,12 @@ pub struct ChannelChat {
_subscription: Subscription, _subscription: Subscription,
} }
#[derive(Debug, PartialEq, Eq)]
pub struct MessageParams {
pub text: String,
pub mentions: Vec<(Range<usize>, UserId)>,
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ChannelMessage { pub struct ChannelMessage {
pub id: ChannelMessageId, pub id: ChannelMessageId,
@ -34,6 +46,7 @@ pub struct ChannelMessage {
pub timestamp: OffsetDateTime, pub timestamp: OffsetDateTime,
pub sender: Arc<User>, pub sender: Arc<User>,
pub nonce: u128, pub nonce: u128,
pub mentions: Vec<(Range<usize>, UserId)>,
} }
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@ -105,6 +118,7 @@ impl ChannelChat {
rpc: client, rpc: client,
outgoing_messages_lock: Default::default(), outgoing_messages_lock: Default::default(),
messages: Default::default(), messages: Default::default(),
acknowledged_message_ids: Default::default(),
loaded_all_messages, loaded_all_messages,
next_pending_message_id: 0, next_pending_message_id: 0,
last_acknowledged_id: None, last_acknowledged_id: None,
@ -123,12 +137,16 @@ impl ChannelChat {
.cloned() .cloned()
} }
pub fn client(&self) -> &Arc<Client> {
&self.rpc
}
pub fn send_message( pub fn send_message(
&mut self, &mut self,
body: String, message: MessageParams,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Result<Task<Result<()>>> { ) -> Result<Task<Result<u64>>> {
if body.is_empty() { if message.text.is_empty() {
Err(anyhow!("message body can't be empty"))?; Err(anyhow!("message body can't be empty"))?;
} }
@ -145,9 +163,10 @@ impl ChannelChat {
SumTree::from_item( SumTree::from_item(
ChannelMessage { ChannelMessage {
id: pending_id, id: pending_id,
body: body.clone(), body: message.text.clone(),
sender: current_user, sender: current_user,
timestamp: OffsetDateTime::now_utc(), timestamp: OffsetDateTime::now_utc(),
mentions: message.mentions.clone(),
nonce, nonce,
}, },
&(), &(),
@ -161,20 +180,18 @@ impl ChannelChat {
let outgoing_message_guard = outgoing_messages_lock.lock().await; let outgoing_message_guard = outgoing_messages_lock.lock().await;
let request = rpc.request(proto::SendChannelMessage { let request = rpc.request(proto::SendChannelMessage {
channel_id, channel_id,
body, body: message.text,
nonce: Some(nonce.into()), nonce: Some(nonce.into()),
mentions: mentions_to_proto(&message.mentions),
}); });
let response = request.await?; let response = request.await?;
drop(outgoing_message_guard); drop(outgoing_message_guard);
let message = ChannelMessage::from_proto( let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
response.message.ok_or_else(|| anyhow!("invalid message"))?, let id = response.id;
&user_store, let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
&mut cx,
)
.await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx); this.insert_messages(SumTree::from_item(message, &()), cx);
Ok(()) Ok(id)
}) })
})) }))
} }
@ -194,41 +211,76 @@ impl ChannelChat {
}) })
} }
pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool { pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
if !self.loaded_all_messages { if self.loaded_all_messages {
let rpc = self.rpc.clone(); return None;
let user_store = self.user_store.clone(); }
let channel_id = self.channel_id;
if let Some(before_message_id) = let rpc = self.rpc.clone();
self.messages.first().and_then(|message| match message.id { let user_store = self.user_store.clone();
ChannelMessageId::Saved(id) => Some(id), let channel_id = self.channel_id;
ChannelMessageId::Pending(_) => None, let before_message_id = self.first_loaded_message_id()?;
}) Some(cx.spawn(|this, mut cx| {
{ async move {
cx.spawn(|this, mut cx| { let response = rpc
async move { .request(proto::GetChannelMessages {
let response = rpc channel_id,
.request(proto::GetChannelMessages { before_message_id,
channel_id, })
before_message_id, .await?;
}) let loaded_all_messages = response.done;
.await?; let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
let loaded_all_messages = response.done; this.update(&mut cx, |this, cx| {
let messages = this.loaded_all_messages = loaded_all_messages;
messages_from_proto(response.messages, &user_store, &mut cx).await?; this.insert_messages(messages, cx);
this.update(&mut cx, |this, cx| { });
this.loaded_all_messages = loaded_all_messages; anyhow::Ok(())
this.insert_messages(messages, cx); }
}); .log_err()
anyhow::Ok(()) }))
}
pub fn first_loaded_message_id(&mut self) -> Option<u64> {
self.messages.first().and_then(|message| match message.id {
ChannelMessageId::Saved(id) => Some(id),
ChannelMessageId::Pending(_) => None,
})
}
/// Load all of the chat messages since a certain message id.
///
/// For now, we always maintain a suffix of the channel's messages.
pub async fn load_history_since_message(
chat: ModelHandle<Self>,
message_id: u64,
mut cx: AsyncAppContext,
) -> Option<usize> {
loop {
let step = chat.update(&mut cx, |chat, cx| {
if let Some(first_id) = chat.first_loaded_message_id() {
if first_id <= message_id {
let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
let message_id = ChannelMessageId::Saved(message_id);
cursor.seek(&message_id, Bias::Left, &());
return ControlFlow::Break(
if cursor
.item()
.map_or(false, |message| message.id == message_id)
{
Some(cursor.start().1 .0)
} else {
None
},
);
} }
.log_err() }
}) ControlFlow::Continue(chat.load_more_messages(cx))
.detach(); });
return true; match step {
ControlFlow::Break(ix) => return ix,
ControlFlow::Continue(task) => task?.await?,
} }
} }
false
} }
pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) { pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
@ -287,6 +339,7 @@ impl ChannelChat {
let request = rpc.request(proto::SendChannelMessage { let request = rpc.request(proto::SendChannelMessage {
channel_id, channel_id,
body: pending_message.body, body: pending_message.body,
mentions: mentions_to_proto(&pending_message.mentions),
nonce: Some(pending_message.nonce.into()), nonce: Some(pending_message.nonce.into()),
}); });
let response = request.await?; let response = request.await?;
@ -322,6 +375,17 @@ impl ChannelChat {
cursor.item().unwrap() cursor.item().unwrap()
} }
pub fn acknowledge_message(&mut self, id: u64) {
if self.acknowledged_message_ids.insert(id) {
self.rpc
.send(proto::AckChannelMessage {
channel_id: self.channel_id,
message_id: id,
})
.ok();
}
}
pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> { pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
let mut cursor = self.messages.cursor::<Count>(); let mut cursor = self.messages.cursor::<Count>();
cursor.seek(&Count(range.start), Bias::Right, &()); cursor.seek(&Count(range.start), Bias::Right, &());
@ -454,22 +518,7 @@ async fn messages_from_proto(
user_store: &ModelHandle<UserStore>, user_store: &ModelHandle<UserStore>,
cx: &mut AsyncAppContext, cx: &mut AsyncAppContext,
) -> Result<SumTree<ChannelMessage>> { ) -> Result<SumTree<ChannelMessage>> {
let unique_user_ids = proto_messages let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
.iter()
.map(|m| m.sender_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
user_store
.update(cx, |user_store, cx| {
user_store.get_users(unique_user_ids, cx)
})
.await?;
let mut messages = Vec::with_capacity(proto_messages.len());
for message in proto_messages {
messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
}
let mut result = SumTree::new(); let mut result = SumTree::new();
result.extend(messages, &()); result.extend(messages, &());
Ok(result) Ok(result)
@ -489,6 +538,14 @@ impl ChannelMessage {
Ok(ChannelMessage { Ok(ChannelMessage {
id: ChannelMessageId::Saved(message.id), id: ChannelMessageId::Saved(message.id),
body: message.body, body: message.body,
mentions: message
.mentions
.into_iter()
.filter_map(|mention| {
let range = mention.range?;
Some((range.start as usize..range.end as usize, mention.user_id))
})
.collect(),
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
sender, sender,
nonce: message nonce: message
@ -501,6 +558,43 @@ impl ChannelMessage {
pub fn is_pending(&self) -> bool { pub fn is_pending(&self) -> bool {
matches!(self.id, ChannelMessageId::Pending(_)) matches!(self.id, ChannelMessageId::Pending(_))
} }
pub async fn from_proto_vec(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &ModelHandle<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<Vec<Self>> {
let unique_user_ids = proto_messages
.iter()
.map(|m| m.sender_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
user_store
.update(cx, |user_store, cx| {
user_store.get_users(unique_user_ids, cx)
})
.await?;
let mut messages = Vec::with_capacity(proto_messages.len());
for message in proto_messages {
messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
}
Ok(messages)
}
}
pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
mentions
.iter()
.map(|(range, user_id)| proto::ChatMention {
range: Some(proto::Range {
start: range.start as u64,
end: range.end as u64,
}),
user_id: *user_id as u64,
})
.collect()
} }
impl sum_tree::Item for ChannelMessage { impl sum_tree::Item for ChannelMessage {
@ -541,3 +635,12 @@ impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
self.0 += summary.count; self.0 += summary.count;
} }
} }
impl<'a> From<&'a str> for MessageParams {
fn from(value: &'a str) -> Self {
Self {
text: value.into(),
mentions: Vec::new(),
}
}
}

View File

@ -1,6 +1,6 @@
mod channel_index; mod channel_index;
use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat}; use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat, ChannelMessage};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use channel_index::ChannelIndex; use channel_index::ChannelIndex;
use client::{Client, Subscription, User, UserId, UserStore}; use client::{Client, Subscription, User, UserId, UserStore};
@ -157,9 +157,6 @@ impl ChannelStore {
this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx)); this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx));
} }
} }
if status.is_connected() {
} else {
}
} }
Some(()) Some(())
}); });
@ -245,6 +242,12 @@ impl ChannelStore {
self.channel_index.by_id().values().nth(ix) self.channel_index.by_id().values().nth(ix)
} }
pub fn has_channel_invitation(&self, channel_id: ChannelId) -> bool {
self.channel_invitations
.iter()
.any(|channel| channel.id == channel_id)
}
pub fn channel_invitations(&self) -> &[Arc<Channel>] { pub fn channel_invitations(&self) -> &[Arc<Channel>] {
&self.channel_invitations &self.channel_invitations
} }
@ -278,6 +281,33 @@ impl ChannelStore {
) )
} }
pub fn fetch_channel_messages(
&self,
message_ids: Vec<u64>,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<ChannelMessage>>> {
let request = if message_ids.is_empty() {
None
} else {
Some(
self.client
.request(proto::GetChannelMessagesById { message_ids }),
)
};
cx.spawn_weak(|this, mut cx| async move {
if let Some(request) = request {
let response = request.await?;
let this = this
.upgrade(&cx)
.ok_or_else(|| anyhow!("channel store dropped"))?;
let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
ChannelMessage::from_proto_vec(response.messages, &user_store, &mut cx).await
} else {
Ok(Vec::new())
}
})
}
pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option<bool> { pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option<bool> {
self.channel_index self.channel_index
.by_id() .by_id()
@ -689,14 +719,15 @@ impl ChannelStore {
&mut self, &mut self,
channel_id: ChannelId, channel_id: ChannelId,
accept: bool, accept: bool,
) -> impl Future<Output = Result<()>> { cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.clone(); let client = self.client.clone();
async move { cx.background().spawn(async move {
client client
.request(proto::RespondToChannelInvite { channel_id, accept }) .request(proto::RespondToChannelInvite { channel_id, accept })
.await?; .await?;
Ok(()) Ok(())
} })
} }
pub fn get_channel_member_details( pub fn get_channel_member_details(
@ -764,6 +795,11 @@ impl ChannelStore {
} }
fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> { fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
self.channel_index.clear();
self.channel_invitations.clear();
self.channel_participants.clear();
self.channel_index.clear();
self.outgoing_invites.clear();
self.disconnect_channel_buffers_task.take(); self.disconnect_channel_buffers_task.take();
for chat in self.opened_chats.values() { for chat in self.opened_chats.values() {
@ -873,11 +909,6 @@ impl ChannelStore {
} }
fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut ModelContext<Self>) { fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut ModelContext<Self>) {
self.channel_index.clear();
self.channel_invitations.clear();
self.channel_participants.clear();
self.channel_index.clear();
self.outgoing_invites.clear();
cx.notify(); cx.notify();
self.disconnect_channel_buffers_task.get_or_insert_with(|| { self.disconnect_channel_buffers_task.get_or_insert_with(|| {

View File

@ -210,6 +210,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
body: "a".into(), body: "a".into(),
timestamp: 1000, timestamp: 1000,
sender_id: 5, sender_id: 5,
mentions: vec![],
nonce: Some(1.into()), nonce: Some(1.into()),
}, },
proto::ChannelMessage { proto::ChannelMessage {
@ -217,6 +218,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
body: "b".into(), body: "b".into(),
timestamp: 1001, timestamp: 1001,
sender_id: 6, sender_id: 6,
mentions: vec![],
nonce: Some(2.into()), nonce: Some(2.into()),
}, },
], ],
@ -263,6 +265,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
body: "c".into(), body: "c".into(),
timestamp: 1002, timestamp: 1002,
sender_id: 7, sender_id: 7,
mentions: vec![],
nonce: Some(3.into()), nonce: Some(3.into()),
}), }),
}); });
@ -300,7 +303,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
// Scroll up to view older messages. // Scroll up to view older messages.
channel.update(cx, |channel, cx| { channel.update(cx, |channel, cx| {
assert!(channel.load_more_messages(cx)); channel.load_more_messages(cx).unwrap().detach();
}); });
let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap(); let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
assert_eq!(get_messages.payload.channel_id, 5); assert_eq!(get_messages.payload.channel_id, 5);
@ -316,6 +319,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
timestamp: 998, timestamp: 998,
sender_id: 5, sender_id: 5,
nonce: Some(4.into()), nonce: Some(4.into()),
mentions: vec![],
}, },
proto::ChannelMessage { proto::ChannelMessage {
id: 9, id: 9,
@ -323,6 +327,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
timestamp: 999, timestamp: 999,
sender_id: 6, sender_id: 6,
nonce: Some(5.into()), nonce: Some(5.into()),
mentions: vec![],
}, },
], ],
}, },

View File

@ -293,21 +293,19 @@ impl UserStore {
// No need to paralellize here // No need to paralellize here
let mut updated_contacts = Vec::new(); let mut updated_contacts = Vec::new();
for contact in message.contacts { for contact in message.contacts {
let should_notify = contact.should_notify; updated_contacts.push(Arc::new(
updated_contacts.push(( Contact::from_proto(contact, &this, &mut cx).await?,
Arc::new(Contact::from_proto(contact, &this, &mut cx).await?),
should_notify,
)); ));
} }
let mut incoming_requests = Vec::new(); let mut incoming_requests = Vec::new();
for request in message.incoming_requests { for request in message.incoming_requests {
incoming_requests.push({ incoming_requests.push(
let user = this this.update(&mut cx, |this, cx| {
.update(&mut cx, |this, cx| this.get_user(request.requester_id, cx)) this.get_user(request.requester_id, cx)
.await?; })
(user, request.should_notify) .await?,
}); );
} }
let mut outgoing_requests = Vec::new(); let mut outgoing_requests = Vec::new();
@ -330,13 +328,7 @@ impl UserStore {
this.contacts this.contacts
.retain(|contact| !removed_contacts.contains(&contact.user.id)); .retain(|contact| !removed_contacts.contains(&contact.user.id));
// Update existing contacts and insert new ones // Update existing contacts and insert new ones
for (updated_contact, should_notify) in updated_contacts { for updated_contact in updated_contacts {
if should_notify {
cx.emit(Event::Contact {
user: updated_contact.user.clone(),
kind: ContactEventKind::Accepted,
});
}
match this.contacts.binary_search_by_key( match this.contacts.binary_search_by_key(
&&updated_contact.user.github_login, &&updated_contact.user.github_login,
|contact| &contact.user.github_login, |contact| &contact.user.github_login,
@ -359,14 +351,7 @@ impl UserStore {
} }
}); });
// Update existing incoming requests and insert new ones // Update existing incoming requests and insert new ones
for (user, should_notify) in incoming_requests { for user in incoming_requests {
if should_notify {
cx.emit(Event::Contact {
user: user.clone(),
kind: ContactEventKind::Requested,
});
}
match this match this
.incoming_contact_requests .incoming_contact_requests
.binary_search_by_key(&&user.github_login, |contact| { .binary_search_by_key(&&user.github_login, |contact| {
@ -415,6 +400,12 @@ impl UserStore {
&self.incoming_contact_requests &self.incoming_contact_requests
} }
pub fn has_incoming_contact_request(&self, user_id: u64) -> bool {
self.incoming_contact_requests
.iter()
.any(|user| user.id == user_id)
}
pub fn outgoing_contact_requests(&self) -> &[Arc<User>] { pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
&self.outgoing_contact_requests &self.outgoing_contact_requests
} }

View File

@ -3,7 +3,7 @@ authors = ["Nathan Sobo <nathan@zed.dev>"]
default-run = "collab" default-run = "collab"
edition = "2021" edition = "2021"
name = "collab" name = "collab"
version = "0.24.0" version = "0.25.0"
publish = false publish = false
[[bin]] [[bin]]
@ -73,6 +73,7 @@ git = { path = "../git", features = ["test-support"] }
live_kit_client = { path = "../live_kit_client", features = ["test-support"] } live_kit_client = { path = "../live_kit_client", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] }
node_runtime = { path = "../node_runtime" } node_runtime = { path = "../node_runtime" }
notifications = { path = "../notifications", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] } project = { path = "../project", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] } rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] }

View File

@ -192,7 +192,7 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id");
CREATE TABLE "channels" ( CREATE TABLE "channels" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT, "id" INTEGER PRIMARY KEY AUTOINCREMENT,
"name" VARCHAR NOT NULL, "name" VARCHAR NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT now, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"visibility" VARCHAR NOT NULL "visibility" VARCHAR NOT NULL
); );
@ -214,7 +214,15 @@ CREATE TABLE IF NOT EXISTS "channel_messages" (
"nonce" BLOB NOT NULL "nonce" BLOB NOT NULL
); );
CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
CREATE TABLE "channel_message_mentions" (
"message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE,
"start_offset" INTEGER NOT NULL,
"end_offset" INTEGER NOT NULL,
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
PRIMARY KEY(message_id, start_offset)
);
CREATE TABLE "channel_paths" ( CREATE TABLE "channel_paths" (
"id_path" TEXT NOT NULL PRIMARY KEY, "id_path" TEXT NOT NULL PRIMARY KEY,
@ -314,3 +322,26 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" (
); );
CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id");
CREATE TABLE "notification_kinds" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"name" VARCHAR NOT NULL
);
CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name");
CREATE TABLE "notifications" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP,
"recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
"entity_id" INTEGER,
"content" TEXT,
"is_read" BOOLEAN NOT NULL DEFAULT FALSE,
"response" BOOLEAN
);
CREATE INDEX
"index_notifications_on_recipient_id_is_read_kind_entity_id"
ON "notifications"
("recipient_id", "is_read", "kind", "entity_id");

View File

@ -0,0 +1,22 @@
CREATE TABLE "notification_kinds" (
"id" SERIAL PRIMARY KEY,
"name" VARCHAR NOT NULL
);
CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name");
CREATE TABLE notifications (
"id" SERIAL PRIMARY KEY,
"created_at" TIMESTAMP NOT NULL DEFAULT now(),
"recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
"entity_id" INTEGER,
"content" TEXT,
"is_read" BOOLEAN NOT NULL DEFAULT FALSE,
"response" BOOLEAN
);
CREATE INDEX
"index_notifications_on_recipient_id_is_read_kind_entity_id"
ON "notifications"
("recipient_id", "is_read", "kind", "entity_id");

View File

@ -0,0 +1,11 @@
CREATE TABLE "channel_message_mentions" (
"message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE,
"start_offset" INTEGER NOT NULL,
"end_offset" INTEGER NOT NULL,
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
PRIMARY KEY(message_id, start_offset)
);
-- We use 'on conflict update' with this index, so it should be per-user.
CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
DROP INDEX "index_channel_messages_on_nonce";

View File

@ -71,7 +71,6 @@ async fn main() {
db::NewUserParams { db::NewUserParams {
github_login: github_user.login, github_login: github_user.login,
github_user_id: github_user.id, github_user_id: github_user.id,
invite_count: 5,
}, },
) )
.await .await

View File

@ -13,6 +13,7 @@ use anyhow::anyhow;
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use dashmap::DashMap; use dashmap::DashMap;
use futures::StreamExt; use futures::StreamExt;
use queries::channels::ChannelGraph;
use rand::{prelude::StdRng, Rng, SeedableRng}; use rand::{prelude::StdRng, Rng, SeedableRng};
use rpc::{ use rpc::{
proto::{self}, proto::{self},
@ -20,7 +21,7 @@ use rpc::{
}; };
use sea_orm::{ use sea_orm::{
entity::prelude::*, entity::prelude::*,
sea_query::{Alias, Expr, OnConflict, Query}, sea_query::{Alias, Expr, OnConflict},
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait, TransactionTrait,
@ -47,14 +48,14 @@ pub use ids::*;
pub use sea_orm::ConnectOptions; pub use sea_orm::ConnectOptions;
pub use tables::user::Model as User; pub use tables::user::Model as User;
use self::queries::channels::ChannelGraph;
pub struct Database { pub struct Database {
options: ConnectOptions, options: ConnectOptions,
pool: DatabaseConnection, pool: DatabaseConnection,
rooms: DashMap<RoomId, Arc<Mutex<()>>>, rooms: DashMap<RoomId, Arc<Mutex<()>>>,
rng: Mutex<StdRng>, rng: Mutex<StdRng>,
executor: Executor, executor: Executor,
notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
notification_kinds_by_name: HashMap<String, NotificationKindId>,
#[cfg(test)] #[cfg(test)]
runtime: Option<tokio::runtime::Runtime>, runtime: Option<tokio::runtime::Runtime>,
} }
@ -69,6 +70,8 @@ impl Database {
pool: sea_orm::Database::connect(options).await?, pool: sea_orm::Database::connect(options).await?,
rooms: DashMap::with_capacity(16384), rooms: DashMap::with_capacity(16384),
rng: Mutex::new(StdRng::seed_from_u64(0)), rng: Mutex::new(StdRng::seed_from_u64(0)),
notification_kinds_by_id: HashMap::default(),
notification_kinds_by_name: HashMap::default(),
executor, executor,
#[cfg(test)] #[cfg(test)]
runtime: None, runtime: None,
@ -121,6 +124,11 @@ impl Database {
Ok(new_migrations) Ok(new_migrations)
} }
pub async fn initialize_static_data(&mut self) -> Result<()> {
self.initialize_notification_kinds().await?;
Ok(())
}
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T> pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where where
F: Send + Fn(TransactionHandle) -> Fut, F: Send + Fn(TransactionHandle) -> Fut,
@ -361,18 +369,9 @@ impl<T> RoomGuard<T> {
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Contact { pub enum Contact {
Accepted { Accepted { user_id: UserId, busy: bool },
user_id: UserId, Outgoing { user_id: UserId },
should_notify: bool, Incoming { user_id: UserId },
busy: bool,
},
Outgoing {
user_id: UserId,
},
Incoming {
user_id: UserId,
should_notify: bool,
},
} }
impl Contact { impl Contact {
@ -385,6 +384,15 @@ impl Contact {
} }
} }
pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
pub struct CreatedChannelMessage {
pub message_id: MessageId,
pub participant_connection_ids: Vec<ConnectionId>,
pub channel_members: Vec<UserId>,
pub notifications: NotificationBatch,
}
#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
pub struct Invite { pub struct Invite {
pub email_address: String, pub email_address: String,
@ -417,7 +425,6 @@ pub struct WaitlistSummary {
pub struct NewUserParams { pub struct NewUserParams {
pub github_login: String, pub github_login: String,
pub github_user_id: i32, pub github_user_id: i32,
pub invite_count: i32,
} }
#[derive(Debug)] #[derive(Debug)]
@ -466,6 +473,24 @@ pub enum SetMemberRoleResult {
MembershipUpdated(MembershipUpdated), MembershipUpdated(MembershipUpdated),
} }
#[derive(Debug)]
pub struct InviteMemberResult {
pub channel: Channel,
pub notifications: NotificationBatch,
}
#[derive(Debug)]
pub struct RespondToChannelInvite {
pub membership_update: Option<MembershipUpdated>,
pub notifications: NotificationBatch,
}
#[derive(Debug)]
pub struct RemoveChannelMemberResult {
pub membership_update: MembershipUpdated,
pub notification_id: Option<NotificationId>,
}
#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)] #[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)]
pub struct Channel { pub struct Channel {
pub id: ChannelId, pub id: ChannelId,

View File

@ -81,6 +81,8 @@ id_type!(SignupId);
id_type!(UserId); id_type!(UserId);
id_type!(ChannelBufferCollaboratorId); id_type!(ChannelBufferCollaboratorId);
id_type!(FlagId); id_type!(FlagId);
id_type!(NotificationId);
id_type!(NotificationKindId);
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)] #[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)]
#[sea_orm(rs_type = "String", db_type = "String(None)")] #[sea_orm(rs_type = "String", db_type = "String(None)")]

View File

@ -5,6 +5,7 @@ pub mod buffers;
pub mod channels; pub mod channels;
pub mod contacts; pub mod contacts;
pub mod messages; pub mod messages;
pub mod notifications;
pub mod projects; pub mod projects;
pub mod rooms; pub mod rooms;
pub mod servers; pub mod servers;

View File

@ -1,4 +1,5 @@
use super::*; use super::*;
use sea_orm::sea_query::Query;
impl Database { impl Database {
pub async fn create_access_token( pub async fn create_access_token(

View File

@ -349,11 +349,11 @@ impl Database {
&self, &self,
channel_id: ChannelId, channel_id: ChannelId,
invitee_id: UserId, invitee_id: UserId,
admin_id: UserId, inviter_id: UserId,
role: ChannelRole, role: ChannelRole,
) -> Result<Channel> { ) -> Result<InviteMemberResult> {
self.transaction(move |tx| async move { self.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, admin_id, &*tx) self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
.await?; .await?;
channel_member::ActiveModel { channel_member::ActiveModel {
@ -371,11 +371,31 @@ impl Database {
.await? .await?
.unwrap(); .unwrap();
Ok(Channel { let channel = Channel {
id: channel.id, id: channel.id,
visibility: channel.visibility, visibility: channel.visibility,
name: channel.name, name: channel.name,
role, role,
};
let notifications = self
.create_notification(
invitee_id,
rpc::Notification::ChannelInvitation {
channel_id: channel_id.to_proto(),
channel_name: channel.name.clone(),
inviter_id: inviter_id.to_proto(),
},
true,
&*tx,
)
.await?
.into_iter()
.collect();
Ok(InviteMemberResult {
channel,
notifications,
}) })
}) })
.await .await
@ -445,9 +465,9 @@ impl Database {
channel_id: ChannelId, channel_id: ChannelId,
user_id: UserId, user_id: UserId,
accept: bool, accept: bool,
) -> Result<Option<MembershipUpdated>> { ) -> Result<RespondToChannelInvite> {
self.transaction(move |tx| async move { self.transaction(move |tx| async move {
if accept { let membership_update = if accept {
let rows_affected = channel_member::Entity::update_many() let rows_affected = channel_member::Entity::update_many()
.set(channel_member::ActiveModel { .set(channel_member::ActiveModel {
accepted: ActiveValue::Set(accept), accepted: ActiveValue::Set(accept),
@ -467,26 +487,45 @@ impl Database {
Err(anyhow!("no such invitation"))?; Err(anyhow!("no such invitation"))?;
} }
return Ok(Some( Some(
self.calculate_membership_updated(channel_id, user_id, &*tx) self.calculate_membership_updated(channel_id, user_id, &*tx)
.await?, .await?,
)); )
} } else {
let rows_affected = channel_member::Entity::delete_many()
.filter(
channel_member::Column::ChannelId
.eq(channel_id)
.and(channel_member::Column::UserId.eq(user_id))
.and(channel_member::Column::Accepted.eq(false)),
)
.exec(&*tx)
.await?
.rows_affected;
if rows_affected == 0 {
Err(anyhow!("no such invitation"))?;
}
let rows_affected = channel_member::ActiveModel { None
channel_id: ActiveValue::Unchanged(channel_id), };
user_id: ActiveValue::Unchanged(user_id),
..Default::default()
}
.delete(&*tx)
.await?
.rows_affected;
if rows_affected == 0 { Ok(RespondToChannelInvite {
Err(anyhow!("no such invitation"))?; membership_update,
} notifications: self
.mark_notification_as_read_with_response(
Ok(None) user_id,
&rpc::Notification::ChannelInvitation {
channel_id: channel_id.to_proto(),
channel_name: Default::default(),
inviter_id: Default::default(),
},
accept,
&*tx,
)
.await?
.into_iter()
.collect(),
})
}) })
.await .await
} }
@ -550,7 +589,7 @@ impl Database {
channel_id: ChannelId, channel_id: ChannelId,
member_id: UserId, member_id: UserId,
admin_id: UserId, admin_id: UserId,
) -> Result<MembershipUpdated> { ) -> Result<RemoveChannelMemberResult> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
self.check_user_is_channel_admin(channel_id, admin_id, &*tx) self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?; .await?;
@ -568,9 +607,22 @@ impl Database {
Err(anyhow!("no such member"))?; Err(anyhow!("no such member"))?;
} }
Ok(self Ok(RemoveChannelMemberResult {
.calculate_membership_updated(channel_id, member_id, &*tx) membership_update: self
.await?) .calculate_membership_updated(channel_id, member_id, &*tx)
.await?,
notification_id: self
.remove_notification(
member_id,
rpc::Notification::ChannelInvitation {
channel_id: channel_id.to_proto(),
channel_name: Default::default(),
inviter_id: Default::default(),
},
&*tx,
)
.await?,
})
}) })
.await .await
} }
@ -911,6 +963,47 @@ impl Database {
.await .await
} }
pub async fn get_channel_participant_details(
&self,
channel_id: ChannelId,
user_id: UserId,
) -> Result<Vec<proto::ChannelMember>> {
let (role, members) = self
.transaction(move |tx| async move {
let role = self
.check_user_is_channel_participant(channel_id, user_id, &*tx)
.await?;
Ok((
role,
self.get_channel_participant_details_internal(channel_id, &*tx)
.await?,
))
})
.await?;
if role == ChannelRole::Admin {
Ok(members
.into_iter()
.map(|channel_member| channel_member.to_proto())
.collect())
} else {
return Ok(members
.into_iter()
.filter_map(|member| {
if member.kind == proto::channel_member::Kind::Invitee {
return None;
}
Some(ChannelMember {
role: member.role,
user_id: member.user_id,
kind: proto::channel_member::Kind::Member,
})
})
.map(|channel_member| channel_member.to_proto())
.collect());
}
}
async fn get_channel_participant_details_internal( async fn get_channel_participant_details_internal(
&self, &self,
channel_id: ChannelId, channel_id: ChannelId,
@ -1003,28 +1096,6 @@ impl Database {
.collect()) .collect())
} }
pub async fn get_channel_participant_details(
&self,
channel_id: ChannelId,
admin_id: UserId,
) -> Result<Vec<proto::ChannelMember>> {
let members = self
.transaction(move |tx| async move {
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
Ok(self
.get_channel_participant_details_internal(channel_id, &*tx)
.await?)
})
.await?;
Ok(members
.into_iter()
.map(|channel_member| channel_member.to_proto())
.collect())
}
pub async fn get_channel_participants( pub async fn get_channel_participants(
&self, &self,
channel_id: ChannelId, channel_id: ChannelId,
@ -1062,9 +1133,10 @@ impl Database {
channel_id: ChannelId, channel_id: ChannelId,
user_id: UserId, user_id: UserId,
tx: &DatabaseTransaction, tx: &DatabaseTransaction,
) -> Result<()> { ) -> Result<ChannelRole> {
match self.channel_role_for_user(channel_id, user_id, tx).await? { let channel_role = self.channel_role_for_user(channel_id, user_id, tx).await?;
Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()), match channel_role {
Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()),
Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!( Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
"user is not a channel member or channel does not exist" "user is not a channel member or channel does not exist"
))?, ))?,

View File

@ -8,7 +8,6 @@ impl Database {
user_id_b: UserId, user_id_b: UserId,
a_to_b: bool, a_to_b: bool,
accepted: bool, accepted: bool,
should_notify: bool,
user_a_busy: bool, user_a_busy: bool,
user_b_busy: bool, user_b_busy: bool,
} }
@ -53,7 +52,6 @@ impl Database {
if db_contact.accepted { if db_contact.accepted {
contacts.push(Contact::Accepted { contacts.push(Contact::Accepted {
user_id: db_contact.user_id_b, user_id: db_contact.user_id_b,
should_notify: db_contact.should_notify && db_contact.a_to_b,
busy: db_contact.user_b_busy, busy: db_contact.user_b_busy,
}); });
} else if db_contact.a_to_b { } else if db_contact.a_to_b {
@ -63,19 +61,16 @@ impl Database {
} else { } else {
contacts.push(Contact::Incoming { contacts.push(Contact::Incoming {
user_id: db_contact.user_id_b, user_id: db_contact.user_id_b,
should_notify: db_contact.should_notify,
}); });
} }
} else if db_contact.accepted { } else if db_contact.accepted {
contacts.push(Contact::Accepted { contacts.push(Contact::Accepted {
user_id: db_contact.user_id_a, user_id: db_contact.user_id_a,
should_notify: db_contact.should_notify && !db_contact.a_to_b,
busy: db_contact.user_a_busy, busy: db_contact.user_a_busy,
}); });
} else if db_contact.a_to_b { } else if db_contact.a_to_b {
contacts.push(Contact::Incoming { contacts.push(Contact::Incoming {
user_id: db_contact.user_id_a, user_id: db_contact.user_id_a,
should_notify: db_contact.should_notify,
}); });
} else { } else {
contacts.push(Contact::Outgoing { contacts.push(Contact::Outgoing {
@ -124,7 +119,11 @@ impl Database {
.await .await
} }
pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { pub async fn send_contact_request(
&self,
sender_id: UserId,
receiver_id: UserId,
) -> Result<NotificationBatch> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
let (id_a, id_b, a_to_b) = if sender_id < receiver_id { let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
(sender_id, receiver_id, true) (sender_id, receiver_id, true)
@ -161,11 +160,22 @@ impl Database {
.exec_without_returning(&*tx) .exec_without_returning(&*tx)
.await?; .await?;
if rows_affected == 1 { if rows_affected == 0 {
Ok(()) Err(anyhow!("contact already requested"))?;
} else {
Err(anyhow!("contact already requested"))?
} }
Ok(self
.create_notification(
receiver_id,
rpc::Notification::ContactRequest {
sender_id: sender_id.to_proto(),
},
true,
&*tx,
)
.await?
.into_iter()
.collect())
}) })
.await .await
} }
@ -179,7 +189,11 @@ impl Database {
/// ///
/// * `requester_id` - The user that initiates this request /// * `requester_id` - The user that initiates this request
/// * `responder_id` - The user that will be removed /// * `responder_id` - The user that will be removed
pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<bool> { pub async fn remove_contact(
&self,
requester_id: UserId,
responder_id: UserId,
) -> Result<(bool, Option<NotificationId>)> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
let (id_a, id_b) = if responder_id < requester_id { let (id_a, id_b) = if responder_id < requester_id {
(responder_id, requester_id) (responder_id, requester_id)
@ -198,7 +212,21 @@ impl Database {
.ok_or_else(|| anyhow!("no such contact"))?; .ok_or_else(|| anyhow!("no such contact"))?;
contact::Entity::delete_by_id(contact.id).exec(&*tx).await?; contact::Entity::delete_by_id(contact.id).exec(&*tx).await?;
Ok(contact.accepted)
let mut deleted_notification_id = None;
if !contact.accepted {
deleted_notification_id = self
.remove_notification(
responder_id,
rpc::Notification::ContactRequest {
sender_id: requester_id.to_proto(),
},
&*tx,
)
.await?;
}
Ok((contact.accepted, deleted_notification_id))
}) })
.await .await
} }
@ -249,7 +277,7 @@ impl Database {
responder_id: UserId, responder_id: UserId,
requester_id: UserId, requester_id: UserId,
accept: bool, accept: bool,
) -> Result<()> { ) -> Result<NotificationBatch> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
let (id_a, id_b, a_to_b) = if responder_id < requester_id { let (id_a, id_b, a_to_b) = if responder_id < requester_id {
(responder_id, requester_id, false) (responder_id, requester_id, false)
@ -287,11 +315,38 @@ impl Database {
result.rows_affected result.rows_affected
}; };
if rows_affected == 1 { if rows_affected == 0 {
Ok(())
} else {
Err(anyhow!("no such contact request"))? Err(anyhow!("no such contact request"))?
} }
let mut notifications = Vec::new();
notifications.extend(
self.mark_notification_as_read_with_response(
responder_id,
&rpc::Notification::ContactRequest {
sender_id: requester_id.to_proto(),
},
accept,
&*tx,
)
.await?,
);
if accept {
notifications.extend(
self.create_notification(
requester_id,
rpc::Notification::ContactRequestAccepted {
responder_id: responder_id.to_proto(),
},
true,
&*tx,
)
.await?,
);
}
Ok(notifications)
}) })
.await .await
} }

View File

@ -1,4 +1,7 @@
use super::*; use super::*;
use futures::Stream;
use rpc::Notification;
use sea_orm::TryInsertResult;
use time::OffsetDateTime; use time::OffsetDateTime;
impl Database { impl Database {
@ -87,43 +90,118 @@ impl Database {
condition = condition.add(channel_message::Column::Id.lt(before_message_id)); condition = condition.add(channel_message::Column::Id.lt(before_message_id));
} }
let mut rows = channel_message::Entity::find() let rows = channel_message::Entity::find()
.filter(condition) .filter(condition)
.order_by_desc(channel_message::Column::Id) .order_by_desc(channel_message::Column::Id)
.limit(count as u64) .limit(count as u64)
.stream(&*tx) .stream(&*tx)
.await?; .await?;
let mut messages = Vec::new(); self.load_channel_messages(rows, &*tx).await
while let Some(row) = rows.next().await { })
let row = row?; .await
let nonce = row.nonce.as_u64_pair(); }
messages.push(proto::ChannelMessage {
id: row.id.to_proto(), pub async fn get_channel_messages_by_id(
sender_id: row.sender_id.to_proto(), &self,
body: row.body, user_id: UserId,
timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, message_ids: &[MessageId],
nonce: Some(proto::Nonce { ) -> Result<Vec<proto::ChannelMessage>> {
upper_half: nonce.0, self.transaction(|tx| async move {
lower_half: nonce.1, let rows = channel_message::Entity::find()
.filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
.order_by_desc(channel_message::Column::Id)
.stream(&*tx)
.await?;
let mut channel_ids = HashSet::<ChannelId>::default();
let messages = self
.load_channel_messages(
rows.map(|row| {
row.map(|row| {
channel_ids.insert(row.channel_id);
row
})
}), }),
}); &*tx,
)
.await?;
for channel_id in channel_ids {
self.check_user_is_channel_member(channel_id, user_id, &*tx)
.await?;
} }
drop(rows);
messages.reverse();
Ok(messages) Ok(messages)
}) })
.await .await
} }
async fn load_channel_messages(
&self,
mut rows: impl Send + Unpin + Stream<Item = Result<channel_message::Model, sea_orm::DbErr>>,
tx: &DatabaseTransaction,
) -> Result<Vec<proto::ChannelMessage>> {
let mut messages = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
let nonce = row.nonce.as_u64_pair();
messages.push(proto::ChannelMessage {
id: row.id.to_proto(),
sender_id: row.sender_id.to_proto(),
body: row.body,
timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
mentions: vec![],
nonce: Some(proto::Nonce {
upper_half: nonce.0,
lower_half: nonce.1,
}),
});
}
drop(rows);
messages.reverse();
let mut mentions = channel_message_mention::Entity::find()
.filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
.order_by_asc(channel_message_mention::Column::MessageId)
.order_by_asc(channel_message_mention::Column::StartOffset)
.stream(&*tx)
.await?;
let mut message_ix = 0;
while let Some(mention) = mentions.next().await {
let mention = mention?;
let message_id = mention.message_id.to_proto();
while let Some(message) = messages.get_mut(message_ix) {
if message.id < message_id {
message_ix += 1;
} else {
if message.id == message_id {
message.mentions.push(proto::ChatMention {
range: Some(proto::Range {
start: mention.start_offset as u64,
end: mention.end_offset as u64,
}),
user_id: mention.user_id.to_proto(),
});
}
break;
}
}
}
Ok(messages)
}
pub async fn create_channel_message( pub async fn create_channel_message(
&self, &self,
channel_id: ChannelId, channel_id: ChannelId,
user_id: UserId, user_id: UserId,
body: &str, body: &str,
mentions: &[proto::ChatMention],
timestamp: OffsetDateTime, timestamp: OffsetDateTime,
nonce: u128, nonce: u128,
) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> { ) -> Result<CreatedChannelMessage> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
self.check_user_is_channel_participant(channel_id, user_id, &*tx) self.check_user_is_channel_participant(channel_id, user_id, &*tx)
.await?; .await?;
@ -153,7 +231,7 @@ impl Database {
let timestamp = timestamp.to_offset(time::UtcOffset::UTC); let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
let message = channel_message::Entity::insert(channel_message::ActiveModel { let result = channel_message::Entity::insert(channel_message::ActiveModel {
channel_id: ActiveValue::Set(channel_id), channel_id: ActiveValue::Set(channel_id),
sender_id: ActiveValue::Set(user_id), sender_id: ActiveValue::Set(user_id),
body: ActiveValue::Set(body.to_string()), body: ActiveValue::Set(body.to_string()),
@ -162,35 +240,85 @@ impl Database {
id: ActiveValue::NotSet, id: ActiveValue::NotSet,
}) })
.on_conflict( .on_conflict(
OnConflict::column(channel_message::Column::Nonce) OnConflict::columns([
.update_column(channel_message::Column::Nonce) channel_message::Column::SenderId,
.to_owned(), channel_message::Column::Nonce,
])
.do_nothing()
.to_owned(),
) )
.do_nothing()
.exec(&*tx) .exec(&*tx)
.await?; .await?;
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] let message_id;
enum QueryConnectionId { let mut notifications = Vec::new();
ConnectionId, match result {
} TryInsertResult::Inserted(result) => {
message_id = result.last_insert_id;
let mentioned_user_ids =
mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
let mentions = mentions
.iter()
.filter_map(|mention| {
let range = mention.range.as_ref()?;
if !body.is_char_boundary(range.start as usize)
|| !body.is_char_boundary(range.end as usize)
{
return None;
}
Some(channel_message_mention::ActiveModel {
message_id: ActiveValue::Set(message_id),
start_offset: ActiveValue::Set(range.start as i32),
end_offset: ActiveValue::Set(range.end as i32),
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
})
})
.collect::<Vec<_>>();
if !mentions.is_empty() {
channel_message_mention::Entity::insert_many(mentions)
.exec(&*tx)
.await?;
}
// Observe this message for the sender for mentioned_user in mentioned_user_ids {
self.observe_channel_message_internal( notifications.extend(
channel_id, self.create_notification(
user_id, UserId::from_proto(mentioned_user),
message.last_insert_id, rpc::Notification::ChannelMessageMention {
&*tx, message_id: message_id.to_proto(),
) sender_id: user_id.to_proto(),
.await?; channel_id: channel_id.to_proto(),
},
false,
&*tx,
)
.await?,
);
}
self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
.await?;
}
_ => {
message_id = channel_message::Entity::find()
.filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to insert message"))?
.id;
}
}
let mut channel_members = self.get_channel_participants(channel_id, &*tx).await?; let mut channel_members = self.get_channel_participants(channel_id, &*tx).await?;
channel_members.retain(|member| !participant_user_ids.contains(member)); channel_members.retain(|member| !participant_user_ids.contains(member));
Ok(( Ok(CreatedChannelMessage {
message.last_insert_id, message_id,
participant_connection_ids, participant_connection_ids,
channel_members, channel_members,
)) notifications,
})
}) })
.await .await
} }
@ -200,11 +328,24 @@ impl Database {
channel_id: ChannelId, channel_id: ChannelId,
user_id: UserId, user_id: UserId,
message_id: MessageId, message_id: MessageId,
) -> Result<()> { ) -> Result<NotificationBatch> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
.await?; .await?;
Ok(()) let mut batch = NotificationBatch::default();
batch.extend(
self.mark_notification_as_read(
user_id,
&Notification::ChannelMessageMention {
message_id: message_id.to_proto(),
sender_id: Default::default(),
channel_id: Default::default(),
},
&*tx,
)
.await?,
);
Ok(batch)
}) })
.await .await
} }

View File

@ -0,0 +1,262 @@
use super::*;
use rpc::Notification;
impl Database {
pub async fn initialize_notification_kinds(&mut self) -> Result<()> {
notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
|kind| notification_kind::ActiveModel {
name: ActiveValue::Set(kind.to_string()),
..Default::default()
},
))
.on_conflict(OnConflict::new().do_nothing().to_owned())
.exec_without_returning(&self.pool)
.await?;
let mut rows = notification_kind::Entity::find().stream(&self.pool).await?;
while let Some(row) = rows.next().await {
let row = row?;
self.notification_kinds_by_name.insert(row.name, row.id);
}
for name in Notification::all_variant_names() {
if let Some(id) = self.notification_kinds_by_name.get(*name).copied() {
self.notification_kinds_by_id.insert(id, name);
}
}
Ok(())
}
pub async fn get_notifications(
&self,
recipient_id: UserId,
limit: usize,
before_id: Option<NotificationId>,
) -> Result<Vec<proto::Notification>> {
self.transaction(|tx| async move {
let mut result = Vec::new();
let mut condition =
Condition::all().add(notification::Column::RecipientId.eq(recipient_id));
if let Some(before_id) = before_id {
condition = condition.add(notification::Column::Id.lt(before_id));
}
let mut rows = notification::Entity::find()
.filter(condition)
.order_by_desc(notification::Column::Id)
.limit(limit as u64)
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
let kind = row.kind;
if let Some(proto) = model_to_proto(self, row) {
result.push(proto);
} else {
log::warn!("unknown notification kind {:?}", kind);
}
}
result.reverse();
Ok(result)
})
.await
}
/// Create a notification. If `avoid_duplicates` is set to true, then avoid
/// creating a new notification if the given recipient already has an
/// unread notification with the given kind and entity id.
pub async fn create_notification(
&self,
recipient_id: UserId,
notification: Notification,
avoid_duplicates: bool,
tx: &DatabaseTransaction,
) -> Result<Option<(UserId, proto::Notification)>> {
if avoid_duplicates {
if self
.find_notification(recipient_id, &notification, tx)
.await?
.is_some()
{
return Ok(None);
}
}
let proto = notification.to_proto();
let kind = notification_kind_from_proto(self, &proto)?;
let model = notification::ActiveModel {
recipient_id: ActiveValue::Set(recipient_id),
kind: ActiveValue::Set(kind),
entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)),
content: ActiveValue::Set(proto.content.clone()),
..Default::default()
}
.save(&*tx)
.await?;
Ok(Some((
recipient_id,
proto::Notification {
id: model.id.as_ref().to_proto(),
kind: proto.kind,
timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
is_read: false,
response: None,
content: proto.content,
entity_id: proto.entity_id,
},
)))
}
/// Remove an unread notification with the given recipient, kind and
/// entity id.
pub async fn remove_notification(
&self,
recipient_id: UserId,
notification: Notification,
tx: &DatabaseTransaction,
) -> Result<Option<NotificationId>> {
let id = self
.find_notification(recipient_id, &notification, tx)
.await?;
if let Some(id) = id {
notification::Entity::delete_by_id(id).exec(tx).await?;
}
Ok(id)
}
/// Populate the response for the notification with the given kind and
/// entity id.
pub async fn mark_notification_as_read_with_response(
&self,
recipient_id: UserId,
notification: &Notification,
response: bool,
tx: &DatabaseTransaction,
) -> Result<Option<(UserId, proto::Notification)>> {
self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx)
.await
}
pub async fn mark_notification_as_read(
&self,
recipient_id: UserId,
notification: &Notification,
tx: &DatabaseTransaction,
) -> Result<Option<(UserId, proto::Notification)>> {
self.mark_notification_as_read_internal(recipient_id, notification, None, tx)
.await
}
pub async fn mark_notification_as_read_by_id(
&self,
recipient_id: UserId,
notification_id: NotificationId,
) -> Result<NotificationBatch> {
self.transaction(|tx| async move {
let row = notification::Entity::update(notification::ActiveModel {
id: ActiveValue::Unchanged(notification_id),
recipient_id: ActiveValue::Unchanged(recipient_id),
is_read: ActiveValue::Set(true),
..Default::default()
})
.exec(&*tx)
.await?;
Ok(model_to_proto(self, row)
.map(|notification| (recipient_id, notification))
.into_iter()
.collect())
})
.await
}
async fn mark_notification_as_read_internal(
&self,
recipient_id: UserId,
notification: &Notification,
response: Option<bool>,
tx: &DatabaseTransaction,
) -> Result<Option<(UserId, proto::Notification)>> {
if let Some(id) = self
.find_notification(recipient_id, notification, &*tx)
.await?
{
let row = notification::Entity::update(notification::ActiveModel {
id: ActiveValue::Unchanged(id),
recipient_id: ActiveValue::Unchanged(recipient_id),
is_read: ActiveValue::Set(true),
response: if let Some(response) = response {
ActiveValue::Set(Some(response))
} else {
ActiveValue::NotSet
},
..Default::default()
})
.exec(tx)
.await?;
Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification)))
} else {
Ok(None)
}
}
/// Find an unread notification by its recipient, kind and entity id.
async fn find_notification(
&self,
recipient_id: UserId,
notification: &Notification,
tx: &DatabaseTransaction,
) -> Result<Option<NotificationId>> {
let proto = notification.to_proto();
let kind = notification_kind_from_proto(self, &proto)?;
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryIds {
Id,
}
Ok(notification::Entity::find()
.select_only()
.column(notification::Column::Id)
.filter(
Condition::all()
.add(notification::Column::RecipientId.eq(recipient_id))
.add(notification::Column::IsRead.eq(false))
.add(notification::Column::Kind.eq(kind))
.add(if proto.entity_id.is_some() {
notification::Column::EntityId.eq(proto.entity_id)
} else {
notification::Column::EntityId.is_null()
}),
)
.into_values::<_, QueryIds>()
.one(&*tx)
.await?)
}
}
fn model_to_proto(this: &Database, row: notification::Model) -> Option<proto::Notification> {
let kind = this.notification_kinds_by_id.get(&row.kind)?;
Some(proto::Notification {
id: row.id.to_proto(),
kind: kind.to_string(),
timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
is_read: row.is_read,
response: row.response,
content: row.content,
entity_id: row.entity_id.map(|id| id as u64),
})
}
fn notification_kind_from_proto(
this: &Database,
proto: &proto::Notification,
) -> Result<NotificationKindId> {
Ok(this
.notification_kinds_by_name
.get(&proto.kind)
.copied()
.ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?)
}

View File

@ -7,11 +7,14 @@ pub mod channel_buffer_collaborator;
pub mod channel_chat_participant; pub mod channel_chat_participant;
pub mod channel_member; pub mod channel_member;
pub mod channel_message; pub mod channel_message;
pub mod channel_message_mention;
pub mod channel_path; pub mod channel_path;
pub mod contact; pub mod contact;
pub mod feature_flag; pub mod feature_flag;
pub mod follower; pub mod follower;
pub mod language_server; pub mod language_server;
pub mod notification;
pub mod notification_kind;
pub mod observed_buffer_edits; pub mod observed_buffer_edits;
pub mod observed_channel_messages; pub mod observed_channel_messages;
pub mod project; pub mod project;

View File

@ -0,0 +1,43 @@
use crate::db::{MessageId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "channel_message_mentions")]
pub struct Model {
#[sea_orm(primary_key)]
pub message_id: MessageId,
#[sea_orm(primary_key)]
pub start_offset: i32,
pub end_offset: i32,
pub user_id: UserId,
}
impl ActiveModelBehavior for ActiveModel {}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::channel_message::Entity",
from = "Column::MessageId",
to = "super::channel_message::Column::Id"
)]
Message,
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
MentionedUser,
}
impl Related<super::channel::Entity> for Entity {
fn to() -> RelationDef {
Relation::Message.def()
}
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::MentionedUser.def()
}
}

View File

@ -0,0 +1,29 @@
use crate::db::{NotificationId, NotificationKindId, UserId};
use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "notifications")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: NotificationId,
pub created_at: PrimitiveDateTime,
pub recipient_id: UserId,
pub kind: NotificationKindId,
pub entity_id: Option<i32>,
pub content: String,
pub is_read: bool,
pub response: Option<bool>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::RecipientId",
to = "super::user::Column::Id"
)]
Recipient,
}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -0,0 +1,15 @@
use crate::db::NotificationKindId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "notification_kinds")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: NotificationKindId,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@ -10,7 +10,10 @@ use parking_lot::Mutex;
use rpc::proto::ChannelEdge; use rpc::proto::ChannelEdge;
use sea_orm::ConnectionTrait; use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase; use sqlx::migrate::MigrateDatabase;
use std::sync::Arc; use std::sync::{
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
Arc,
};
const TEST_RELEASE_CHANNEL: &'static str = "test"; const TEST_RELEASE_CHANNEL: &'static str = "test";
@ -31,7 +34,7 @@ impl TestDb {
let mut db = runtime.block_on(async { let mut db = runtime.block_on(async {
let mut options = ConnectOptions::new(url); let mut options = ConnectOptions::new(url);
options.max_connections(5); options.max_connections(5);
let db = Database::new(options, Executor::Deterministic(background)) let mut db = Database::new(options, Executor::Deterministic(background))
.await .await
.unwrap(); .unwrap();
let sql = include_str!(concat!( let sql = include_str!(concat!(
@ -45,6 +48,7 @@ impl TestDb {
)) ))
.await .await
.unwrap(); .unwrap();
db.initialize_notification_kinds().await.unwrap();
db db
}); });
@ -79,11 +83,12 @@ impl TestDb {
options options
.max_connections(5) .max_connections(5)
.idle_timeout(Duration::from_secs(0)); .idle_timeout(Duration::from_secs(0));
let db = Database::new(options, Executor::Deterministic(background)) let mut db = Database::new(options, Executor::Deterministic(background))
.await .await
.unwrap(); .unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
db.migrate(Path::new(migrations_path), false).await.unwrap(); db.migrate(Path::new(migrations_path), false).await.unwrap();
db.initialize_notification_kinds().await.unwrap();
db db
}); });
@ -176,3 +181,27 @@ fn graph(
graph graph
} }
static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
db.create_user(
email,
false,
NewUserParams {
github_login: email[0..email.find("@").unwrap()].to_string(),
github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
},
)
.await
.unwrap()
.user_id
}
static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
fn new_test_connection(server: ServerId) -> ConnectionId {
ConnectionId {
id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
owner_id: server.0 as u32,
}
}

View File

@ -17,7 +17,6 @@ async fn test_channel_buffers(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user_a".into(), github_login: "user_a".into(),
github_user_id: 101, github_user_id: 101,
invite_count: 0,
}, },
) )
.await .await
@ -30,7 +29,6 @@ async fn test_channel_buffers(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user_b".into(), github_login: "user_b".into(),
github_user_id: 102, github_user_id: 102,
invite_count: 0,
}, },
) )
.await .await
@ -45,7 +43,6 @@ async fn test_channel_buffers(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user_c".into(), github_login: "user_c".into(),
github_user_id: 102, github_user_id: 102,
invite_count: 0,
}, },
) )
.await .await
@ -178,7 +175,6 @@ async fn test_channel_buffers_last_operations(db: &Database) {
NewUserParams { NewUserParams {
github_login: "user_a".into(), github_login: "user_a".into(),
github_user_id: 101, github_user_id: 101,
invite_count: 0,
}, },
) )
.await .await
@ -191,7 +187,6 @@ async fn test_channel_buffers_last_operations(db: &Database) {
NewUserParams { NewUserParams {
github_login: "user_b".into(), github_login: "user_b".into(),
github_user_id: 102, github_user_id: 102,
invite_count: 0,
}, },
) )
.await .await

View File

@ -1,20 +1,17 @@
use collections::{HashMap, HashSet}; use std::sync::Arc;
use rpc::{
proto::{self},
ConnectionId,
};
use crate::{ use crate::{
db::{ db::{
queries::channels::ChannelGraph, queries::channels::ChannelGraph,
tests::{graph, TEST_RELEASE_CHANNEL}, tests::{graph, new_test_connection, new_test_user, TEST_RELEASE_CHANNEL},
ChannelId, ChannelRole, Database, NewUserParams, RoomId, ServerId, UserId, ChannelId, ChannelRole, Database, NewUserParams, RoomId,
}, },
test_both_dbs, test_both_dbs,
}; };
use std::sync::{ use collections::{HashMap, HashSet};
atomic::{AtomicI32, AtomicU32, Ordering}, use rpc::{
Arc, proto::{self},
ConnectionId,
}; };
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
@ -305,7 +302,6 @@ async fn test_channel_renames(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user1".into(), github_login: "user1".into(),
github_user_id: 5, github_user_id: 5,
invite_count: 0,
}, },
) )
.await .await
@ -319,7 +315,6 @@ async fn test_channel_renames(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user2".into(), github_login: "user2".into(),
github_user_id: 6, github_user_id: 6,
invite_count: 0,
}, },
) )
.await .await
@ -360,7 +355,6 @@ async fn test_db_channel_moving(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user1".into(), github_login: "user1".into(),
github_user_id: 5, github_user_id: 5,
invite_count: 0,
}, },
) )
.await .await
@ -727,7 +721,6 @@ async fn test_db_channel_moving_bugs(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user1".into(), github_login: "user1".into(),
github_user_id: 5, github_user_id: 5,
invite_count: 0,
}, },
) )
.await .await
@ -1122,28 +1115,3 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)])
pretty_assertions::assert_eq!(actual_map, expected_map) pretty_assertions::assert_eq!(actual_map, expected_map)
} }
static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
db.create_user(
email,
false,
NewUserParams {
github_login: email[0..email.find("@").unwrap()].to_string(),
github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst),
invite_count: 0,
},
)
.await
.unwrap()
.user_id
}
static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
fn new_test_connection(server: ServerId) -> ConnectionId {
ConnectionId {
id: TEST_CONNECTION_ID.fetch_add(1, Ordering::SeqCst),
owner_id: server.0 as u32,
}
}

View File

@ -22,7 +22,6 @@ async fn test_get_users(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: format!("user{i}"), github_login: format!("user{i}"),
github_user_id: i, github_user_id: i,
invite_count: 0,
}, },
) )
.await .await
@ -88,7 +87,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "login1".into(), github_login: "login1".into(),
github_user_id: 101, github_user_id: 101,
invite_count: 0,
}, },
) )
.await .await
@ -101,7 +99,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "login2".into(), github_login: "login2".into(),
github_user_id: 102, github_user_id: 102,
invite_count: 0,
}, },
) )
.await .await
@ -156,7 +153,6 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "u1".into(), github_login: "u1".into(),
github_user_id: 1, github_user_id: 1,
invite_count: 0,
}, },
) )
.await .await
@ -238,7 +234,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: format!("user{i}"), github_login: format!("user{i}"),
github_user_id: i, github_user_id: i,
invite_count: 0,
}, },
) )
.await .await
@ -264,10 +259,7 @@ async fn test_add_contacts(db: &Arc<Database>) {
); );
assert_eq!( assert_eq!(
db.get_contacts(user_2).await.unwrap(), db.get_contacts(user_2).await.unwrap(),
&[Contact::Incoming { &[Contact::Incoming { user_id: user_1 }]
user_id: user_1,
should_notify: true
}]
); );
// User 2 dismisses the contact request notification without accepting or rejecting. // User 2 dismisses the contact request notification without accepting or rejecting.
@ -280,10 +272,7 @@ async fn test_add_contacts(db: &Arc<Database>) {
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
db.get_contacts(user_2).await.unwrap(), db.get_contacts(user_2).await.unwrap(),
&[Contact::Incoming { &[Contact::Incoming { user_id: user_1 }]
user_id: user_1,
should_notify: false
}]
); );
// User can't accept their own contact request // User can't accept their own contact request
@ -299,7 +288,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_1).await.unwrap(), db.get_contacts(user_1).await.unwrap(),
&[Contact::Accepted { &[Contact::Accepted {
user_id: user_2, user_id: user_2,
should_notify: true,
busy: false, busy: false,
}], }],
); );
@ -309,7 +297,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_2).await.unwrap(), db.get_contacts(user_2).await.unwrap(),
&[Contact::Accepted { &[Contact::Accepted {
user_id: user_1, user_id: user_1,
should_notify: false,
busy: false, busy: false,
}] }]
); );
@ -326,7 +313,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_1).await.unwrap(), db.get_contacts(user_1).await.unwrap(),
&[Contact::Accepted { &[Contact::Accepted {
user_id: user_2, user_id: user_2,
should_notify: true,
busy: false, busy: false,
}] }]
); );
@ -339,7 +325,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_1).await.unwrap(), db.get_contacts(user_1).await.unwrap(),
&[Contact::Accepted { &[Contact::Accepted {
user_id: user_2, user_id: user_2,
should_notify: false,
busy: false, busy: false,
}] }]
); );
@ -353,12 +338,10 @@ async fn test_add_contacts(db: &Arc<Database>) {
&[ &[
Contact::Accepted { Contact::Accepted {
user_id: user_2, user_id: user_2,
should_notify: false,
busy: false, busy: false,
}, },
Contact::Accepted { Contact::Accepted {
user_id: user_3, user_id: user_3,
should_notify: false,
busy: false, busy: false,
} }
] ]
@ -367,7 +350,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_3).await.unwrap(), db.get_contacts(user_3).await.unwrap(),
&[Contact::Accepted { &[Contact::Accepted {
user_id: user_1, user_id: user_1,
should_notify: false,
busy: false, busy: false,
}], }],
); );
@ -383,7 +365,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_2).await.unwrap(), db.get_contacts(user_2).await.unwrap(),
&[Contact::Accepted { &[Contact::Accepted {
user_id: user_1, user_id: user_1,
should_notify: false,
busy: false, busy: false,
}] }]
); );
@ -391,7 +372,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_3).await.unwrap(), db.get_contacts(user_3).await.unwrap(),
&[Contact::Accepted { &[Contact::Accepted {
user_id: user_1, user_id: user_1,
should_notify: false,
busy: false, busy: false,
}], }],
); );
@ -415,7 +395,6 @@ async fn test_metrics_id(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "person1".into(), github_login: "person1".into(),
github_user_id: 101, github_user_id: 101,
invite_count: 5,
}, },
) )
.await .await
@ -431,7 +410,6 @@ async fn test_metrics_id(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "person2".into(), github_login: "person2".into(),
github_user_id: 102, github_user_id: 102,
invite_count: 5,
}, },
) )
.await .await
@ -460,7 +438,6 @@ async fn test_project_count(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "admin".into(), github_login: "admin".into(),
github_user_id: 0, github_user_id: 0,
invite_count: 0,
}, },
) )
.await .await
@ -472,7 +449,6 @@ async fn test_project_count(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user".into(), github_login: "user".into(),
github_user_id: 1, github_user_id: 1,
invite_count: 0,
}, },
) )
.await .await
@ -554,7 +530,6 @@ async fn test_fuzzy_search_users() {
NewUserParams { NewUserParams {
github_login: github_login.into(), github_login: github_login.into(),
github_user_id: i as i32, github_user_id: i as i32,
invite_count: 0,
}, },
) )
.await .await
@ -596,7 +571,6 @@ async fn test_non_matching_release_channels(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "admin".into(), github_login: "admin".into(),
github_user_id: 0, github_user_id: 0,
invite_count: 0,
}, },
) )
.await .await
@ -608,7 +582,6 @@ async fn test_non_matching_release_channels(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: "user".into(), github_login: "user".into(),
github_user_id: 1, github_user_id: 1,
invite_count: 0,
}, },
) )
.await .await

View File

@ -18,7 +18,6 @@ async fn test_get_user_flags(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: format!("user1"), github_login: format!("user1"),
github_user_id: 1, github_user_id: 1,
invite_count: 0,
}, },
) )
.await .await
@ -32,7 +31,6 @@ async fn test_get_user_flags(db: &Arc<Database>) {
NewUserParams { NewUserParams {
github_login: format!("user2"), github_login: format!("user2"),
github_user_id: 2, github_user_id: 2,
invite_count: 0,
}, },
) )
.await .await

View File

@ -1,7 +1,9 @@
use super::new_test_user;
use crate::{ use crate::{
db::{ChannelRole, Database, MessageId, NewUserParams}, db::{ChannelRole, Database, MessageId},
test_both_dbs, test_both_dbs,
}; };
use channel::mentions_to_proto;
use std::sync::Arc; use std::sync::Arc;
use time::OffsetDateTime; use time::OffsetDateTime;
@ -12,39 +14,38 @@ test_both_dbs!(
); );
async fn test_channel_message_retrieval(db: &Arc<Database>) { async fn test_channel_message_retrieval(db: &Arc<Database>) {
let user = db let user = new_test_user(db, "user@example.com").await;
.create_user( let result = db.create_channel("channel", None, user).await.unwrap();
"user@example.com",
false,
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel = db.create_root_channel("channel", user).await.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32; let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) db.join_channel_chat(
.await result.channel.id,
.unwrap(); rpc::ConnectionId { owner_id, id: 0 },
user,
)
.await
.unwrap();
let mut all_messages = Vec::new(); let mut all_messages = Vec::new();
for i in 0..10 { for i in 0..10 {
all_messages.push( all_messages.push(
db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) db.create_channel_message(
.await result.channel.id,
.unwrap() user,
.0 &i.to_string(),
.to_proto(), &[],
OffsetDateTime::now_utc(),
i,
)
.await
.unwrap()
.message_id
.to_proto(),
); );
} }
let messages = db let messages = db
.get_channel_messages(channel, user, 3, None) .get_channel_messages(result.channel.id, user, 3, None)
.await .await
.unwrap() .unwrap()
.into_iter() .into_iter()
@ -54,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
let messages = db let messages = db
.get_channel_messages( .get_channel_messages(
channel, result.channel.id,
user, user,
4, 4,
Some(MessageId::from_proto(all_messages[6])), Some(MessageId::from_proto(all_messages[6])),
@ -74,99 +75,154 @@ test_both_dbs!(
); );
async fn test_channel_message_nonces(db: &Arc<Database>) { async fn test_channel_message_nonces(db: &Arc<Database>) {
let user = db let user_a = new_test_user(db, "user_a@example.com").await;
.create_user( let user_b = new_test_user(db, "user_b@example.com").await;
"user@example.com", let user_c = new_test_user(db, "user_c@example.com").await;
false, let channel = db.create_root_channel("channel", user_a).await.unwrap();
NewUserParams { db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
github_login: "user".into(), .await
github_user_id: 1, .unwrap();
invite_count: 0, db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member)
}, .await
.unwrap();
db.respond_to_channel_invite(channel, user_b, true)
.await
.unwrap();
db.respond_to_channel_invite(channel, user_c, true)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a)
.await
.unwrap();
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b)
.await
.unwrap();
// As user A, create messages that re-use the same nonces. The requests
// succeed, but return the same ids.
let id1 = db
.create_channel_message(
channel,
user_a,
"hi @user_b",
&mentions_to_proto(&[(3..10, user_b.to_proto())]),
OffsetDateTime::now_utc(),
100,
) )
.await .await
.unwrap() .unwrap()
.user_id; .message_id;
let channel = db.create_root_channel("channel", user).await.unwrap(); let id2 = db
.create_channel_message(
channel,
user_a,
"hello, fellow users",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
200,
)
.await
.unwrap()
.message_id;
let id3 = db
.create_channel_message(
channel,
user_a,
"bye @user_c (same nonce as first message)",
&mentions_to_proto(&[(4..11, user_c.to_proto())]),
OffsetDateTime::now_utc(),
100,
)
.await
.unwrap()
.message_id;
let id4 = db
.create_channel_message(
channel,
user_a,
"omg (same nonce as second message)",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
200,
)
.await
.unwrap()
.message_id;
let owner_id = db.create_server("test").await.unwrap().0 as u32; // As a different user, reuse one of the same nonces. This request succeeds
// and returns a different id.
let id5 = db
.create_channel_message(
channel,
user_b,
"omg @user_a (same nonce as user_a's first message)",
&mentions_to_proto(&[(4..11, user_a.to_proto())]),
OffsetDateTime::now_utc(),
100,
)
.await
.unwrap()
.message_id;
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) assert_ne!(id1, id2);
.await assert_eq!(id1, id3);
.unwrap(); assert_eq!(id2, id4);
assert_ne!(id5, id1);
let msg1_id = db let messages = db
.create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) .get_channel_messages(channel, user_a, 5, None)
.await .await
.unwrap(); .unwrap()
let msg2_id = db .into_iter()
.create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) .map(|m| (m.id, m.body, m.mentions))
.await .collect::<Vec<_>>();
.unwrap(); assert_eq!(
let msg3_id = db messages,
.create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) &[
.await (
.unwrap(); id1.to_proto(),
let msg4_id = db "hi @user_b".into(),
.create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) mentions_to_proto(&[(3..10, user_b.to_proto())]),
.await ),
.unwrap(); (
id2.to_proto(),
assert_ne!(msg1_id, msg2_id); "hello, fellow users".into(),
assert_eq!(msg1_id, msg3_id); mentions_to_proto(&[])
assert_eq!(msg2_id, msg4_id); ),
(
id5.to_proto(),
"omg @user_a (same nonce as user_a's first message)".into(),
mentions_to_proto(&[(4..11, user_a.to_proto())]),
),
]
);
} }
test_both_dbs!( test_both_dbs!(
test_channel_message_new_notification, test_unseen_channel_messages,
test_channel_message_new_notification_postgres, test_unseen_channel_messages_postgres,
test_channel_message_new_notification_sqlite test_unseen_channel_messages_sqlite
); );
async fn test_channel_message_new_notification(db: &Arc<Database>) { async fn test_unseen_channel_messages(db: &Arc<Database>) {
let user = db let user = new_test_user(db, "user_a@example.com").await;
.create_user( let observer = new_test_user(db, "user_b@example.com").await;
"user_a@example.com",
false,
NewUserParams {
github_login: "user_a".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let observer = db
.create_user(
"user_b@example.com",
false,
NewUserParams {
github_login: "user_b".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel_1 = db.create_root_channel("channel", user).await.unwrap(); let channel_1 = db.create_root_channel("channel", user).await.unwrap();
let channel_2 = db.create_root_channel("channel-2", user).await.unwrap(); let channel_2 = db.create_root_channel("channel-2", user).await.unwrap();
db.invite_channel_member(channel_1, observer, user, ChannelRole::Member) db.invite_channel_member(channel_1, observer, user, ChannelRole::Member)
.await .await
.unwrap(); .unwrap();
db.invite_channel_member(channel_2, observer, user, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(channel_1, observer, true) db.respond_to_channel_invite(channel_1, observer, true)
.await .await
.unwrap(); .unwrap();
db.invite_channel_member(channel_2, observer, user, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(channel_2, observer, true) db.respond_to_channel_invite(channel_2, observer, true)
.await .await
.unwrap(); .unwrap();
@ -179,28 +235,31 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
.unwrap(); .unwrap();
let _ = db let _ = db
.create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1) .create_channel_message(channel_1, user, "1_1", &[], OffsetDateTime::now_utc(), 1)
.await .await
.unwrap(); .unwrap();
let (second_message, _, _) = db let second_message = db
.create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2) .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2)
.await .await
.unwrap(); .unwrap()
.message_id;
let (third_message, _, _) = db let third_message = db
.create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3) .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3)
.await .await
.unwrap(); .unwrap()
.message_id;
db.join_channel_chat(channel_2, user_connection_id, user) db.join_channel_chat(channel_2, user_connection_id, user)
.await .await
.unwrap(); .unwrap();
let (fourth_message, _, _) = db let fourth_message = db
.create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4) .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4)
.await .await
.unwrap(); .unwrap()
.message_id;
// Check that observer has new messages // Check that observer has new messages
let unseen_messages = db let unseen_messages = db
@ -295,3 +354,101 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
}] }]
); );
} }
test_both_dbs!(
test_channel_message_mentions,
test_channel_message_mentions_postgres,
test_channel_message_mentions_sqlite
);
async fn test_channel_message_mentions(db: &Arc<Database>) {
let user_a = new_test_user(db, "user_a@example.com").await;
let user_b = new_test_user(db, "user_b@example.com").await;
let user_c = new_test_user(db, "user_c@example.com").await;
let channel = db
.create_channel("channel", None, user_a)
.await
.unwrap()
.channel
.id;
db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(channel, user_b, true)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let connection_id = rpc::ConnectionId { owner_id, id: 0 };
db.join_channel_chat(channel, connection_id, user_a)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"hi @user_b and @user_c",
&mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]),
OffsetDateTime::now_utc(),
1,
)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"bye @user_c",
&mentions_to_proto(&[(4..11, user_c.to_proto())]),
OffsetDateTime::now_utc(),
2,
)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"umm",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
3,
)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"@user_b, stop.",
&mentions_to_proto(&[(0..7, user_b.to_proto())]),
OffsetDateTime::now_utc(),
4,
)
.await
.unwrap();
let messages = db
.get_channel_messages(channel, user_b, 5, None)
.await
.unwrap()
.into_iter()
.map(|m| (m.body, m.mentions))
.collect::<Vec<_>>();
assert_eq!(
&messages,
&[
(
"hi @user_b and @user_c".into(),
mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]),
),
(
"bye @user_c".into(),
mentions_to_proto(&[(4..11, user_c.to_proto())]),
),
("umm".into(), mentions_to_proto(&[]),),
(
"@user_b, stop.".into(),
mentions_to_proto(&[(0..7, user_b.to_proto())]),
),
]
);
}

View File

@ -119,7 +119,9 @@ impl AppState {
pub async fn new(config: Config) -> Result<Arc<Self>> { pub async fn new(config: Config) -> Result<Arc<Self>> {
let mut db_options = db::ConnectOptions::new(config.database_url.clone()); let mut db_options = db::ConnectOptions::new(config.database_url.clone());
db_options.max_connections(config.database_max_connections); db_options.max_connections(config.database_max_connections);
let db = Database::new(db_options, Executor::Production).await?; let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
let live_kit_client = if let Some(((server, key), secret)) = config let live_kit_client = if let Some(((server, key), secret)) = config
.live_kit_server .live_kit_server
.as_ref() .as_ref()

View File

@ -3,9 +3,11 @@ mod connection_pool;
use crate::{ use crate::{
auth, auth,
db::{ db::{
self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, Database, self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
MembershipUpdated, MessageId, MoveChannelResult, ProjectId, RenameChannelResult, RoomId, CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
ServerId, SetChannelVisibilityResult, User, UserId, MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
User, UserId,
}, },
executor::Executor, executor::Executor,
AppState, Result, AppState, Result,
@ -71,6 +73,7 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
const MESSAGE_COUNT_PER_PAGE: usize = 100; const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024; const MAX_MESSAGE_LEN: usize = 1024;
const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
lazy_static! { lazy_static! {
static ref METRIC_CONNECTIONS: IntGauge = static ref METRIC_CONNECTIONS: IntGauge =
@ -271,6 +274,9 @@ impl Server {
.add_request_handler(send_channel_message) .add_request_handler(send_channel_message)
.add_request_handler(remove_channel_message) .add_request_handler(remove_channel_message)
.add_request_handler(get_channel_messages) .add_request_handler(get_channel_messages)
.add_request_handler(get_channel_messages_by_id)
.add_request_handler(get_notifications)
.add_request_handler(mark_notification_as_read)
.add_request_handler(link_channel) .add_request_handler(link_channel)
.add_request_handler(unlink_channel) .add_request_handler(unlink_channel)
.add_request_handler(move_channel) .add_request_handler(move_channel)
@ -390,7 +396,7 @@ impl Server {
let contacts = app_state.db.get_contacts(user_id).await.trace_err(); let contacts = app_state.db.get_contacts(user_id).await.trace_err();
if let Some((busy, contacts)) = busy.zip(contacts) { if let Some((busy, contacts)) = busy.zip(contacts) {
let pool = pool.lock(); let pool = pool.lock();
let updated_contact = contact_for_user(user_id, false, busy, &pool); let updated_contact = contact_for_user(user_id, busy, &pool);
for contact in contacts { for contact in contacts {
if let db::Contact::Accepted { if let db::Contact::Accepted {
user_id: contact_user_id, user_id: contact_user_id,
@ -584,7 +590,7 @@ impl Server {
let (contacts, channels_for_user, channel_invites) = future::try_join3( let (contacts, channels_for_user, channel_invites) = future::try_join3(
this.app_state.db.get_contacts(user_id), this.app_state.db.get_contacts(user_id),
this.app_state.db.get_channels_for_user(user_id), this.app_state.db.get_channels_for_user(user_id),
this.app_state.db.get_channel_invites_for_user(user_id) this.app_state.db.get_channel_invites_for_user(user_id),
).await?; ).await?;
{ {
@ -690,7 +696,7 @@ impl Server {
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
if let Some(code) = &user.invite_code { if let Some(code) = &user.invite_code {
let pool = self.connection_pool.lock(); let pool = self.connection_pool.lock();
let invitee_contact = contact_for_user(invitee_id, true, false, &pool); let invitee_contact = contact_for_user(invitee_id, false, &pool);
for connection_id in pool.user_connection_ids(inviter_id) { for connection_id in pool.user_connection_ids(inviter_id) {
self.peer.send( self.peer.send(
connection_id, connection_id,
@ -2066,7 +2072,7 @@ async fn request_contact(
return Err(anyhow!("cannot add yourself as a contact"))?; return Err(anyhow!("cannot add yourself as a contact"))?;
} }
session let notifications = session
.db() .db()
.await .await
.send_contact_request(requester_id, responder_id) .send_contact_request(requester_id, responder_id)
@ -2089,16 +2095,14 @@ async fn request_contact(
.incoming_requests .incoming_requests
.push(proto::IncomingContactRequest { .push(proto::IncomingContactRequest {
requester_id: requester_id.to_proto(), requester_id: requester_id.to_proto(),
should_notify: true,
}); });
for connection_id in session let connection_pool = session.connection_pool().await;
.connection_pool() for connection_id in connection_pool.user_connection_ids(responder_id) {
.await
.user_connection_ids(responder_id)
{
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
} }
send_notifications(&*connection_pool, &session.peer, notifications);
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
} }
@ -2117,7 +2121,8 @@ async fn respond_to_contact_request(
} else { } else {
let accept = request.response == proto::ContactRequestResponse::Accept as i32; let accept = request.response == proto::ContactRequestResponse::Accept as i32;
db.respond_to_contact_request(responder_id, requester_id, accept) let notifications = db
.respond_to_contact_request(responder_id, requester_id, accept)
.await?; .await?;
let requester_busy = db.is_user_busy(requester_id).await?; let requester_busy = db.is_user_busy(requester_id).await?;
let responder_busy = db.is_user_busy(responder_id).await?; let responder_busy = db.is_user_busy(responder_id).await?;
@ -2128,7 +2133,7 @@ async fn respond_to_contact_request(
if accept { if accept {
update update
.contacts .contacts
.push(contact_for_user(requester_id, false, requester_busy, &pool)); .push(contact_for_user(requester_id, requester_busy, &pool));
} }
update update
.remove_incoming_requests .remove_incoming_requests
@ -2142,14 +2147,17 @@ async fn respond_to_contact_request(
if accept { if accept {
update update
.contacts .contacts
.push(contact_for_user(responder_id, true, responder_busy, &pool)); .push(contact_for_user(responder_id, responder_busy, &pool));
} }
update update
.remove_outgoing_requests .remove_outgoing_requests
.push(responder_id.to_proto()); .push(responder_id.to_proto());
for connection_id in pool.user_connection_ids(requester_id) { for connection_id in pool.user_connection_ids(requester_id) {
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
} }
send_notifications(&*pool, &session.peer, notifications);
} }
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
@ -2164,7 +2172,8 @@ async fn remove_contact(
let requester_id = session.user_id; let requester_id = session.user_id;
let responder_id = UserId::from_proto(request.user_id); let responder_id = UserId::from_proto(request.user_id);
let db = session.db().await; let db = session.db().await;
let contact_accepted = db.remove_contact(requester_id, responder_id).await?; let (contact_accepted, deleted_notification_id) =
db.remove_contact(requester_id, responder_id).await?;
let pool = session.connection_pool().await; let pool = session.connection_pool().await;
// Update outgoing contact requests of requester // Update outgoing contact requests of requester
@ -2191,6 +2200,14 @@ async fn remove_contact(
} }
for connection_id in pool.user_connection_ids(responder_id) { for connection_id in pool.user_connection_ids(responder_id) {
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
if let Some(notification_id) = deleted_notification_id {
session.peer.send(
connection_id,
proto::DeleteNotification {
notification_id: notification_id.to_proto(),
},
)?;
}
} }
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
@ -2268,7 +2285,10 @@ async fn invite_channel_member(
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let invitee_id = UserId::from_proto(request.user_id); let invitee_id = UserId::from_proto(request.user_id);
let channel = db let InviteMemberResult {
channel,
notifications,
} = db
.invite_channel_member( .invite_channel_member(
channel_id, channel_id,
invitee_id, invitee_id,
@ -2282,14 +2302,13 @@ async fn invite_channel_member(
..Default::default() ..Default::default()
}; };
for connection_id in session let connection_pool = session.connection_pool().await;
.connection_pool() for connection_id in connection_pool.user_connection_ids(invitee_id) {
.await
.user_connection_ids(invitee_id)
{
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
} }
send_notifications(&*connection_pool, &session.peer, notifications);
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
} }
@ -2303,13 +2322,33 @@ async fn remove_channel_member(
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let member_id = UserId::from_proto(request.user_id); let member_id = UserId::from_proto(request.user_id);
let membership_updated = db let RemoveChannelMemberResult {
membership_update,
notification_id,
} = db
.remove_channel_member(channel_id, member_id, session.user_id) .remove_channel_member(channel_id, member_id, session.user_id)
.await?; .await?;
dbg!(&membership_updated); let connection_pool = &session.connection_pool().await;
notify_membership_updated(
notify_membership_updated(membership_updated, member_id, &session).await?; &connection_pool,
membership_update,
member_id,
&session.peer,
);
for connection_id in connection_pool.user_connection_ids(member_id) {
if let Some(notification_id) = notification_id {
session
.peer
.send(
connection_id,
proto::DeleteNotification {
notification_id: notification_id.to_proto(),
},
)
.trace_err();
}
}
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
@ -2374,7 +2413,13 @@ async fn set_channel_member_role(
match result { match result {
db::SetMemberRoleResult::MembershipUpdated(membership_update) => { db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
notify_membership_updated(membership_update, member_id, &session).await?; let connection_pool = session.connection_pool().await;
notify_membership_updated(
&connection_pool,
membership_update,
member_id,
&session.peer,
)
} }
db::SetMemberRoleResult::InviteUpdated(channel) => { db::SetMemberRoleResult::InviteUpdated(channel) => {
let update = proto::UpdateChannels { let update = proto::UpdateChannels {
@ -2535,24 +2580,34 @@ async fn respond_to_channel_invite(
) -> Result<()> { ) -> Result<()> {
let db = session.db().await; let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let result = db let RespondToChannelInvite {
membership_update,
notifications,
} = db
.respond_to_channel_invite(channel_id, session.user_id, request.accept) .respond_to_channel_invite(channel_id, session.user_id, request.accept)
.await?; .await?;
if let Some(accept_invite_result) = result { let connection_pool = session.connection_pool().await;
notify_membership_updated(accept_invite_result, session.user_id, &session).await?; if let Some(membership_update) = membership_update {
notify_membership_updated(
&connection_pool,
membership_update,
session.user_id,
&session.peer,
);
} else { } else {
let update = proto::UpdateChannels { let update = proto::UpdateChannels {
remove_channel_invitations: vec![channel_id.to_proto()], remove_channel_invitations: vec![channel_id.to_proto()],
..Default::default() ..Default::default()
}; };
let connection_pool = session.connection_pool().await;
for connection_id in connection_pool.user_connection_ids(session.user_id) { for connection_id in connection_pool.user_connection_ids(session.user_id) {
session.peer.send(connection_id, update.clone())?; session.peer.send(connection_id, update.clone())?;
} }
}; };
send_notifications(&*connection_pool, &session.peer, notifications);
response.send(proto::Ack {})?; response.send(proto::Ack {})?;
Ok(()) Ok(())
@ -2635,8 +2690,14 @@ async fn join_channel_internal(
live_kit_connection_info, live_kit_connection_info,
})?; })?;
let connection_pool = session.connection_pool().await;
if let Some(accept_invite_result) = accept_invite_result { if let Some(accept_invite_result) = accept_invite_result {
notify_membership_updated(accept_invite_result, session.user_id, &session).await?; notify_membership_updated(
&connection_pool,
accept_invite_result,
session.user_id,
&session.peer,
);
} }
room_updated(&joined_room.room, &session.peer); room_updated(&joined_room.room, &session.peer);
@ -2805,6 +2866,29 @@ fn channel_buffer_updated<T: EnvelopedMessage>(
}); });
} }
fn send_notifications(
connection_pool: &ConnectionPool,
peer: &Peer,
notifications: db::NotificationBatch,
) {
for (user_id, notification) in notifications {
for connection_id in connection_pool.user_connection_ids(user_id) {
if let Err(error) = peer.send(
connection_id,
proto::AddNotification {
notification: Some(notification.clone()),
},
) {
tracing::error!(
"failed to send notification to {:?} {}",
connection_id,
error
);
}
}
}
}
async fn send_channel_message( async fn send_channel_message(
request: proto::SendChannelMessage, request: proto::SendChannelMessage,
response: Response<proto::SendChannelMessage>, response: Response<proto::SendChannelMessage>,
@ -2819,19 +2903,27 @@ async fn send_channel_message(
return Err(anyhow!("message can't be blank"))?; return Err(anyhow!("message can't be blank"))?;
} }
// TODO: adjust mentions if body is trimmed
let timestamp = OffsetDateTime::now_utc(); let timestamp = OffsetDateTime::now_utc();
let nonce = request let nonce = request
.nonce .nonce
.ok_or_else(|| anyhow!("nonce can't be blank"))?; .ok_or_else(|| anyhow!("nonce can't be blank"))?;
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let (message_id, connection_ids, non_participants) = session let CreatedChannelMessage {
message_id,
participant_connection_ids,
channel_members,
notifications,
} = session
.db() .db()
.await .await
.create_channel_message( .create_channel_message(
channel_id, channel_id,
session.user_id, session.user_id,
&body, &body,
&request.mentions,
timestamp, timestamp,
nonce.clone().into(), nonce.clone().into(),
) )
@ -2840,18 +2932,23 @@ async fn send_channel_message(
sender_id: session.user_id.to_proto(), sender_id: session.user_id.to_proto(),
id: message_id.to_proto(), id: message_id.to_proto(),
body, body,
mentions: request.mentions,
timestamp: timestamp.unix_timestamp() as u64, timestamp: timestamp.unix_timestamp() as u64,
nonce: Some(nonce), nonce: Some(nonce),
}; };
broadcast(Some(session.connection_id), connection_ids, |connection| { broadcast(
session.peer.send( Some(session.connection_id),
connection, participant_connection_ids,
proto::ChannelMessageSent { |connection| {
channel_id: channel_id.to_proto(), session.peer.send(
message: Some(message.clone()), connection,
}, proto::ChannelMessageSent {
) channel_id: channel_id.to_proto(),
}); message: Some(message.clone()),
},
)
},
);
response.send(proto::SendChannelMessageResponse { response.send(proto::SendChannelMessageResponse {
message: Some(message), message: Some(message),
})?; })?;
@ -2859,7 +2956,7 @@ async fn send_channel_message(
let pool = &*session.connection_pool().await; let pool = &*session.connection_pool().await;
broadcast( broadcast(
None, None,
non_participants channel_members
.iter() .iter()
.flat_map(|user_id| pool.user_connection_ids(*user_id)), .flat_map(|user_id| pool.user_connection_ids(*user_id)),
|peer_id| { |peer_id| {
@ -2875,6 +2972,7 @@ async fn send_channel_message(
) )
}, },
); );
send_notifications(pool, &session.peer, notifications);
Ok(()) Ok(())
} }
@ -2904,11 +3002,16 @@ async fn acknowledge_channel_message(
) -> Result<()> { ) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id); let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id); let message_id = MessageId::from_proto(request.message_id);
session let notifications = session
.db() .db()
.await .await
.observe_channel_message(channel_id, session.user_id, message_id) .observe_channel_message(channel_id, session.user_id, message_id)
.await?; .await?;
send_notifications(
&*session.connection_pool().await,
&session.peer,
notifications,
);
Ok(()) Ok(())
} }
@ -2983,6 +3086,72 @@ async fn get_channel_messages(
Ok(()) Ok(())
} }
async fn get_channel_messages_by_id(
request: proto::GetChannelMessagesById,
response: Response<proto::GetChannelMessagesById>,
session: Session,
) -> Result<()> {
let message_ids = request
.message_ids
.iter()
.map(|id| MessageId::from_proto(*id))
.collect::<Vec<_>>();
let messages = session
.db()
.await
.get_channel_messages_by_id(session.user_id, &message_ids)
.await?;
response.send(proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
}
async fn get_notifications(
request: proto::GetNotifications,
response: Response<proto::GetNotifications>,
session: Session,
) -> Result<()> {
let notifications = session
.db()
.await
.get_notifications(
session.user_id,
NOTIFICATION_COUNT_PER_PAGE,
request
.before_id
.map(|id| db::NotificationId::from_proto(id)),
)
.await?;
response.send(proto::GetNotificationsResponse {
done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
notifications,
})?;
Ok(())
}
async fn mark_notification_as_read(
request: proto::MarkNotificationRead,
response: Response<proto::MarkNotificationRead>,
session: Session,
) -> Result<()> {
let database = &session.db().await;
let notifications = database
.mark_notification_as_read_by_id(
session.user_id,
NotificationId::from_proto(request.notification_id),
)
.await?;
send_notifications(
&*session.connection_pool().await,
&session.peer,
notifications,
);
response.send(proto::Ack {})?;
Ok(())
}
async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id); let project_id = ProjectId::from_proto(request.project_id);
let project_connection_ids = session let project_connection_ids = session
@ -3052,11 +3221,12 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
} }
} }
async fn notify_membership_updated( fn notify_membership_updated(
connection_pool: &ConnectionPool,
result: MembershipUpdated, result: MembershipUpdated,
user_id: UserId, user_id: UserId,
session: &Session, peer: &Peer,
) -> Result<()> { ) {
let mut update = build_channels_update(result.new_channels, vec![]); let mut update = build_channels_update(result.new_channels, vec![]);
update.delete_channels = result update.delete_channels = result
.removed_channels .removed_channels
@ -3065,11 +3235,9 @@ async fn notify_membership_updated(
.collect(); .collect();
update.remove_channel_invitations = vec![result.channel_id.to_proto()]; update.remove_channel_invitations = vec![result.channel_id.to_proto()];
let connection_pool = session.connection_pool().await;
for connection_id in connection_pool.user_connection_ids(user_id) { for connection_id in connection_pool.user_connection_ids(user_id) {
session.peer.send(connection_id, update.clone())?; peer.send(connection_id, update.clone()).trace_err();
} }
Ok(())
} }
fn build_channels_update( fn build_channels_update(
@ -3120,42 +3288,28 @@ fn build_initial_contacts_update(
for contact in contacts { for contact in contacts {
match contact { match contact {
db::Contact::Accepted { db::Contact::Accepted { user_id, busy } => {
user_id, update.contacts.push(contact_for_user(user_id, busy, &pool));
should_notify,
busy,
} => {
update
.contacts
.push(contact_for_user(user_id, should_notify, busy, &pool));
} }
db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
db::Contact::Incoming { db::Contact::Incoming { user_id } => {
user_id, update
should_notify, .incoming_requests
} => update .push(proto::IncomingContactRequest {
.incoming_requests requester_id: user_id.to_proto(),
.push(proto::IncomingContactRequest { })
requester_id: user_id.to_proto(), }
should_notify,
}),
} }
} }
update update
} }
fn contact_for_user( fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
user_id: UserId,
should_notify: bool,
busy: bool,
pool: &ConnectionPool,
) -> proto::Contact {
proto::Contact { proto::Contact {
user_id: user_id.to_proto(), user_id: user_id.to_proto(),
online: pool.is_user_online(user_id), online: pool.is_user_online(user_id),
busy, busy,
should_notify,
} }
} }
@ -3216,7 +3370,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()>
let busy = db.is_user_busy(user_id).await?; let busy = db.is_user_busy(user_id).await?;
let pool = session.connection_pool().await; let pool = session.connection_pool().await;
let updated_contact = contact_for_user(user_id, false, busy, &pool); let updated_contact = contact_for_user(user_id, busy, &pool);
for contact in contacts { for contact in contacts {
if let db::Contact::Accepted { if let db::Contact::Accepted {
user_id: contact_user_id, user_id: contact_user_id,

View File

@ -6,6 +6,7 @@ mod channel_message_tests;
mod channel_tests; mod channel_tests;
mod following_tests; mod following_tests;
mod integration_tests; mod integration_tests;
mod notification_tests;
mod random_channel_buffer_tests; mod random_channel_buffer_tests;
mod random_project_collaboration_tests; mod random_project_collaboration_tests;
mod randomized_test_helpers; mod randomized_test_helpers;

View File

@ -1,27 +1,30 @@
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use channel::{ChannelChat, ChannelMessageId}; use channel::{ChannelChat, ChannelMessageId, MessageParams};
use collab_ui::chat_panel::ChatPanel; use collab_ui::chat_panel::ChatPanel;
use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext}; use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext};
use rpc::Notification;
use std::sync::Arc; use std::sync::Arc;
use workspace::dock::Panel; use workspace::dock::Panel;
#[gpui::test] #[gpui::test]
async fn test_basic_channel_messages( async fn test_basic_channel_messages(
deterministic: Arc<Deterministic>, deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext, mut cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext, mut cx_b: &mut TestAppContext,
mut cx_c: &mut TestAppContext,
) { ) {
deterministic.forbid_parking(); deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await; let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await; let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await; let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
let channel_id = server let channel_id = server
.make_channel( .make_channel(
"the-channel", "the-channel",
None, None,
(&client_a, cx_a), (&client_a, cx_a),
&mut [(&client_b, cx_b)], &mut [(&client_b, cx_b), (&client_c, cx_c)],
) )
.await; .await;
@ -36,8 +39,17 @@ async fn test_basic_channel_messages(
.await .await
.unwrap(); .unwrap();
channel_chat_a let message_id = channel_chat_a
.update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) .update(cx_a, |c, cx| {
c.send_message(
MessageParams {
text: "hi @user_c!".into(),
mentions: vec![(3..10, client_c.id())],
},
cx,
)
.unwrap()
})
.await .await
.unwrap(); .unwrap();
channel_chat_a channel_chat_a
@ -52,15 +64,55 @@ async fn test_basic_channel_messages(
.unwrap(); .unwrap();
deterministic.run_until_parked(); deterministic.run_until_parked();
channel_chat_a.update(cx_a, |c, _| {
let channel_chat_c = client_c
.channel_store()
.update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
for (chat, cx) in [
(&channel_chat_a, &mut cx_a),
(&channel_chat_b, &mut cx_b),
(&channel_chat_c, &mut cx_c),
] {
chat.update(*cx, |c, _| {
assert_eq!(
c.messages()
.iter()
.map(|m| (m.body.as_str(), m.mentions.as_slice()))
.collect::<Vec<_>>(),
vec![
("hi @user_c!", [(3..10, client_c.id())].as_slice()),
("two", &[]),
("three", &[])
],
"results for user {}",
c.client().id(),
);
});
}
client_c.notification_store().update(cx_c, |store, _| {
assert_eq!(store.notification_count(), 2);
assert_eq!(store.unread_notification_count(), 1);
assert_eq!( assert_eq!(
c.messages() store.notification_at(0).unwrap().notification,
.iter() Notification::ChannelMessageMention {
.map(|m| m.body.as_str()) message_id,
.collect::<Vec<_>>(), sender_id: client_a.id(),
vec!["one", "two", "three"] channel_id,
}
); );
}) assert_eq!(
store.notification_at(1).unwrap().notification,
Notification::ChannelInvitation {
channel_id,
channel_name: "the-channel".to_string(),
inviter_id: client_a.id()
}
);
});
} }
#[gpui::test] #[gpui::test]
@ -280,7 +332,7 @@ async fn test_channel_message_changes(
chat_panel_b chat_panel_b
.update(cx_b, |chat_panel, cx| { .update(cx_b, |chat_panel, cx| {
chat_panel.set_active(true, cx); chat_panel.set_active(true, cx);
chat_panel.select_channel(channel_id, cx) chat_panel.select_channel(channel_id, None, cx)
}) })
.await .await
.unwrap(); .unwrap();

View File

@ -126,8 +126,8 @@ async fn test_core_channels(
// Client B accepts the invitation. // Client B accepts the invitation.
client_b client_b
.channel_store() .channel_store()
.update(cx_b, |channels, _| { .update(cx_b, |channels, cx| {
channels.respond_to_channel_invite(channel_a_id, true) channels.respond_to_channel_invite(channel_a_id, true, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -153,7 +153,6 @@ async fn test_core_channels(
}, },
], ],
); );
dbg!("-------");
let channel_c_id = client_a let channel_c_id = client_a
.channel_store() .channel_store()
@ -289,11 +288,17 @@ async fn test_core_channels(
// Client B no longer has access to the channel // Client B no longer has access to the channel
assert_channels(client_b.channel_store(), cx_b, &[]); assert_channels(client_b.channel_store(), cx_b, &[]);
// When disconnected, client A sees no channels.
server.forbid_connections(); server.forbid_connections();
server.disconnect_client(client_a.peer_id().unwrap()); server.disconnect_client(client_a.peer_id().unwrap());
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
assert_channels(client_a.channel_store(), cx_a, &[]);
client_b
.channel_store()
.update(cx_b, |channel_store, cx| {
channel_store.rename(channel_a_id, "channel-a-renamed", cx)
})
.await
.unwrap();
server.allow_connections(); server.allow_connections();
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT); deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
@ -302,7 +307,7 @@ async fn test_core_channels(
cx_a, cx_a,
&[ExpectedChannel { &[ExpectedChannel {
id: channel_a_id, id: channel_a_id,
name: "channel-a".to_string(), name: "channel-a-renamed".to_string(),
depth: 0, depth: 0,
role: ChannelRole::Admin, role: ChannelRole::Admin,
}], }],
@ -886,8 +891,8 @@ async fn test_lost_channel_creation(
// Client B accepts the invite // Client B accepts the invite
client_b client_b
.channel_store() .channel_store()
.update(cx_b, |channel_store, _| { .update(cx_b, |channel_store, cx| {
channel_store.respond_to_channel_invite(channel_id, true) channel_store.respond_to_channel_invite(channel_id, true, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -951,16 +956,16 @@ async fn test_channel_link_notifications(
client_b client_b
.channel_store() .channel_store()
.update(cx_b, |channel_store, _| { .update(cx_b, |channel_store, cx| {
channel_store.respond_to_channel_invite(zed_channel, true) channel_store.respond_to_channel_invite(zed_channel, true, cx)
}) })
.await .await
.unwrap(); .unwrap();
client_c client_c
.channel_store() .channel_store()
.update(cx_c, |channel_store, _| { .update(cx_c, |channel_store, cx| {
channel_store.respond_to_channel_invite(zed_channel, true) channel_store.respond_to_channel_invite(zed_channel, true, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -1162,16 +1167,16 @@ async fn test_channel_membership_notifications(
client_b client_b
.channel_store() .channel_store()
.update(cx_b, |channel_store, _| { .update(cx_b, |channel_store, cx| {
channel_store.respond_to_channel_invite(zed_channel, true) channel_store.respond_to_channel_invite(zed_channel, true, cx)
}) })
.await .await
.unwrap(); .unwrap();
client_b client_b
.channel_store() .channel_store()
.update(cx_b, |channel_store, _| { .update(cx_b, |channel_store, cx| {
channel_store.respond_to_channel_invite(vim_channel, true) channel_store.respond_to_channel_invite(vim_channel, true, cx)
}) })
.await .await
.unwrap(); .unwrap();

View File

@ -1,6 +1,6 @@
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use call::ActiveCall; use call::ActiveCall;
use collab_ui::project_shared_notification::ProjectSharedNotification; use collab_ui::notifications::project_shared_notification::ProjectSharedNotification;
use editor::{Editor, ExcerptRange, MultiBuffer}; use editor::{Editor, ExcerptRange, MultiBuffer};
use gpui::{executor::Deterministic, geometry::vector::vec2f, TestAppContext, ViewHandle}; use gpui::{executor::Deterministic, geometry::vector::vec2f, TestAppContext, ViewHandle};
use live_kit_client::MacOSDisplay; use live_kit_client::MacOSDisplay;

View File

@ -15,8 +15,8 @@ use gpui::{executor::Deterministic, test::EmptyView, AppContext, ModelHandle, Te
use indoc::indoc; use indoc::indoc;
use language::{ use language::{
language_settings::{AllLanguageSettings, Formatter, InlayHintSettings}, language_settings::{AllLanguageSettings, Formatter, InlayHintSettings},
tree_sitter_rust, Anchor, BundledFormatter, Diagnostic, DiagnosticEntry, FakeLspAdapter, tree_sitter_rust, Anchor, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language,
Language, LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope, LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope,
}; };
use live_kit_client::MacOSDisplay; use live_kit_client::MacOSDisplay;
use lsp::LanguageServerId; use lsp::LanguageServerId;
@ -4530,6 +4530,7 @@ async fn test_prettier_formatting_buffer(
LanguageConfig { LanguageConfig {
name: "Rust".into(), name: "Rust".into(),
path_suffixes: vec!["rs".to_string()], path_suffixes: vec!["rs".to_string()],
prettier_parser_name: Some("test_parser".to_string()),
..Default::default() ..Default::default()
}, },
Some(tree_sitter_rust::language()), Some(tree_sitter_rust::language()),
@ -4537,10 +4538,7 @@ async fn test_prettier_formatting_buffer(
let test_plugin = "test_plugin"; let test_plugin = "test_plugin";
let mut fake_language_servers = language let mut fake_language_servers = language
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter { .set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
enabled_formatters: vec![BundledFormatter::Prettier { prettier_plugins: vec![test_plugin],
parser_name: Some("test_parser"),
plugin_names: vec![test_plugin],
}],
..Default::default() ..Default::default()
})) }))
.await; .await;

View File

@ -0,0 +1,159 @@
use crate::tests::TestServer;
use gpui::{executor::Deterministic, TestAppContext};
use notifications::NotificationEvent;
use parking_lot::Mutex;
use rpc::{proto, Notification};
use std::sync::Arc;
#[gpui::test]
async fn test_notifications(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let notification_events_a = Arc::new(Mutex::new(Vec::new()));
let notification_events_b = Arc::new(Mutex::new(Vec::new()));
client_a.notification_store().update(cx_a, |_, cx| {
let events = notification_events_a.clone();
cx.subscribe(&cx.handle(), move |_, _, event, _| {
events.lock().push(event.clone());
})
.detach()
});
client_b.notification_store().update(cx_b, |_, cx| {
let events = notification_events_b.clone();
cx.subscribe(&cx.handle(), move |_, _, event, _| {
events.lock().push(event.clone());
})
.detach()
});
// Client A sends a contact request to client B.
client_a
.user_store()
.update(cx_a, |store, cx| store.request_contact(client_b.id(), cx))
.await
.unwrap();
// Client B receives a contact request notification and responds to the
// request, accepting it.
deterministic.run_until_parked();
client_b.notification_store().update(cx_b, |store, cx| {
assert_eq!(store.notification_count(), 1);
assert_eq!(store.unread_notification_count(), 1);
let entry = store.notification_at(0).unwrap();
assert_eq!(
entry.notification,
Notification::ContactRequest {
sender_id: client_a.id()
}
);
assert!(!entry.is_read);
assert_eq!(
&notification_events_b.lock()[0..],
&[
NotificationEvent::NewNotification {
entry: entry.clone(),
},
NotificationEvent::NotificationsUpdated {
old_range: 0..0,
new_count: 1
}
]
);
store.respond_to_notification(entry.notification.clone(), true, cx);
});
// Client B sees the notification is now read, and that they responded.
deterministic.run_until_parked();
client_b.notification_store().read_with(cx_b, |store, _| {
assert_eq!(store.notification_count(), 1);
assert_eq!(store.unread_notification_count(), 0);
let entry = store.notification_at(0).unwrap();
assert!(entry.is_read);
assert_eq!(entry.response, Some(true));
assert_eq!(
&notification_events_b.lock()[2..],
&[
NotificationEvent::NotificationRead {
entry: entry.clone(),
},
NotificationEvent::NotificationsUpdated {
old_range: 0..1,
new_count: 1
}
]
);
});
// Client A receives a notification that client B accepted their request.
client_a.notification_store().read_with(cx_a, |store, _| {
assert_eq!(store.notification_count(), 1);
assert_eq!(store.unread_notification_count(), 1);
let entry = store.notification_at(0).unwrap();
assert_eq!(
entry.notification,
Notification::ContactRequestAccepted {
responder_id: client_b.id()
}
);
assert!(!entry.is_read);
});
// Client A creates a channel and invites client B to be a member.
let channel_id = client_a
.channel_store()
.update(cx_a, |store, cx| {
store.create_channel("the-channel", None, cx)
})
.await
.unwrap();
client_a
.channel_store()
.update(cx_a, |store, cx| {
store.invite_member(channel_id, client_b.id(), proto::ChannelRole::Member, cx)
})
.await
.unwrap();
// Client B receives a channel invitation notification and responds to the
// invitation, accepting it.
deterministic.run_until_parked();
client_b.notification_store().update(cx_b, |store, cx| {
assert_eq!(store.notification_count(), 2);
assert_eq!(store.unread_notification_count(), 1);
let entry = store.notification_at(0).unwrap();
assert_eq!(
entry.notification,
Notification::ChannelInvitation {
channel_id,
channel_name: "the-channel".to_string(),
inviter_id: client_a.id()
}
);
assert!(!entry.is_read);
store.respond_to_notification(entry.notification.clone(), true, cx);
});
// Client B sees the notification is now read, and that they responded.
deterministic.run_until_parked();
client_b.notification_store().read_with(cx_b, |store, _| {
assert_eq!(store.notification_count(), 2);
assert_eq!(store.unread_notification_count(), 0);
let entry = store.notification_at(0).unwrap();
assert!(entry.is_read);
assert_eq!(entry.response, Some(true));
});
}

View File

@ -208,8 +208,7 @@ impl<T: RandomizedTest> TestPlan<T> {
false, false,
NewUserParams { NewUserParams {
github_login: username.clone(), github_login: username.clone(),
github_user_id: (ix + 1) as i32, github_user_id: ix as i32,
invite_count: 0,
}, },
) )
.await .await

View File

@ -16,6 +16,7 @@ use futures::{channel::oneshot, StreamExt as _};
use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
use language::LanguageRegistry; use language::LanguageRegistry;
use node_runtime::FakeNodeRuntime; use node_runtime::FakeNodeRuntime;
use notifications::NotificationStore;
use parking_lot::Mutex; use parking_lot::Mutex;
use project::{Project, WorktreeId}; use project::{Project, WorktreeId};
use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT}; use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT};
@ -46,6 +47,7 @@ pub struct TestClient {
pub username: String, pub username: String,
pub app_state: Arc<workspace::AppState>, pub app_state: Arc<workspace::AppState>,
channel_store: ModelHandle<ChannelStore>, channel_store: ModelHandle<ChannelStore>,
notification_store: ModelHandle<NotificationStore>,
state: RefCell<TestClientState>, state: RefCell<TestClientState>,
} }
@ -138,7 +140,6 @@ impl TestServer {
NewUserParams { NewUserParams {
github_login: name.into(), github_login: name.into(),
github_user_id: 0, github_user_id: 0,
invite_count: 0,
}, },
) )
.await .await
@ -231,7 +232,8 @@ impl TestServer {
workspace::init(app_state.clone(), cx); workspace::init(app_state.clone(), cx);
audio::init((), cx); audio::init((), cx);
call::init(client.clone(), user_store.clone(), cx); call::init(client.clone(), user_store.clone(), cx);
channel::init(&client, user_store, cx); channel::init(&client, user_store.clone(), cx);
notifications::init(client.clone(), user_store, cx);
}); });
client client
@ -243,6 +245,7 @@ impl TestServer {
app_state, app_state,
username: name.to_string(), username: name.to_string(),
channel_store: cx.read(ChannelStore::global).clone(), channel_store: cx.read(ChannelStore::global).clone(),
notification_store: cx.read(NotificationStore::global).clone(),
state: Default::default(), state: Default::default(),
}; };
client.wait_for_current_user(cx).await; client.wait_for_current_user(cx).await;
@ -338,8 +341,8 @@ impl TestServer {
member_cx member_cx
.read(ChannelStore::global) .read(ChannelStore::global)
.update(*member_cx, |channels, _| { .update(*member_cx, |channels, cx| {
channels.respond_to_channel_invite(channel_id, true) channels.respond_to_channel_invite(channel_id, true, cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -448,6 +451,10 @@ impl TestClient {
&self.channel_store &self.channel_store
} }
pub fn notification_store(&self) -> &ModelHandle<NotificationStore> {
&self.notification_store
}
pub fn user_store(&self) -> &ModelHandle<UserStore> { pub fn user_store(&self) -> &ModelHandle<UserStore> {
&self.app_state.user_store &self.app_state.user_store
} }

View File

@ -37,10 +37,12 @@ fuzzy = { path = "../fuzzy" }
gpui = { path = "../gpui" } gpui = { path = "../gpui" }
language = { path = "../language" } language = { path = "../language" }
menu = { path = "../menu" } menu = { path = "../menu" }
notifications = { path = "../notifications" }
rich_text = { path = "../rich_text" } rich_text = { path = "../rich_text" }
picker = { path = "../picker" } picker = { path = "../picker" }
project = { path = "../project" } project = { path = "../project" }
recent_projects = {path = "../recent_projects"} recent_projects = { path = "../recent_projects" }
rpc = { path = "../rpc" }
settings = { path = "../settings" } settings = { path = "../settings" }
feature_flags = {path = "../feature_flags"} feature_flags = {path = "../feature_flags"}
theme = { path = "../theme" } theme = { path = "../theme" }
@ -52,6 +54,7 @@ zed-actions = {path = "../zed-actions"}
anyhow.workspace = true anyhow.workspace = true
futures.workspace = true futures.workspace = true
lazy_static.workspace = true
log.workspace = true log.workspace = true
schemars.workspace = true schemars.workspace = true
postage.workspace = true postage.workspace = true
@ -66,7 +69,12 @@ client = { path = "../client", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] }
editor = { path = "../editor", features = ["test-support"] } editor = { path = "../editor", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] }
notifications = { path = "../notifications", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] } project = { path = "../project", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] } util = { path = "../util", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] }
pretty_assertions.workspace = true
tree-sitter-markdown.workspace = true

View File

@ -1,4 +1,6 @@
use crate::{channel_view::ChannelView, ChatPanelSettings}; use crate::{
channel_view::ChannelView, is_channels_feature_enabled, render_avatar, ChatPanelSettings,
};
use anyhow::Result; use anyhow::Result;
use call::ActiveCall; use call::ActiveCall;
use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore}; use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore};
@ -6,18 +8,18 @@ use client::Client;
use collections::HashMap; use collections::HashMap;
use db::kvp::KEY_VALUE_STORE; use db::kvp::KEY_VALUE_STORE;
use editor::Editor; use editor::Editor;
use feature_flags::{ChannelsAlpha, FeatureFlagAppExt};
use gpui::{ use gpui::{
actions, actions,
elements::*, elements::*,
platform::{CursorStyle, MouseButton}, platform::{CursorStyle, MouseButton},
serde_json, serde_json,
views::{ItemType, Select, SelectStyle}, views::{ItemType, Select, SelectStyle},
AnyViewHandle, AppContext, AsyncAppContext, Entity, ImageData, ModelHandle, Subscription, Task, AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Subscription, Task, View,
View, ViewContext, ViewHandle, WeakViewHandle, ViewContext, ViewHandle, WeakViewHandle,
}; };
use language::{language_settings::SoftWrap, LanguageRegistry}; use language::LanguageRegistry;
use menu::Confirm; use menu::Confirm;
use message_editor::MessageEditor;
use project::Fs; use project::Fs;
use rich_text::RichText; use rich_text::RichText;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -31,6 +33,8 @@ use workspace::{
Workspace, Workspace,
}; };
mod message_editor;
const MESSAGE_LOADING_THRESHOLD: usize = 50; const MESSAGE_LOADING_THRESHOLD: usize = 50;
const CHAT_PANEL_KEY: &'static str = "ChatPanel"; const CHAT_PANEL_KEY: &'static str = "ChatPanel";
@ -40,7 +44,7 @@ pub struct ChatPanel {
languages: Arc<LanguageRegistry>, languages: Arc<LanguageRegistry>,
active_chat: Option<(ModelHandle<ChannelChat>, Subscription)>, active_chat: Option<(ModelHandle<ChannelChat>, Subscription)>,
message_list: ListState<ChatPanel>, message_list: ListState<ChatPanel>,
input_editor: ViewHandle<Editor>, input_editor: ViewHandle<MessageEditor>,
channel_select: ViewHandle<Select>, channel_select: ViewHandle<Select>,
local_timezone: UtcOffset, local_timezone: UtcOffset,
fs: Arc<dyn Fs>, fs: Arc<dyn Fs>,
@ -49,6 +53,7 @@ pub struct ChatPanel {
pending_serialization: Task<Option<()>>, pending_serialization: Task<Option<()>>,
subscriptions: Vec<gpui::Subscription>, subscriptions: Vec<gpui::Subscription>,
workspace: WeakViewHandle<Workspace>, workspace: WeakViewHandle<Workspace>,
is_scrolled_to_bottom: bool,
has_focus: bool, has_focus: bool,
markdown_data: HashMap<ChannelMessageId, RichText>, markdown_data: HashMap<ChannelMessageId, RichText>,
} }
@ -85,13 +90,18 @@ impl ChatPanel {
let languages = workspace.app_state().languages.clone(); let languages = workspace.app_state().languages.clone();
let input_editor = cx.add_view(|cx| { let input_editor = cx.add_view(|cx| {
let mut editor = Editor::auto_height( MessageEditor::new(
4, languages.clone(),
Some(Arc::new(|theme| theme.chat_panel.input_editor.clone())), channel_store.clone(),
cx.add_view(|cx| {
Editor::auto_height(
4,
Some(Arc::new(|theme| theme.chat_panel.input_editor.clone())),
cx,
)
}),
cx, cx,
); )
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
editor
}); });
let workspace_handle = workspace.weak_handle(); let workspace_handle = workspace.weak_handle();
@ -121,13 +131,14 @@ impl ChatPanel {
}); });
let mut message_list = let mut message_list =
ListState::<Self>::new(0, Orientation::Bottom, 1000., move |this, ix, cx| { ListState::<Self>::new(0, Orientation::Bottom, 10., move |this, ix, cx| {
this.render_message(ix, cx) this.render_message(ix, cx)
}); });
message_list.set_scroll_handler(|visible_range, this, cx| { message_list.set_scroll_handler(|visible_range, count, this, cx| {
if visible_range.start < MESSAGE_LOADING_THRESHOLD { if visible_range.start < MESSAGE_LOADING_THRESHOLD {
this.load_more_messages(&LoadMoreMessages, cx); this.load_more_messages(&LoadMoreMessages, cx);
} }
this.is_scrolled_to_bottom = visible_range.end == count;
}); });
cx.add_view(|cx| { cx.add_view(|cx| {
@ -136,7 +147,6 @@ impl ChatPanel {
client, client,
channel_store, channel_store,
languages, languages,
active_chat: Default::default(), active_chat: Default::default(),
pending_serialization: Task::ready(None), pending_serialization: Task::ready(None),
message_list, message_list,
@ -146,6 +156,7 @@ impl ChatPanel {
has_focus: false, has_focus: false,
subscriptions: Vec::new(), subscriptions: Vec::new(),
workspace: workspace_handle, workspace: workspace_handle,
is_scrolled_to_bottom: true,
active: false, active: false,
width: None, width: None,
markdown_data: Default::default(), markdown_data: Default::default(),
@ -179,35 +190,20 @@ impl ChatPanel {
.channel_at(selected_ix) .channel_at(selected_ix)
.map(|e| e.id); .map(|e| e.id);
if let Some(selected_channel_id) = selected_channel_id { if let Some(selected_channel_id) = selected_channel_id {
this.select_channel(selected_channel_id, cx) this.select_channel(selected_channel_id, None, cx)
.detach_and_log_err(cx); .detach_and_log_err(cx);
} }
}) })
.detach(); .detach();
let markdown = this.languages.language_for_name("Markdown");
cx.spawn(|this, mut cx| async move {
let markdown = markdown.await?;
this.update(&mut cx, |this, cx| {
this.input_editor.update(cx, |editor, cx| {
editor.buffer().update(cx, |multi_buffer, cx| {
multi_buffer
.as_singleton()
.unwrap()
.update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx))
})
})
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
this this
}) })
} }
pub fn is_scrolled_to_bottom(&self) -> bool {
self.is_scrolled_to_bottom
}
pub fn active_chat(&self) -> Option<ModelHandle<ChannelChat>> { pub fn active_chat(&self) -> Option<ModelHandle<ChannelChat>> {
self.active_chat.as_ref().map(|(chat, _)| chat.clone()) self.active_chat.as_ref().map(|(chat, _)| chat.clone())
} }
@ -267,24 +263,22 @@ impl ChatPanel {
fn set_active_chat(&mut self, chat: ModelHandle<ChannelChat>, cx: &mut ViewContext<Self>) { fn set_active_chat(&mut self, chat: ModelHandle<ChannelChat>, cx: &mut ViewContext<Self>) {
if self.active_chat.as_ref().map(|e| &e.0) != Some(&chat) { if self.active_chat.as_ref().map(|e| &e.0) != Some(&chat) {
let id = chat.read(cx).channel_id; let channel_id = chat.read(cx).channel_id;
{ {
self.markdown_data.clear();
let chat = chat.read(cx); let chat = chat.read(cx);
self.message_list.reset(chat.message_count()); self.message_list.reset(chat.message_count());
let placeholder = if let Some(channel) = chat.channel(cx) {
format!("Message #{}", channel.name) let channel_name = chat.channel(cx).map(|channel| channel.name.clone());
} else { self.input_editor.update(cx, |editor, cx| {
"Message Channel".to_string() editor.set_channel(channel_id, channel_name, cx);
};
self.input_editor.update(cx, move |editor, cx| {
editor.set_placeholder_text(placeholder, cx);
}); });
} };
let subscription = cx.subscribe(&chat, Self::channel_did_change); let subscription = cx.subscribe(&chat, Self::channel_did_change);
self.active_chat = Some((chat, subscription)); self.active_chat = Some((chat, subscription));
self.acknowledge_last_message(cx); self.acknowledge_last_message(cx);
self.channel_select.update(cx, |select, cx| { self.channel_select.update(cx, |select, cx| {
if let Some(ix) = self.channel_store.read(cx).index_of_channel(id) { if let Some(ix) = self.channel_store.read(cx).index_of_channel(channel_id) {
select.set_selected_index(ix, cx); select.set_selected_index(ix, cx);
} }
}); });
@ -323,7 +317,7 @@ impl ChatPanel {
} }
fn acknowledge_last_message(&mut self, cx: &mut ViewContext<'_, '_, ChatPanel>) { fn acknowledge_last_message(&mut self, cx: &mut ViewContext<'_, '_, ChatPanel>) {
if self.active { if self.active && self.is_scrolled_to_bottom {
if let Some((chat, _)) = &self.active_chat { if let Some((chat, _)) = &self.active_chat {
chat.update(cx, |chat, cx| { chat.update(cx, |chat, cx| {
chat.acknowledge_last_message(cx); chat.acknowledge_last_message(cx);
@ -359,33 +353,48 @@ impl ChatPanel {
} }
fn render_message(&mut self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement<Self> { fn render_message(&mut self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let (message, is_continuation, is_last, is_admin) = { let (message, is_continuation, is_last, is_admin) = self
let active_chat = self.active_chat.as_ref().unwrap().0.read(cx); .active_chat
let is_admin = self .as_ref()
.channel_store .unwrap()
.read(cx) .0
.is_channel_admin(active_chat.channel_id); .update(cx, |active_chat, cx| {
let last_message = active_chat.message(ix.saturating_sub(1)); let is_admin = self
let this_message = active_chat.message(ix); .channel_store
let is_continuation = last_message.id != this_message.id .read(cx)
&& this_message.sender.id == last_message.sender.id; .is_channel_admin(active_chat.channel_id);
( let last_message = active_chat.message(ix.saturating_sub(1));
active_chat.message(ix).clone(), let this_message = active_chat.message(ix).clone();
is_continuation, let is_continuation = last_message.id != this_message.id
active_chat.message_count() == ix + 1, && this_message.sender.id == last_message.sender.id;
is_admin,
) if let ChannelMessageId::Saved(id) = this_message.id {
}; if this_message
.mentions
.iter()
.any(|(_, user_id)| Some(*user_id) == self.client.user_id())
{
active_chat.acknowledge_message(id);
}
}
(
this_message,
is_continuation,
active_chat.message_count() == ix + 1,
is_admin,
)
});
let is_pending = message.is_pending(); let is_pending = message.is_pending();
let text = self let theme = theme::current(cx);
.markdown_data let text = self.markdown_data.entry(message.id).or_insert_with(|| {
.entry(message.id) Self::render_markdown_with_mentions(&self.languages, self.client.id(), &message)
.or_insert_with(|| rich_text::render_markdown(message.body, &self.languages, None)); });
let now = OffsetDateTime::now_utc(); let now = OffsetDateTime::now_utc();
let theme = theme::current(cx);
let style = if is_pending { let style = if is_pending {
&theme.chat_panel.pending_message &theme.chat_panel.pending_message
} else if is_continuation { } else if is_continuation {
@ -405,14 +414,13 @@ impl ChatPanel {
enum MessageBackgroundHighlight {} enum MessageBackgroundHighlight {}
MouseEventHandler::new::<MessageBackgroundHighlight, _>(ix, cx, |state, cx| { MouseEventHandler::new::<MessageBackgroundHighlight, _>(ix, cx, |state, cx| {
let container = style.container.style_for(state); let container = style.style_for(state);
if is_continuation { if is_continuation {
Flex::row() Flex::row()
.with_child( .with_child(
text.element( text.element(
theme.editor.syntax.clone(), theme.editor.syntax.clone(),
style.body.clone(), theme.chat_panel.rich_text.clone(),
theme.editor.document_highlight_read_background,
cx, cx,
) )
.flex(1., true), .flex(1., true),
@ -434,15 +442,16 @@ impl ChatPanel {
Flex::row() Flex::row()
.with_child(render_avatar( .with_child(render_avatar(
message.sender.avatar.clone(), message.sender.avatar.clone(),
&theme, &theme.chat_panel.avatar,
theme.chat_panel.avatar_container,
)) ))
.with_child( .with_child(
Label::new( Label::new(
message.sender.github_login.clone(), message.sender.github_login.clone(),
style.sender.text.clone(), theme.chat_panel.message_sender.text.clone(),
) )
.contained() .contained()
.with_style(style.sender.container), .with_style(theme.chat_panel.message_sender.container),
) )
.with_child( .with_child(
Label::new( Label::new(
@ -451,10 +460,10 @@ impl ChatPanel {
now, now,
self.local_timezone, self.local_timezone,
), ),
style.timestamp.text.clone(), theme.chat_panel.message_timestamp.text.clone(),
) )
.contained() .contained()
.with_style(style.timestamp.container), .with_style(theme.chat_panel.message_timestamp.container),
) )
.align_children_center() .align_children_center()
.flex(1., true), .flex(1., true),
@ -467,8 +476,7 @@ impl ChatPanel {
.with_child( .with_child(
text.element( text.element(
theme.editor.syntax.clone(), theme.editor.syntax.clone(),
style.body.clone(), theme.chat_panel.rich_text.clone(),
theme.editor.document_highlight_read_background,
cx, cx,
) )
.flex(1., true), .flex(1., true),
@ -489,6 +497,23 @@ impl ChatPanel {
.into_any() .into_any()
} }
fn render_markdown_with_mentions(
language_registry: &Arc<LanguageRegistry>,
current_user_id: u64,
message: &channel::ChannelMessage,
) -> RichText {
let mentions = message
.mentions
.iter()
.map(|(range, user_id)| rich_text::Mention {
range: range.clone(),
is_self_mention: *user_id == current_user_id,
})
.collect::<Vec<_>>();
rich_text::render_markdown(message.body.clone(), &mentions, language_registry, None)
}
fn render_input_box(&self, theme: &Arc<Theme>, cx: &AppContext) -> AnyElement<Self> { fn render_input_box(&self, theme: &Arc<Theme>, cx: &AppContext) -> AnyElement<Self> {
ChildView::new(&self.input_editor, cx) ChildView::new(&self.input_editor, cx)
.contained() .contained()
@ -614,14 +639,12 @@ impl ChatPanel {
fn send(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) { fn send(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
if let Some((chat, _)) = self.active_chat.as_ref() { if let Some((chat, _)) = self.active_chat.as_ref() {
let body = self.input_editor.update(cx, |editor, cx| { let message = self
let body = editor.text(cx); .input_editor
editor.clear(cx); .update(cx, |editor, cx| editor.take_message(cx));
body
});
if let Some(task) = chat if let Some(task) = chat
.update(cx, |chat, cx| chat.send_message(body, cx)) .update(cx, |chat, cx| chat.send_message(message, cx))
.log_err() .log_err()
{ {
task.detach(); task.detach();
@ -638,7 +661,9 @@ impl ChatPanel {
fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext<Self>) { fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext<Self>) {
if let Some((chat, _)) = self.active_chat.as_ref() { if let Some((chat, _)) = self.active_chat.as_ref() {
chat.update(cx, |channel, cx| { chat.update(cx, |channel, cx| {
channel.load_more_messages(cx); if let Some(task) = channel.load_more_messages(cx) {
task.detach();
}
}) })
} }
} }
@ -646,23 +671,46 @@ impl ChatPanel {
pub fn select_channel( pub fn select_channel(
&mut self, &mut self,
selected_channel_id: u64, selected_channel_id: u64,
scroll_to_message_id: Option<u64>,
cx: &mut ViewContext<ChatPanel>, cx: &mut ViewContext<ChatPanel>,
) -> Task<Result<()>> { ) -> Task<Result<()>> {
if let Some((chat, _)) = &self.active_chat { let open_chat = self
if chat.read(cx).channel_id == selected_channel_id { .active_chat
return Task::ready(Ok(())); .as_ref()
} .and_then(|(chat, _)| {
} (chat.read(cx).channel_id == selected_channel_id)
.then(|| Task::ready(anyhow::Ok(chat.clone())))
})
.unwrap_or_else(|| {
self.channel_store.update(cx, |store, cx| {
store.open_channel_chat(selected_channel_id, cx)
})
});
let open_chat = self.channel_store.update(cx, |store, cx| {
store.open_channel_chat(selected_channel_id, cx)
});
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let chat = open_chat.await?; let chat = open_chat.await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.markdown_data = Default::default(); this.set_active_chat(chat.clone(), cx);
this.set_active_chat(chat, cx); })?;
})
if let Some(message_id) = scroll_to_message_id {
if let Some(item_ix) =
ChannelChat::load_history_since_message(chat.clone(), message_id, cx.clone())
.await
{
this.update(&mut cx, |this, cx| {
if this.active_chat.as_ref().map_or(false, |(c, _)| *c == chat) {
this.message_list.scroll_to(ListOffset {
item_ix,
offset_in_item: 0.,
});
cx.notify();
}
})?;
}
}
Ok(())
}) })
} }
@ -685,32 +733,6 @@ impl ChatPanel {
} }
} }
fn render_avatar(avatar: Option<Arc<ImageData>>, theme: &Arc<Theme>) -> AnyElement<ChatPanel> {
let avatar_style = theme.chat_panel.avatar;
avatar
.map(|avatar| {
Image::from_data(avatar)
.with_style(avatar_style.image)
.aligned()
.contained()
.with_corner_radius(avatar_style.outer_corner_radius)
.constrained()
.with_width(avatar_style.outer_width)
.with_height(avatar_style.outer_width)
.into_any()
})
.unwrap_or_else(|| {
Empty::new()
.constrained()
.with_width(avatar_style.outer_width)
.into_any()
})
.contained()
.with_style(theme.chat_panel.avatar_container)
.into_any()
}
fn render_remove( fn render_remove(
message_id_to_remove: Option<u64>, message_id_to_remove: Option<u64>,
cx: &mut ViewContext<'_, '_, ChatPanel>, cx: &mut ViewContext<'_, '_, ChatPanel>,
@ -781,7 +803,8 @@ impl View for ChatPanel {
*self.client.status().borrow(), *self.client.status().borrow(),
client::Status::Connected { .. } client::Status::Connected { .. }
) { ) {
cx.focus(&self.input_editor); let editor = self.input_editor.read(cx).editor.clone();
cx.focus(&editor);
} }
} }
@ -820,14 +843,14 @@ impl Panel for ChatPanel {
self.active = active; self.active = active;
if active { if active {
self.acknowledge_last_message(cx); self.acknowledge_last_message(cx);
if !is_chat_feature_enabled(cx) { if !is_channels_feature_enabled(cx) {
cx.emit(Event::Dismissed); cx.emit(Event::Dismissed);
} }
} }
} }
fn icon_path(&self, cx: &gpui::WindowContext) -> Option<&'static str> { fn icon_path(&self, cx: &gpui::WindowContext) -> Option<&'static str> {
(settings::get::<ChatPanelSettings>(cx).button && is_chat_feature_enabled(cx)) (settings::get::<ChatPanelSettings>(cx).button && is_channels_feature_enabled(cx))
.then(|| "icons/conversations.svg") .then(|| "icons/conversations.svg")
} }
@ -852,10 +875,6 @@ impl Panel for ChatPanel {
} }
} }
fn is_chat_feature_enabled(cx: &gpui::WindowContext<'_>) -> bool {
cx.is_staff() || cx.has_flag::<ChannelsAlpha>()
}
fn format_timestamp( fn format_timestamp(
mut timestamp: OffsetDateTime, mut timestamp: OffsetDateTime,
mut now: OffsetDateTime, mut now: OffsetDateTime,
@ -893,3 +912,72 @@ fn render_icon_button<V: View>(style: &IconButton, svg_path: &'static str) -> im
.contained() .contained()
.with_style(style.container) .with_style(style.container)
} }
#[cfg(test)]
mod tests {
use super::*;
use gpui::fonts::HighlightStyle;
use pretty_assertions::assert_eq;
use rich_text::{BackgroundKind, Highlight, RenderedRegion};
use util::test::marked_text_ranges;
#[gpui::test]
fn test_render_markdown_with_mentions() {
let language_registry = Arc::new(LanguageRegistry::test());
let (body, ranges) = marked_text_ranges("*hi*, «@abc», let's **call** «@fgh»", false);
let message = channel::ChannelMessage {
id: ChannelMessageId::Saved(0),
body,
timestamp: OffsetDateTime::now_utc(),
sender: Arc::new(client::User {
github_login: "fgh".into(),
avatar: None,
id: 103,
}),
nonce: 5,
mentions: vec![(ranges[0].clone(), 101), (ranges[1].clone(), 102)],
};
let message = ChatPanel::render_markdown_with_mentions(&language_registry, 102, &message);
// Note that the "'" was replaced with due to smart punctuation.
let (body, ranges) = marked_text_ranges("«hi», «@abc», lets «call» «@fgh»", false);
assert_eq!(message.text, body);
assert_eq!(
message.highlights,
vec![
(
ranges[0].clone(),
HighlightStyle {
italic: Some(true),
..Default::default()
}
.into()
),
(ranges[1].clone(), Highlight::Mention),
(
ranges[2].clone(),
HighlightStyle {
weight: Some(gpui::fonts::Weight::BOLD),
..Default::default()
}
.into()
),
(ranges[3].clone(), Highlight::SelfMention)
]
);
assert_eq!(
message.regions,
vec![
RenderedRegion {
background_kind: Some(BackgroundKind::Mention),
link_url: None
},
RenderedRegion {
background_kind: Some(BackgroundKind::SelfMention),
link_url: None
},
]
);
}
}

View File

@ -0,0 +1,313 @@
use channel::{ChannelId, ChannelMembership, ChannelStore, MessageParams};
use client::UserId;
use collections::HashMap;
use editor::{AnchorRangeExt, Editor};
use gpui::{
elements::ChildView, AnyElement, AsyncAppContext, Element, Entity, ModelHandle, Task, View,
ViewContext, ViewHandle, WeakViewHandle,
};
use language::{language_settings::SoftWrap, Buffer, BufferSnapshot, LanguageRegistry};
use lazy_static::lazy_static;
use project::search::SearchQuery;
use std::{sync::Arc, time::Duration};
const MENTIONS_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(50);
lazy_static! {
static ref MENTIONS_SEARCH: SearchQuery = SearchQuery::regex(
"@[-_\\w]+",
false,
false,
Default::default(),
Default::default()
)
.unwrap();
}
pub struct MessageEditor {
pub editor: ViewHandle<Editor>,
channel_store: ModelHandle<ChannelStore>,
users: HashMap<String, UserId>,
mentions: Vec<UserId>,
mentions_task: Option<Task<()>>,
channel_id: Option<ChannelId>,
}
impl MessageEditor {
pub fn new(
language_registry: Arc<LanguageRegistry>,
channel_store: ModelHandle<ChannelStore>,
editor: ViewHandle<Editor>,
cx: &mut ViewContext<Self>,
) -> Self {
editor.update(cx, |editor, cx| {
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
});
let buffer = editor
.read(cx)
.buffer()
.read(cx)
.as_singleton()
.expect("message editor must be singleton");
cx.subscribe(&buffer, Self::on_buffer_event).detach();
let markdown = language_registry.language_for_name("Markdown");
cx.app_context()
.spawn(|mut cx| async move {
let markdown = markdown.await?;
buffer.update(&mut cx, |buffer, cx| {
buffer.set_language(Some(markdown), cx)
});
anyhow::Ok(())
})
.detach_and_log_err(cx);
Self {
editor,
channel_store,
users: HashMap::default(),
channel_id: None,
mentions: Vec::new(),
mentions_task: None,
}
}
pub fn set_channel(
&mut self,
channel_id: u64,
channel_name: Option<String>,
cx: &mut ViewContext<Self>,
) {
self.editor.update(cx, |editor, cx| {
if let Some(channel_name) = channel_name {
editor.set_placeholder_text(format!("Message #{}", channel_name), cx);
} else {
editor.set_placeholder_text(format!("Message Channel"), cx);
}
});
self.channel_id = Some(channel_id);
self.refresh_users(cx);
}
pub fn refresh_users(&mut self, cx: &mut ViewContext<Self>) {
if let Some(channel_id) = self.channel_id {
let members = self.channel_store.update(cx, |store, cx| {
store.get_channel_member_details(channel_id, cx)
});
cx.spawn(|this, mut cx| async move {
let members = members.await?;
this.update(&mut cx, |this, cx| this.set_members(members, cx))?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
}
pub fn set_members(&mut self, members: Vec<ChannelMembership>, _: &mut ViewContext<Self>) {
self.users.clear();
self.users.extend(
members
.into_iter()
.map(|member| (member.user.github_login.clone(), member.user.id)),
);
}
pub fn take_message(&mut self, cx: &mut ViewContext<Self>) -> MessageParams {
self.editor.update(cx, |editor, cx| {
let highlights = editor.text_highlights::<Self>(cx);
let text = editor.text(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
let mentions = if let Some((_, ranges)) = highlights {
ranges
.iter()
.map(|range| range.to_offset(&snapshot))
.zip(self.mentions.iter().copied())
.collect()
} else {
Vec::new()
};
editor.clear(cx);
self.mentions.clear();
MessageParams { text, mentions }
})
}
fn on_buffer_event(
&mut self,
buffer: ModelHandle<Buffer>,
event: &language::Event,
cx: &mut ViewContext<Self>,
) {
if let language::Event::Reparsed | language::Event::Edited = event {
let buffer = buffer.read(cx).snapshot();
self.mentions_task = Some(cx.spawn(|this, cx| async move {
cx.background().timer(MENTIONS_DEBOUNCE_INTERVAL).await;
Self::find_mentions(this, buffer, cx).await;
}));
}
}
async fn find_mentions(
this: WeakViewHandle<MessageEditor>,
buffer: BufferSnapshot,
mut cx: AsyncAppContext,
) {
let (buffer, ranges) = cx
.background()
.spawn(async move {
let ranges = MENTIONS_SEARCH.search(&buffer, None).await;
(buffer, ranges)
})
.await;
this.update(&mut cx, |this, cx| {
let mut anchor_ranges = Vec::new();
let mut mentioned_user_ids = Vec::new();
let mut text = String::new();
this.editor.update(cx, |editor, cx| {
let multi_buffer = editor.buffer().read(cx).snapshot(cx);
for range in ranges {
text.clear();
text.extend(buffer.text_for_range(range.clone()));
if let Some(username) = text.strip_prefix("@") {
if let Some(user_id) = this.users.get(username) {
let start = multi_buffer.anchor_after(range.start);
let end = multi_buffer.anchor_after(range.end);
mentioned_user_ids.push(*user_id);
anchor_ranges.push(start..end);
}
}
}
editor.clear_highlights::<Self>(cx);
editor.highlight_text::<Self>(
anchor_ranges,
theme::current(cx).chat_panel.rich_text.mention_highlight,
cx,
)
});
this.mentions = mentioned_user_ids;
this.mentions_task.take();
})
.ok();
}
}
impl Entity for MessageEditor {
type Event = ();
}
impl View for MessageEditor {
fn render(&mut self, cx: &mut ViewContext<'_, '_, Self>) -> AnyElement<Self> {
ChildView::new(&self.editor, cx).into_any()
}
fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
if cx.is_self_focused() {
cx.focus(&self.editor);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use client::{Client, User, UserStore};
use gpui::{TestAppContext, WindowHandle};
use language::{Language, LanguageConfig};
use rpc::proto;
use settings::SettingsStore;
use util::{http::FakeHttpClient, test::marked_text_ranges};
#[gpui::test]
async fn test_message_editor(cx: &mut TestAppContext) {
let editor = init_test(cx);
let editor = editor.root(cx);
editor.update(cx, |editor, cx| {
editor.set_members(
vec![
ChannelMembership {
user: Arc::new(User {
github_login: "a-b".into(),
id: 101,
avatar: None,
}),
kind: proto::channel_member::Kind::Member,
role: proto::ChannelRole::Member,
},
ChannelMembership {
user: Arc::new(User {
github_login: "C_D".into(),
id: 102,
avatar: None,
}),
kind: proto::channel_member::Kind::Member,
role: proto::ChannelRole::Member,
},
],
cx,
);
editor.editor.update(cx, |editor, cx| {
editor.set_text("Hello, @a-b! Have you met @C_D?", cx)
});
});
cx.foreground().advance_clock(MENTIONS_DEBOUNCE_INTERVAL);
editor.update(cx, |editor, cx| {
let (text, ranges) = marked_text_ranges("Hello, «@a-b»! Have you met «@C_D»?", false);
assert_eq!(
editor.take_message(cx),
MessageParams {
text,
mentions: vec![(ranges[0].clone(), 101), (ranges[1].clone(), 102)],
}
);
});
}
fn init_test(cx: &mut TestAppContext) -> WindowHandle<MessageEditor> {
cx.foreground().forbid_parking();
cx.update(|cx| {
let http = FakeHttpClient::with_404_response();
let client = Client::new(http.clone(), cx);
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
cx.set_global(SettingsStore::test(cx));
theme::init((), cx);
language::init(cx);
editor::init(cx);
client::init(&client, cx);
channel::init(&client, user_store, cx);
});
let language_registry = Arc::new(LanguageRegistry::test());
language_registry.add(Arc::new(Language::new(
LanguageConfig {
name: "Markdown".into(),
..Default::default()
},
Some(tree_sitter_markdown::language()),
)));
let editor = cx.add_window(|cx| {
MessageEditor::new(
language_registry,
ChannelStore::global(cx),
cx.add_view(|cx| Editor::auto_height(4, None, cx)),
cx,
)
});
cx.foreground().run_until_parked();
editor
}
}

View File

@ -3220,10 +3220,11 @@ impl CollabPanel {
accept: bool, accept: bool,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) { ) {
let respond = self.channel_store.update(cx, |store, _| { self.channel_store
store.respond_to_channel_invite(channel_id, accept) .update(cx, |store, cx| {
}); store.respond_to_channel_invite(channel_id, accept, cx)
cx.foreground().spawn(respond).detach(); })
.detach();
} }
fn call( fn call(
@ -3262,7 +3263,9 @@ impl CollabPanel {
workspace.update(cx, |workspace, cx| { workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.focus_panel::<ChatPanel>(cx) { if let Some(panel) = workspace.focus_panel::<ChatPanel>(cx) {
panel.update(cx, |panel, cx| { panel.update(cx, |panel, cx| {
panel.select_channel(channel_id, cx).detach_and_log_err(cx); panel
.select_channel(channel_id, None, cx)
.detach_and_log_err(cx);
}); });
} }
}); });

View File

@ -1,10 +1,10 @@
use crate::{ use crate::{
contact_notification::ContactNotification, face_pile::FacePile, toggle_deafen, toggle_mute, face_pile::FacePile, toggle_deafen, toggle_mute, toggle_screen_sharing, LeaveCall,
toggle_screen_sharing, LeaveCall, ToggleDeafen, ToggleMute, ToggleScreenSharing, ToggleDeafen, ToggleMute, ToggleScreenSharing,
}; };
use auto_update::AutoUpdateStatus; use auto_update::AutoUpdateStatus;
use call::{ActiveCall, ParticipantLocation, Room}; use call::{ActiveCall, ParticipantLocation, Room};
use client::{proto::PeerId, Client, ContactEventKind, SignIn, SignOut, User, UserStore}; use client::{proto::PeerId, Client, SignIn, SignOut, User, UserStore};
use clock::ReplicaId; use clock::ReplicaId;
use context_menu::{ContextMenu, ContextMenuItem}; use context_menu::{ContextMenu, ContextMenuItem};
use gpui::{ use gpui::{
@ -158,28 +158,6 @@ impl CollabTitlebarItem {
this.window_activation_changed(active, cx) this.window_activation_changed(active, cx)
})); }));
subscriptions.push(cx.observe(&user_store, |_, _, cx| cx.notify())); subscriptions.push(cx.observe(&user_store, |_, _, cx| cx.notify()));
subscriptions.push(
cx.subscribe(&user_store, move |this, user_store, event, cx| {
if let Some(workspace) = this.workspace.upgrade(cx) {
workspace.update(cx, |workspace, cx| {
if let client::Event::Contact { user, kind } = event {
if let ContactEventKind::Requested | ContactEventKind::Accepted = kind {
workspace.show_notification(user.id as usize, cx, |cx| {
cx.add_view(|cx| {
ContactNotification::new(
user.clone(),
*kind,
user_store,
cx,
)
})
})
}
}
});
}
}),
);
Self { Self {
workspace: workspace.weak_handle(), workspace: workspace.weak_handle(),
@ -495,7 +473,11 @@ impl CollabTitlebarItem {
pub fn toggle_vcs_menu(&mut self, _: &ToggleVcsMenu, cx: &mut ViewContext<Self>) { pub fn toggle_vcs_menu(&mut self, _: &ToggleVcsMenu, cx: &mut ViewContext<Self>) {
if self.branch_popover.take().is_none() { if self.branch_popover.take().is_none() {
if let Some(workspace) = self.workspace.upgrade(cx) { if let Some(workspace) = self.workspace.upgrade(cx) {
let view = cx.add_view(|cx| build_branch_list(workspace, cx)); let Some(view) =
cx.add_option_view(|cx| build_branch_list(workspace, cx).log_err())
else {
return;
};
cx.subscribe(&view, |this, _, event, cx| { cx.subscribe(&view, |this, _, event, cx| {
match event { match event {
PickerEvent::Dismiss => { PickerEvent::Dismiss => {

View File

@ -2,30 +2,32 @@ pub mod channel_view;
pub mod chat_panel; pub mod chat_panel;
pub mod collab_panel; pub mod collab_panel;
mod collab_titlebar_item; mod collab_titlebar_item;
mod contact_notification;
mod face_pile; mod face_pile;
mod incoming_call_notification; pub mod notification_panel;
mod notifications; pub mod notifications;
mod panel_settings; mod panel_settings;
pub mod project_shared_notification;
mod sharing_status_indicator;
use call::{report_call_event_for_room, ActiveCall, Room}; use call::{report_call_event_for_room, ActiveCall, Room};
use feature_flags::{ChannelsAlpha, FeatureFlagAppExt};
use gpui::{ use gpui::{
actions, actions,
elements::{ContainerStyle, Empty, Image},
geometry::{ geometry::{
rect::RectF, rect::RectF,
vector::{vec2f, Vector2F}, vector::{vec2f, Vector2F},
}, },
platform::{Screen, WindowBounds, WindowKind, WindowOptions}, platform::{Screen, WindowBounds, WindowKind, WindowOptions},
AppContext, Task, AnyElement, AppContext, Element, ImageData, Task,
}; };
use std::{rc::Rc, sync::Arc}; use std::{rc::Rc, sync::Arc};
use theme::AvatarStyle;
use util::ResultExt; use util::ResultExt;
use workspace::AppState; use workspace::AppState;
pub use collab_titlebar_item::CollabTitlebarItem; pub use collab_titlebar_item::CollabTitlebarItem;
pub use panel_settings::{ChatPanelSettings, CollaborationPanelSettings}; pub use panel_settings::{
ChatPanelSettings, CollaborationPanelSettings, NotificationPanelSettings,
};
actions!( actions!(
collab, collab,
@ -35,14 +37,13 @@ actions!(
pub fn init(app_state: &Arc<AppState>, cx: &mut AppContext) { pub fn init(app_state: &Arc<AppState>, cx: &mut AppContext) {
settings::register::<CollaborationPanelSettings>(cx); settings::register::<CollaborationPanelSettings>(cx);
settings::register::<ChatPanelSettings>(cx); settings::register::<ChatPanelSettings>(cx);
settings::register::<NotificationPanelSettings>(cx);
vcs_menu::init(cx); vcs_menu::init(cx);
collab_titlebar_item::init(cx); collab_titlebar_item::init(cx);
collab_panel::init(cx); collab_panel::init(cx);
chat_panel::init(cx); chat_panel::init(cx);
incoming_call_notification::init(&app_state, cx); notifications::init(&app_state, cx);
project_shared_notification::init(&app_state, cx);
sharing_status_indicator::init(cx);
cx.add_global_action(toggle_screen_sharing); cx.add_global_action(toggle_screen_sharing);
cx.add_global_action(toggle_mute); cx.add_global_action(toggle_mute);
@ -130,3 +131,35 @@ fn notification_window_options(
screen: Some(screen), screen: Some(screen),
} }
} }
fn render_avatar<T: 'static>(
avatar: Option<Arc<ImageData>>,
avatar_style: &AvatarStyle,
container: ContainerStyle,
) -> AnyElement<T> {
avatar
.map(|avatar| {
Image::from_data(avatar)
.with_style(avatar_style.image)
.aligned()
.contained()
.with_corner_radius(avatar_style.outer_corner_radius)
.constrained()
.with_width(avatar_style.outer_width)
.with_height(avatar_style.outer_width)
.into_any()
})
.unwrap_or_else(|| {
Empty::new()
.constrained()
.with_width(avatar_style.outer_width)
.into_any()
})
.contained()
.with_style(container)
.into_any()
}
fn is_channels_feature_enabled(cx: &gpui::WindowContext<'_>) -> bool {
cx.is_staff() || cx.has_flag::<ChannelsAlpha>()
}

View File

@ -1,121 +0,0 @@
use std::sync::Arc;
use crate::notifications::render_user_notification;
use client::{ContactEventKind, User, UserStore};
use gpui::{elements::*, Entity, ModelHandle, View, ViewContext};
use workspace::notifications::Notification;
pub struct ContactNotification {
user_store: ModelHandle<UserStore>,
user: Arc<User>,
kind: client::ContactEventKind,
}
#[derive(Clone, PartialEq)]
struct Dismiss(u64);
#[derive(Clone, PartialEq)]
pub struct RespondToContactRequest {
pub user_id: u64,
pub accept: bool,
}
pub enum Event {
Dismiss,
}
impl Entity for ContactNotification {
type Event = Event;
}
impl View for ContactNotification {
fn ui_name() -> &'static str {
"ContactNotification"
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
match self.kind {
ContactEventKind::Requested => render_user_notification(
self.user.clone(),
"wants to add you as a contact",
Some("They won't be alerted if you decline."),
|notification, cx| notification.dismiss(cx),
vec![
(
"Decline",
Box::new(|notification, cx| {
notification.respond_to_contact_request(false, cx)
}),
),
(
"Accept",
Box::new(|notification, cx| {
notification.respond_to_contact_request(true, cx)
}),
),
],
cx,
),
ContactEventKind::Accepted => render_user_notification(
self.user.clone(),
"accepted your contact request",
None,
|notification, cx| notification.dismiss(cx),
vec![],
cx,
),
_ => unreachable!(),
}
}
}
impl Notification for ContactNotification {
fn should_dismiss_notification_on_event(&self, event: &<Self as Entity>::Event) -> bool {
matches!(event, Event::Dismiss)
}
}
impl ContactNotification {
pub fn new(
user: Arc<User>,
kind: client::ContactEventKind,
user_store: ModelHandle<UserStore>,
cx: &mut ViewContext<Self>,
) -> Self {
cx.subscribe(&user_store, move |this, _, event, cx| {
if let client::Event::Contact {
kind: ContactEventKind::Cancelled,
user,
} = event
{
if user.id == this.user.id {
cx.emit(Event::Dismiss);
}
}
})
.detach();
Self {
user,
kind,
user_store,
}
}
fn dismiss(&mut self, cx: &mut ViewContext<Self>) {
self.user_store.update(cx, |store, cx| {
store
.dismiss_contact_request(self.user.id, cx)
.detach_and_log_err(cx);
});
cx.emit(Event::Dismiss);
}
fn respond_to_contact_request(&mut self, accept: bool, cx: &mut ViewContext<Self>) {
self.user_store
.update(cx, |store, cx| {
store.respond_to_contact_request(self.user.id, accept, cx)
})
.detach();
}
}

View File

@ -0,0 +1,884 @@
use crate::{chat_panel::ChatPanel, render_avatar, NotificationPanelSettings};
use anyhow::Result;
use channel::ChannelStore;
use client::{Client, Notification, User, UserStore};
use collections::HashMap;
use db::kvp::KEY_VALUE_STORE;
use futures::StreamExt;
use gpui::{
actions,
elements::*,
platform::{CursorStyle, MouseButton},
serde_json, AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View,
ViewContext, ViewHandle, WeakViewHandle, WindowContext,
};
use notifications::{NotificationEntry, NotificationEvent, NotificationStore};
use project::Fs;
use rpc::proto;
use serde::{Deserialize, Serialize};
use settings::SettingsStore;
use std::{sync::Arc, time::Duration};
use theme::{ui, Theme};
use time::{OffsetDateTime, UtcOffset};
use util::{ResultExt, TryFutureExt};
use workspace::{
dock::{DockPosition, Panel},
Workspace,
};
const LOADING_THRESHOLD: usize = 30;
const MARK_AS_READ_DELAY: Duration = Duration::from_secs(1);
const TOAST_DURATION: Duration = Duration::from_secs(5);
const NOTIFICATION_PANEL_KEY: &'static str = "NotificationPanel";
pub struct NotificationPanel {
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
channel_store: ModelHandle<ChannelStore>,
notification_store: ModelHandle<NotificationStore>,
fs: Arc<dyn Fs>,
width: Option<f32>,
active: bool,
notification_list: ListState<Self>,
pending_serialization: Task<Option<()>>,
subscriptions: Vec<gpui::Subscription>,
workspace: WeakViewHandle<Workspace>,
current_notification_toast: Option<(u64, Task<()>)>,
local_timezone: UtcOffset,
has_focus: bool,
mark_as_read_tasks: HashMap<u64, Task<Result<()>>>,
}
#[derive(Serialize, Deserialize)]
struct SerializedNotificationPanel {
width: Option<f32>,
}
#[derive(Debug)]
pub enum Event {
DockPositionChanged,
Focus,
Dismissed,
}
pub struct NotificationPresenter {
pub actor: Option<Arc<client::User>>,
pub text: String,
pub icon: &'static str,
pub needs_response: bool,
pub can_navigate: bool,
}
actions!(notification_panel, [ToggleFocus]);
pub fn init(_cx: &mut AppContext) {}
impl NotificationPanel {
pub fn new(workspace: &mut Workspace, cx: &mut ViewContext<Workspace>) -> ViewHandle<Self> {
let fs = workspace.app_state().fs.clone();
let client = workspace.app_state().client.clone();
let user_store = workspace.app_state().user_store.clone();
let workspace_handle = workspace.weak_handle();
cx.add_view(|cx| {
let mut status = client.status();
cx.spawn(|this, mut cx| async move {
while let Some(_) = status.next().await {
if this
.update(&mut cx, |_, cx| {
cx.notify();
})
.is_err()
{
break;
}
}
})
.detach();
let mut notification_list =
ListState::<Self>::new(0, Orientation::Top, 1000., move |this, ix, cx| {
this.render_notification(ix, cx)
.unwrap_or_else(|| Empty::new().into_any())
});
notification_list.set_scroll_handler(|visible_range, count, this, cx| {
if count.saturating_sub(visible_range.end) < LOADING_THRESHOLD {
if let Some(task) = this
.notification_store
.update(cx, |store, cx| store.load_more_notifications(false, cx))
{
task.detach();
}
}
});
let mut this = Self {
fs,
client,
user_store,
local_timezone: cx.platform().local_timezone(),
channel_store: ChannelStore::global(cx),
notification_store: NotificationStore::global(cx),
notification_list,
pending_serialization: Task::ready(None),
workspace: workspace_handle,
has_focus: false,
current_notification_toast: None,
subscriptions: Vec::new(),
active: false,
mark_as_read_tasks: HashMap::default(),
width: None,
};
let mut old_dock_position = this.position(cx);
this.subscriptions.extend([
cx.observe(&this.notification_store, |_, _, cx| cx.notify()),
cx.subscribe(&this.notification_store, Self::on_notification_event),
cx.observe_global::<SettingsStore, _>(move |this: &mut Self, cx| {
let new_dock_position = this.position(cx);
if new_dock_position != old_dock_position {
old_dock_position = new_dock_position;
cx.emit(Event::DockPositionChanged);
}
cx.notify();
}),
]);
this
})
}
pub fn load(
workspace: WeakViewHandle<Workspace>,
cx: AsyncAppContext,
) -> Task<Result<ViewHandle<Self>>> {
cx.spawn(|mut cx| async move {
let serialized_panel = if let Some(panel) = cx
.background()
.spawn(async move { KEY_VALUE_STORE.read_kvp(NOTIFICATION_PANEL_KEY) })
.await
.log_err()
.flatten()
{
Some(serde_json::from_str::<SerializedNotificationPanel>(&panel)?)
} else {
None
};
workspace.update(&mut cx, |workspace, cx| {
let panel = Self::new(workspace, cx);
if let Some(serialized_panel) = serialized_panel {
panel.update(cx, |panel, cx| {
panel.width = serialized_panel.width;
cx.notify();
});
}
panel
})
})
}
fn serialize(&mut self, cx: &mut ViewContext<Self>) {
let width = self.width;
self.pending_serialization = cx.background().spawn(
async move {
KEY_VALUE_STORE
.write_kvp(
NOTIFICATION_PANEL_KEY.into(),
serde_json::to_string(&SerializedNotificationPanel { width })?,
)
.await?;
anyhow::Ok(())
}
.log_err(),
);
}
fn render_notification(
&mut self,
ix: usize,
cx: &mut ViewContext<Self>,
) -> Option<AnyElement<Self>> {
let entry = self.notification_store.read(cx).notification_at(ix)?;
let notification_id = entry.id;
let now = OffsetDateTime::now_utc();
let timestamp = entry.timestamp;
let NotificationPresenter {
actor,
text,
needs_response,
can_navigate,
..
} = self.present_notification(entry, cx)?;
let theme = theme::current(cx);
let style = &theme.notification_panel;
let response = entry.response;
let notification = entry.notification.clone();
let message_style = if entry.is_read {
style.read_text.clone()
} else {
style.unread_text.clone()
};
if self.active && !entry.is_read {
self.did_render_notification(notification_id, &notification, cx);
}
enum Decline {}
enum Accept {}
Some(
MouseEventHandler::new::<NotificationEntry, _>(ix, cx, |_, cx| {
let container = message_style.container;
Flex::row()
.with_children(actor.map(|actor| {
render_avatar(actor.avatar.clone(), &style.avatar, style.avatar_container)
}))
.with_child(
Flex::column()
.with_child(Text::new(text, message_style.text.clone()))
.with_child(
Flex::row()
.with_child(
Label::new(
format_timestamp(timestamp, now, self.local_timezone),
style.timestamp.text.clone(),
)
.contained()
.with_style(style.timestamp.container),
)
.with_children(if let Some(is_accepted) = response {
Some(
Label::new(
if is_accepted {
"You accepted"
} else {
"You declined"
},
style.read_text.text.clone(),
)
.flex_float()
.into_any(),
)
} else if needs_response {
Some(
Flex::row()
.with_children([
MouseEventHandler::new::<Decline, _>(
ix,
cx,
|state, _| {
let button =
style.button.style_for(state);
Label::new(
"Decline",
button.text.clone(),
)
.contained()
.with_style(button.container)
},
)
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, {
let notification = notification.clone();
move |_, view, cx| {
view.respond_to_notification(
notification.clone(),
false,
cx,
);
}
}),
MouseEventHandler::new::<Accept, _>(
ix,
cx,
|state, _| {
let button =
style.button.style_for(state);
Label::new(
"Accept",
button.text.clone(),
)
.contained()
.with_style(button.container)
},
)
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, {
let notification = notification.clone();
move |_, view, cx| {
view.respond_to_notification(
notification.clone(),
true,
cx,
);
}
}),
])
.flex_float()
.into_any(),
)
} else {
None
}),
)
.flex(1.0, true),
)
.contained()
.with_style(container)
.into_any()
})
.with_cursor_style(if can_navigate {
CursorStyle::PointingHand
} else {
CursorStyle::default()
})
.on_click(MouseButton::Left, {
let notification = notification.clone();
move |_, this, cx| this.did_click_notification(&notification, cx)
})
.into_any(),
)
}
fn present_notification(
&self,
entry: &NotificationEntry,
cx: &AppContext,
) -> Option<NotificationPresenter> {
let user_store = self.user_store.read(cx);
let channel_store = self.channel_store.read(cx);
match entry.notification {
Notification::ContactRequest { sender_id } => {
let requester = user_store.get_cached_user(sender_id)?;
Some(NotificationPresenter {
icon: "icons/plus.svg",
text: format!("{} wants to add you as a contact", requester.github_login),
needs_response: user_store.has_incoming_contact_request(requester.id),
actor: Some(requester),
can_navigate: false,
})
}
Notification::ContactRequestAccepted { responder_id } => {
let responder = user_store.get_cached_user(responder_id)?;
Some(NotificationPresenter {
icon: "icons/plus.svg",
text: format!("{} accepted your contact invite", responder.github_login),
needs_response: false,
actor: Some(responder),
can_navigate: false,
})
}
Notification::ChannelInvitation {
ref channel_name,
channel_id,
inviter_id,
} => {
let inviter = user_store.get_cached_user(inviter_id)?;
Some(NotificationPresenter {
icon: "icons/hash.svg",
text: format!(
"{} invited you to join the #{channel_name} channel",
inviter.github_login
),
needs_response: channel_store.has_channel_invitation(channel_id),
actor: Some(inviter),
can_navigate: false,
})
}
Notification::ChannelMessageMention {
sender_id,
channel_id,
message_id,
} => {
let sender = user_store.get_cached_user(sender_id)?;
let channel = channel_store.channel_for_id(channel_id)?;
let message = self
.notification_store
.read(cx)
.channel_message_for_id(message_id)?;
Some(NotificationPresenter {
icon: "icons/conversations.svg",
text: format!(
"{} mentioned you in #{}:\n{}",
sender.github_login, channel.name, message.body,
),
needs_response: false,
actor: Some(sender),
can_navigate: true,
})
}
}
}
fn did_render_notification(
&mut self,
notification_id: u64,
notification: &Notification,
cx: &mut ViewContext<Self>,
) {
let should_mark_as_read = match notification {
Notification::ContactRequestAccepted { .. } => true,
Notification::ContactRequest { .. }
| Notification::ChannelInvitation { .. }
| Notification::ChannelMessageMention { .. } => false,
};
if should_mark_as_read {
self.mark_as_read_tasks
.entry(notification_id)
.or_insert_with(|| {
let client = self.client.clone();
cx.spawn(|this, mut cx| async move {
cx.background().timer(MARK_AS_READ_DELAY).await;
client
.request(proto::MarkNotificationRead { notification_id })
.await?;
this.update(&mut cx, |this, _| {
this.mark_as_read_tasks.remove(&notification_id);
})?;
Ok(())
})
});
}
}
fn did_click_notification(&mut self, notification: &Notification, cx: &mut ViewContext<Self>) {
if let Notification::ChannelMessageMention {
message_id,
channel_id,
..
} = notification.clone()
{
if let Some(workspace) = self.workspace.upgrade(cx) {
cx.app_context().defer(move |cx| {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.focus_panel::<ChatPanel>(cx) {
panel.update(cx, |panel, cx| {
panel
.select_channel(channel_id, Some(message_id), cx)
.detach_and_log_err(cx);
});
}
});
});
}
}
}
fn is_showing_notification(&self, notification: &Notification, cx: &AppContext) -> bool {
if let Notification::ChannelMessageMention { channel_id, .. } = &notification {
if let Some(workspace) = self.workspace.upgrade(cx) {
return workspace
.read_with(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<ChatPanel>(cx) {
return panel.read_with(cx, |panel, cx| {
panel.is_scrolled_to_bottom()
&& panel.active_chat().map_or(false, |chat| {
chat.read(cx).channel_id == *channel_id
})
});
}
false
})
.unwrap_or_default();
}
}
false
}
fn render_sign_in_prompt(
&self,
theme: &Arc<Theme>,
cx: &mut ViewContext<Self>,
) -> AnyElement<Self> {
enum SignInPromptLabel {}
MouseEventHandler::new::<SignInPromptLabel, _>(0, cx, |mouse_state, _| {
Label::new(
"Sign in to view your notifications".to_string(),
theme
.chat_panel
.sign_in_prompt
.style_for(mouse_state)
.clone(),
)
})
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| {
let client = this.client.clone();
cx.spawn(|_, cx| async move {
client.authenticate_and_connect(true, &cx).log_err().await;
})
.detach();
})
.aligned()
.into_any()
}
fn render_empty_state(
&self,
theme: &Arc<Theme>,
_cx: &mut ViewContext<Self>,
) -> AnyElement<Self> {
Label::new(
"You have no notifications".to_string(),
theme.chat_panel.sign_in_prompt.default.clone(),
)
.aligned()
.into_any()
}
fn on_notification_event(
&mut self,
_: ModelHandle<NotificationStore>,
event: &NotificationEvent,
cx: &mut ViewContext<Self>,
) {
match event {
NotificationEvent::NewNotification { entry } => self.add_toast(entry, cx),
NotificationEvent::NotificationRemoved { entry }
| NotificationEvent::NotificationRead { entry } => self.remove_toast(entry.id, cx),
NotificationEvent::NotificationsUpdated {
old_range,
new_count,
} => {
self.notification_list.splice(old_range.clone(), *new_count);
cx.notify();
}
}
}
fn add_toast(&mut self, entry: &NotificationEntry, cx: &mut ViewContext<Self>) {
if self.is_showing_notification(&entry.notification, cx) {
return;
}
let Some(NotificationPresenter { actor, text, .. }) = self.present_notification(entry, cx)
else {
return;
};
let notification_id = entry.id;
self.current_notification_toast = Some((
notification_id,
cx.spawn(|this, mut cx| async move {
cx.background().timer(TOAST_DURATION).await;
this.update(&mut cx, |this, cx| this.remove_toast(notification_id, cx))
.ok();
}),
));
self.workspace
.update(cx, |workspace, cx| {
workspace.dismiss_notification::<NotificationToast>(0, cx);
workspace.show_notification(0, cx, |cx| {
let workspace = cx.weak_handle();
cx.add_view(|_| NotificationToast {
notification_id,
actor,
text,
workspace,
})
})
})
.ok();
}
fn remove_toast(&mut self, notification_id: u64, cx: &mut ViewContext<Self>) {
if let Some((current_id, _)) = &self.current_notification_toast {
if *current_id == notification_id {
self.current_notification_toast.take();
self.workspace
.update(cx, |workspace, cx| {
workspace.dismiss_notification::<NotificationToast>(0, cx)
})
.ok();
}
}
}
fn respond_to_notification(
&mut self,
notification: Notification,
response: bool,
cx: &mut ViewContext<Self>,
) {
self.notification_store.update(cx, |store, cx| {
store.respond_to_notification(notification, response, cx);
});
}
}
impl Entity for NotificationPanel {
type Event = Event;
}
impl View for NotificationPanel {
fn ui_name() -> &'static str {
"NotificationPanel"
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let theme = theme::current(cx);
let style = &theme.notification_panel;
let element = if self.client.user_id().is_none() {
self.render_sign_in_prompt(&theme, cx)
} else if self.notification_list.item_count() == 0 {
self.render_empty_state(&theme, cx)
} else {
Flex::column()
.with_child(
Flex::row()
.with_child(Label::new("Notifications", style.title.text.clone()))
.with_child(ui::svg(&style.title_icon).flex_float())
.align_children_center()
.contained()
.with_style(style.title.container)
.constrained()
.with_height(style.title_height),
)
.with_child(
List::new(self.notification_list.clone())
.contained()
.with_style(style.list)
.flex(1., true),
)
.into_any()
};
element
.contained()
.with_style(style.container)
.constrained()
.with_min_width(150.)
.into_any()
}
fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
self.has_focus = true;
}
fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
self.has_focus = false;
}
}
impl Panel for NotificationPanel {
fn position(&self, cx: &gpui::WindowContext) -> DockPosition {
settings::get::<NotificationPanelSettings>(cx).dock
}
fn position_is_valid(&self, position: DockPosition) -> bool {
matches!(position, DockPosition::Left | DockPosition::Right)
}
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
settings::update_settings_file::<NotificationPanelSettings>(
self.fs.clone(),
cx,
move |settings| settings.dock = Some(position),
);
}
fn size(&self, cx: &gpui::WindowContext) -> f32 {
self.width
.unwrap_or_else(|| settings::get::<NotificationPanelSettings>(cx).default_width)
}
fn set_size(&mut self, size: Option<f32>, cx: &mut ViewContext<Self>) {
self.width = size;
self.serialize(cx);
cx.notify();
}
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
self.active = active;
if self.notification_store.read(cx).notification_count() == 0 {
cx.emit(Event::Dismissed);
}
}
fn icon_path(&self, cx: &gpui::WindowContext) -> Option<&'static str> {
(settings::get::<NotificationPanelSettings>(cx).button
&& self.notification_store.read(cx).notification_count() > 0)
.then(|| "icons/bell.svg")
}
fn icon_tooltip(&self) -> (String, Option<Box<dyn gpui::Action>>) {
(
"Notification Panel".to_string(),
Some(Box::new(ToggleFocus)),
)
}
fn icon_label(&self, cx: &WindowContext) -> Option<String> {
let count = self.notification_store.read(cx).unread_notification_count();
if count == 0 {
None
} else {
Some(count.to_string())
}
}
fn should_change_position_on_event(event: &Self::Event) -> bool {
matches!(event, Event::DockPositionChanged)
}
fn should_close_on_event(event: &Self::Event) -> bool {
matches!(event, Event::Dismissed)
}
fn has_focus(&self, _cx: &gpui::WindowContext) -> bool {
self.has_focus
}
fn is_focus_event(event: &Self::Event) -> bool {
matches!(event, Event::Focus)
}
}
pub struct NotificationToast {
notification_id: u64,
actor: Option<Arc<User>>,
text: String,
workspace: WeakViewHandle<Workspace>,
}
pub enum ToastEvent {
Dismiss,
}
impl NotificationToast {
fn focus_notification_panel(&self, cx: &mut AppContext) {
let workspace = self.workspace.clone();
let notification_id = self.notification_id;
cx.defer(move |cx| {
workspace
.update(cx, |workspace, cx| {
if let Some(panel) = workspace.focus_panel::<NotificationPanel>(cx) {
panel.update(cx, |panel, cx| {
let store = panel.notification_store.read(cx);
if let Some(entry) = store.notification_for_id(notification_id) {
panel.did_click_notification(&entry.clone().notification, cx);
}
});
}
})
.ok();
})
}
}
impl Entity for NotificationToast {
type Event = ToastEvent;
}
impl View for NotificationToast {
fn ui_name() -> &'static str {
"ContactNotification"
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let user = self.actor.clone();
let theme = theme::current(cx).clone();
let theme = &theme.contact_notification;
MouseEventHandler::new::<Self, _>(0, cx, |_, cx| {
Flex::row()
.with_children(user.and_then(|user| {
Some(
Image::from_data(user.avatar.clone()?)
.with_style(theme.header_avatar)
.aligned()
.constrained()
.with_height(
cx.font_cache()
.line_height(theme.header_message.text.font_size),
)
.aligned()
.top(),
)
}))
.with_child(
Text::new(self.text.clone(), theme.header_message.text.clone())
.contained()
.with_style(theme.header_message.container)
.aligned()
.top()
.left()
.flex(1., true),
)
.with_child(
MouseEventHandler::new::<ToastEvent, _>(0, cx, |state, _| {
let style = theme.dismiss_button.style_for(state);
Svg::new("icons/x.svg")
.with_color(style.color)
.constrained()
.with_width(style.icon_width)
.aligned()
.contained()
.with_style(style.container)
.constrained()
.with_width(style.button_width)
.with_height(style.button_width)
})
.with_cursor_style(CursorStyle::PointingHand)
.with_padding(Padding::uniform(5.))
.on_click(MouseButton::Left, move |_, _, cx| {
cx.emit(ToastEvent::Dismiss)
})
.aligned()
.constrained()
.with_height(
cx.font_cache()
.line_height(theme.header_message.text.font_size),
)
.aligned()
.top()
.flex_float(),
)
.contained()
})
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| {
this.focus_notification_panel(cx);
cx.emit(ToastEvent::Dismiss);
})
.into_any()
}
}
impl workspace::notifications::Notification for NotificationToast {
fn should_dismiss_notification_on_event(&self, event: &<Self as Entity>::Event) -> bool {
matches!(event, ToastEvent::Dismiss)
}
}
fn format_timestamp(
mut timestamp: OffsetDateTime,
mut now: OffsetDateTime,
local_timezone: UtcOffset,
) -> String {
timestamp = timestamp.to_offset(local_timezone);
now = now.to_offset(local_timezone);
let today = now.date();
let date = timestamp.date();
if date == today {
let difference = now - timestamp;
if difference >= Duration::from_secs(3600) {
format!("{}h", difference.whole_seconds() / 3600)
} else if difference >= Duration::from_secs(60) {
format!("{}m", difference.whole_seconds() / 60)
} else {
"just now".to_string()
}
} else if date.next_day() == Some(today) {
format!("yesterday")
} else {
format!("{:02}/{}/{}", date.month() as u32, date.day(), date.year())
}
}

View File

@ -1,110 +1,11 @@
use client::User; use gpui::AppContext;
use gpui::{
elements::*,
platform::{CursorStyle, MouseButton},
AnyElement, Element, ViewContext,
};
use std::sync::Arc; use std::sync::Arc;
use workspace::AppState;
enum Dismiss {} pub mod incoming_call_notification;
enum Button {} pub mod project_shared_notification;
pub fn render_user_notification<F, V: 'static>( pub fn init(app_state: &Arc<AppState>, cx: &mut AppContext) {
user: Arc<User>, incoming_call_notification::init(app_state, cx);
title: &'static str, project_shared_notification::init(app_state, cx);
body: Option<&'static str>,
on_dismiss: F,
buttons: Vec<(&'static str, Box<dyn Fn(&mut V, &mut ViewContext<V>)>)>,
cx: &mut ViewContext<V>,
) -> AnyElement<V>
where
F: 'static + Fn(&mut V, &mut ViewContext<V>),
{
let theme = theme::current(cx).clone();
let theme = &theme.contact_notification;
Flex::column()
.with_child(
Flex::row()
.with_children(user.avatar.clone().map(|avatar| {
Image::from_data(avatar)
.with_style(theme.header_avatar)
.aligned()
.constrained()
.with_height(
cx.font_cache()
.line_height(theme.header_message.text.font_size),
)
.aligned()
.top()
}))
.with_child(
Text::new(
format!("{} {}", user.github_login, title),
theme.header_message.text.clone(),
)
.contained()
.with_style(theme.header_message.container)
.aligned()
.top()
.left()
.flex(1., true),
)
.with_child(
MouseEventHandler::new::<Dismiss, _>(user.id as usize, cx, |state, _| {
let style = theme.dismiss_button.style_for(state);
Svg::new("icons/x.svg")
.with_color(style.color)
.constrained()
.with_width(style.icon_width)
.aligned()
.contained()
.with_style(style.container)
.constrained()
.with_width(style.button_width)
.with_height(style.button_width)
})
.with_cursor_style(CursorStyle::PointingHand)
.with_padding(Padding::uniform(5.))
.on_click(MouseButton::Left, move |_, view, cx| on_dismiss(view, cx))
.aligned()
.constrained()
.with_height(
cx.font_cache()
.line_height(theme.header_message.text.font_size),
)
.aligned()
.top()
.flex_float(),
)
.into_any_named("contact notification header"),
)
.with_children(body.map(|body| {
Label::new(body, theme.body_message.text.clone())
.contained()
.with_style(theme.body_message.container)
}))
.with_children(if buttons.is_empty() {
None
} else {
Some(
Flex::row()
.with_children(buttons.into_iter().enumerate().map(
|(ix, (message, handler))| {
MouseEventHandler::new::<Button, _>(ix, cx, |state, _| {
let button = theme.button.style_for(state);
Label::new(message, button.text.clone())
.contained()
.with_style(button.container)
})
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, view, cx| handler(view, cx))
},
))
.aligned()
.right(),
)
})
.contained()
.into_any()
} }

View File

@ -18,6 +18,13 @@ pub struct ChatPanelSettings {
pub default_width: f32, pub default_width: f32,
} }
#[derive(Deserialize, Debug)]
pub struct NotificationPanelSettings {
pub button: bool,
pub dock: DockPosition,
pub default_width: f32,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] #[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct PanelSettingsContent { pub struct PanelSettingsContent {
pub button: Option<bool>, pub button: Option<bool>,
@ -27,9 +34,7 @@ pub struct PanelSettingsContent {
impl Setting for CollaborationPanelSettings { impl Setting for CollaborationPanelSettings {
const KEY: Option<&'static str> = Some("collaboration_panel"); const KEY: Option<&'static str> = Some("collaboration_panel");
type FileContent = PanelSettingsContent; type FileContent = PanelSettingsContent;
fn load( fn load(
default_value: &Self::FileContent, default_value: &Self::FileContent,
user_values: &[&Self::FileContent], user_values: &[&Self::FileContent],
@ -41,9 +46,19 @@ impl Setting for CollaborationPanelSettings {
impl Setting for ChatPanelSettings { impl Setting for ChatPanelSettings {
const KEY: Option<&'static str> = Some("chat_panel"); const KEY: Option<&'static str> = Some("chat_panel");
type FileContent = PanelSettingsContent; type FileContent = PanelSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}
impl Setting for NotificationPanelSettings {
const KEY: Option<&'static str> = Some("notification_panel");
type FileContent = PanelSettingsContent;
fn load( fn load(
default_value: &Self::FileContent, default_value: &Self::FileContent,
user_values: &[&Self::FileContent], user_values: &[&Self::FileContent],

View File

@ -1,62 +0,0 @@
use crate::toggle_screen_sharing;
use call::ActiveCall;
use gpui::{
color::Color,
elements::{MouseEventHandler, Svg},
platform::{Appearance, MouseButton},
AnyElement, AppContext, Element, Entity, View, ViewContext,
};
use workspace::WorkspaceSettings;
pub fn init(cx: &mut AppContext) {
let active_call = ActiveCall::global(cx);
let mut status_indicator = None;
cx.observe(&active_call, move |call, cx| {
if let Some(room) = call.read(cx).room() {
if room.read(cx).is_screen_sharing() {
if status_indicator.is_none()
&& settings::get::<WorkspaceSettings>(cx).show_call_status_icon
{
status_indicator = Some(cx.add_status_bar_item(|_| SharingStatusIndicator));
}
} else if let Some(window) = status_indicator.take() {
window.update(cx, |cx| cx.remove_window());
}
} else if let Some(window) = status_indicator.take() {
window.update(cx, |cx| cx.remove_window());
}
})
.detach();
}
pub struct SharingStatusIndicator;
impl Entity for SharingStatusIndicator {
type Event = ();
}
impl View for SharingStatusIndicator {
fn ui_name() -> &'static str {
"SharingStatusIndicator"
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let color = match cx.window_appearance() {
Appearance::Light | Appearance::VibrantLight => Color::black(),
Appearance::Dark | Appearance::VibrantDark => Color::white(),
};
MouseEventHandler::new::<Self, _>(0, cx, |_, _| {
Svg::new("icons/desktop.svg")
.with_color(color)
.constrained()
.with_width(18.)
.aligned()
})
.on_click(MouseButton::Left, |_, _, cx| {
toggle_screen_sharing(&Default::default(), cx)
})
.into_any()
}
}

View File

@ -5,22 +5,24 @@ mod tab_map;
mod wrap_map; mod wrap_map;
use crate::{ use crate::{
link_go_to_definition::InlayHighlight, Anchor, AnchorRangeExt, InlayId, MultiBuffer, link_go_to_definition::InlayHighlight, movement::TextLayoutDetails, Anchor, AnchorRangeExt,
MultiBufferSnapshot, ToOffset, ToPoint, EditorStyle, InlayId, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
}; };
pub use block_map::{BlockMap, BlockPoint}; pub use block_map::{BlockMap, BlockPoint};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use fold_map::FoldMap; use fold_map::FoldMap;
use gpui::{ use gpui::{
color::Color, color::Color,
fonts::{FontId, HighlightStyle}, fonts::{FontId, HighlightStyle, Underline},
text_layout::{Line, RunStyle},
Entity, ModelContext, ModelHandle, Entity, ModelContext, ModelHandle,
}; };
use inlay_map::InlayMap; use inlay_map::InlayMap;
use language::{ use language::{
language_settings::language_settings, OffsetUtf16, Point, Subscription as BufferSubscription, language_settings::language_settings, OffsetUtf16, Point, Subscription as BufferSubscription,
}; };
use std::{any::TypeId, fmt::Debug, num::NonZeroU32, ops::Range, sync::Arc}; use lsp::DiagnosticSeverity;
use std::{any::TypeId, borrow::Cow, fmt::Debug, num::NonZeroU32, ops::Range, sync::Arc};
use sum_tree::{Bias, TreeMap}; use sum_tree::{Bias, TreeMap};
use tab_map::TabMap; use tab_map::TabMap;
use wrap_map::WrapMap; use wrap_map::WrapMap;
@ -316,6 +318,12 @@ pub struct Highlights<'a> {
pub suggestion_highlight_style: Option<HighlightStyle>, pub suggestion_highlight_style: Option<HighlightStyle>,
} }
pub struct HighlightedChunk<'a> {
pub chunk: &'a str,
pub style: Option<HighlightStyle>,
pub is_tab: bool,
}
pub struct DisplaySnapshot { pub struct DisplaySnapshot {
pub buffer_snapshot: MultiBufferSnapshot, pub buffer_snapshot: MultiBufferSnapshot,
pub fold_snapshot: fold_map::FoldSnapshot, pub fold_snapshot: fold_map::FoldSnapshot,
@ -485,7 +493,7 @@ impl DisplaySnapshot {
language_aware: bool, language_aware: bool,
inlay_highlight_style: Option<HighlightStyle>, inlay_highlight_style: Option<HighlightStyle>,
suggestion_highlight_style: Option<HighlightStyle>, suggestion_highlight_style: Option<HighlightStyle>,
) -> DisplayChunks<'_> { ) -> DisplayChunks<'a> {
self.block_snapshot.chunks( self.block_snapshot.chunks(
display_rows, display_rows,
language_aware, language_aware,
@ -498,6 +506,140 @@ impl DisplaySnapshot {
) )
} }
pub fn highlighted_chunks<'a>(
&'a self,
display_rows: Range<u32>,
language_aware: bool,
style: &'a EditorStyle,
) -> impl Iterator<Item = HighlightedChunk<'a>> {
self.chunks(
display_rows,
language_aware,
Some(style.theme.hint),
Some(style.theme.suggestion),
)
.map(|chunk| {
let mut highlight_style = chunk
.syntax_highlight_id
.and_then(|id| id.style(&style.syntax));
if let Some(chunk_highlight) = chunk.highlight_style {
if let Some(highlight_style) = highlight_style.as_mut() {
highlight_style.highlight(chunk_highlight);
} else {
highlight_style = Some(chunk_highlight);
}
}
let mut diagnostic_highlight = HighlightStyle::default();
if chunk.is_unnecessary {
diagnostic_highlight.fade_out = Some(style.unnecessary_code_fade);
}
if let Some(severity) = chunk.diagnostic_severity {
// Omit underlines for HINT/INFO diagnostics on 'unnecessary' code.
if severity <= DiagnosticSeverity::WARNING || !chunk.is_unnecessary {
let diagnostic_style = super::diagnostic_style(severity, true, style);
diagnostic_highlight.underline = Some(Underline {
color: Some(diagnostic_style.message.text.color),
thickness: 1.0.into(),
squiggly: true,
});
}
}
if let Some(highlight_style) = highlight_style.as_mut() {
highlight_style.highlight(diagnostic_highlight);
} else {
highlight_style = Some(diagnostic_highlight);
}
HighlightedChunk {
chunk: chunk.text,
style: highlight_style,
is_tab: chunk.is_tab,
}
})
}
pub fn lay_out_line_for_row(
&self,
display_row: u32,
TextLayoutDetails {
font_cache,
text_layout_cache,
editor_style,
}: &TextLayoutDetails,
) -> Line {
let mut styles = Vec::new();
let mut line = String::new();
let mut ended_in_newline = false;
let range = display_row..display_row + 1;
for chunk in self.highlighted_chunks(range, false, editor_style) {
line.push_str(chunk.chunk);
let text_style = if let Some(style) = chunk.style {
editor_style
.text
.clone()
.highlight(style, font_cache)
.map(Cow::Owned)
.unwrap_or_else(|_| Cow::Borrowed(&editor_style.text))
} else {
Cow::Borrowed(&editor_style.text)
};
ended_in_newline = chunk.chunk.ends_with("\n");
styles.push((
chunk.chunk.len(),
RunStyle {
font_id: text_style.font_id,
color: text_style.color,
underline: text_style.underline,
},
));
}
// our pixel positioning logic assumes each line ends in \n,
// this is almost always true except for the last line which
// may have no trailing newline.
if !ended_in_newline && display_row == self.max_point().row() {
line.push_str("\n");
styles.push((
"\n".len(),
RunStyle {
font_id: editor_style.text.font_id,
color: editor_style.text_color,
underline: editor_style.text.underline,
},
));
}
text_layout_cache.layout_str(&line, editor_style.text.font_size, &styles)
}
pub fn x_for_point(
&self,
display_point: DisplayPoint,
text_layout_details: &TextLayoutDetails,
) -> f32 {
let layout_line = self.lay_out_line_for_row(display_point.row(), text_layout_details);
layout_line.x_for_index(display_point.column() as usize)
}
pub fn column_for_x(
&self,
display_row: u32,
x_coordinate: f32,
text_layout_details: &TextLayoutDetails,
) -> u32 {
let layout_line = self.lay_out_line_for_row(display_row, text_layout_details);
layout_line.closest_index_for_x(x_coordinate) as u32
}
pub fn chars_at( pub fn chars_at(
&self, &self,
mut point: DisplayPoint, mut point: DisplayPoint,
@ -869,12 +1011,16 @@ pub fn next_rows(display_row: u32, display_map: &DisplaySnapshot) -> impl Iterat
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*; use super::*;
use crate::{movement, test::marked_display_snapshot}; use crate::{
movement,
test::{editor_test_context::EditorTestContext, marked_display_snapshot},
};
use gpui::{color::Color, elements::*, test::observe, AppContext}; use gpui::{color::Color, elements::*, test::observe, AppContext};
use language::{ use language::{
language_settings::{AllLanguageSettings, AllLanguageSettingsContent}, language_settings::{AllLanguageSettings, AllLanguageSettingsContent},
Buffer, Language, LanguageConfig, SelectionGoal, Buffer, Language, LanguageConfig, SelectionGoal,
}; };
use project::Project;
use rand::{prelude::*, Rng}; use rand::{prelude::*, Rng};
use settings::SettingsStore; use settings::SettingsStore;
use smol::stream::StreamExt; use smol::stream::StreamExt;
@ -1148,95 +1294,120 @@ pub mod tests {
} }
#[gpui::test(retries = 5)] #[gpui::test(retries = 5)]
fn test_soft_wraps(cx: &mut AppContext) { async fn test_soft_wraps(cx: &mut gpui::TestAppContext) {
cx.foreground().set_block_on_ticks(usize::MAX..=usize::MAX); cx.foreground().set_block_on_ticks(usize::MAX..=usize::MAX);
init_test(cx, |_| {}); cx.update(|cx| {
init_test(cx, |_| {});
let font_cache = cx.font_cache();
let family_id = font_cache
.load_family(&["Helvetica"], &Default::default())
.unwrap();
let font_id = font_cache
.select_font(family_id, &Default::default())
.unwrap();
let font_size = 12.0;
let wrap_width = Some(64.);
let text = "one two three four five\nsix seven eight";
let buffer = MultiBuffer::build_simple(text, cx);
let map = cx.add_model(|cx| {
DisplayMap::new(buffer.clone(), font_id, font_size, wrap_width, 1, 1, cx)
}); });
let snapshot = map.update(cx, |map, cx| map.snapshot(cx)); let mut cx = EditorTestContext::new(cx).await;
assert_eq!( let editor = cx.editor.clone();
snapshot.text_chunks(0).collect::<String>(), let window = cx.window.clone();
"one two \nthree four \nfive\nsix seven \neight"
);
assert_eq!(
snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Left),
DisplayPoint::new(0, 7)
);
assert_eq!(
snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Right),
DisplayPoint::new(1, 0)
);
assert_eq!(
movement::right(&snapshot, DisplayPoint::new(0, 7)),
DisplayPoint::new(1, 0)
);
assert_eq!(
movement::left(&snapshot, DisplayPoint::new(1, 0)),
DisplayPoint::new(0, 7)
);
assert_eq!(
movement::up(
&snapshot,
DisplayPoint::new(1, 10),
SelectionGoal::None,
false
),
(DisplayPoint::new(0, 7), SelectionGoal::Column(10))
);
assert_eq!(
movement::down(
&snapshot,
DisplayPoint::new(0, 7),
SelectionGoal::Column(10),
false
),
(DisplayPoint::new(1, 10), SelectionGoal::Column(10))
);
assert_eq!(
movement::down(
&snapshot,
DisplayPoint::new(1, 10),
SelectionGoal::Column(10),
false
),
(DisplayPoint::new(2, 4), SelectionGoal::Column(10))
);
let ix = snapshot.buffer_snapshot.text().find("seven").unwrap(); cx.update_window(window, |cx| {
buffer.update(cx, |buffer, cx| { let text_layout_details =
buffer.edit([(ix..ix, "and ")], None, cx); editor.read_with(cx, |editor, cx| editor.text_layout_details(cx));
let font_cache = cx.font_cache().clone();
let family_id = font_cache
.load_family(&["Helvetica"], &Default::default())
.unwrap();
let font_id = font_cache
.select_font(family_id, &Default::default())
.unwrap();
let font_size = 12.0;
let wrap_width = Some(64.);
let text = "one two three four five\nsix seven eight";
let buffer = MultiBuffer::build_simple(text, cx);
let map = cx.add_model(|cx| {
DisplayMap::new(buffer.clone(), font_id, font_size, wrap_width, 1, 1, cx)
});
let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
assert_eq!(
snapshot.text_chunks(0).collect::<String>(),
"one two \nthree four \nfive\nsix seven \neight"
);
assert_eq!(
snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Left),
DisplayPoint::new(0, 7)
);
assert_eq!(
snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Right),
DisplayPoint::new(1, 0)
);
assert_eq!(
movement::right(&snapshot, DisplayPoint::new(0, 7)),
DisplayPoint::new(1, 0)
);
assert_eq!(
movement::left(&snapshot, DisplayPoint::new(1, 0)),
DisplayPoint::new(0, 7)
);
let x = snapshot.x_for_point(DisplayPoint::new(1, 10), &text_layout_details);
assert_eq!(
movement::up(
&snapshot,
DisplayPoint::new(1, 10),
SelectionGoal::None,
false,
&text_layout_details,
),
(
DisplayPoint::new(0, 7),
SelectionGoal::HorizontalPosition(x)
)
);
assert_eq!(
movement::down(
&snapshot,
DisplayPoint::new(0, 7),
SelectionGoal::HorizontalPosition(x),
false,
&text_layout_details
),
(
DisplayPoint::new(1, 10),
SelectionGoal::HorizontalPosition(x)
)
);
assert_eq!(
movement::down(
&snapshot,
DisplayPoint::new(1, 10),
SelectionGoal::HorizontalPosition(x),
false,
&text_layout_details
),
(
DisplayPoint::new(2, 4),
SelectionGoal::HorizontalPosition(x)
)
);
let ix = snapshot.buffer_snapshot.text().find("seven").unwrap();
buffer.update(cx, |buffer, cx| {
buffer.edit([(ix..ix, "and ")], None, cx);
});
let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
assert_eq!(
snapshot.text_chunks(1).collect::<String>(),
"three four \nfive\nsix and \nseven eight"
);
// Re-wrap on font size changes
map.update(cx, |map, cx| map.set_font(font_id, font_size + 3., cx));
let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
assert_eq!(
snapshot.text_chunks(1).collect::<String>(),
"three \nfour five\nsix and \nseven \neight"
)
}); });
let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
assert_eq!(
snapshot.text_chunks(1).collect::<String>(),
"three four \nfive\nsix and \nseven eight"
);
// Re-wrap on font size changes
map.update(cx, |map, cx| map.set_font(font_id, font_size + 3., cx));
let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
assert_eq!(
snapshot.text_chunks(1).collect::<String>(),
"three \nfour five\nsix and \nseven \neight"
)
} }
#[gpui::test] #[gpui::test]
@ -1731,6 +1902,9 @@ pub mod tests {
cx.foreground().forbid_parking(); cx.foreground().forbid_parking();
cx.set_global(SettingsStore::test(cx)); cx.set_global(SettingsStore::test(cx));
language::init(cx); language::init(cx);
crate::init(cx);
Project::init_settings(cx);
theme::init((), cx);
cx.update_global::<SettingsStore, _, _>(|store, cx| { cx.update_global::<SettingsStore, _, _>(|store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, f); store.update_user_settings::<AllLanguageSettings>(cx, f);
}); });

View File

@ -71,6 +71,7 @@ use link_go_to_definition::{
}; };
use log::error; use log::error;
use lsp::LanguageServerId; use lsp::LanguageServerId;
use movement::TextLayoutDetails;
use multi_buffer::ToOffsetUtf16; use multi_buffer::ToOffsetUtf16;
pub use multi_buffer::{ pub use multi_buffer::{
Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, ToOffset, Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, ToOffset,
@ -3286,8 +3287,10 @@ impl Editor {
i = 0; i = 0;
} else if pair_state.range.start.to_offset(buffer) > range.end { } else if pair_state.range.start.to_offset(buffer) > range.end {
break; break;
} else if pair_state.selection_id == selection.id { } else {
enclosing = Some(pair_state); if pair_state.selection_id == selection.id {
enclosing = Some(pair_state);
}
i += 1; i += 1;
} }
} }
@ -3474,6 +3477,14 @@ impl Editor {
.collect() .collect()
} }
pub fn text_layout_details(&self, cx: &WindowContext) -> TextLayoutDetails {
TextLayoutDetails {
font_cache: cx.font_cache().clone(),
text_layout_cache: cx.text_layout_cache().clone(),
editor_style: self.style(cx),
}
}
fn splice_inlay_hints( fn splice_inlay_hints(
&self, &self,
to_remove: Vec<InlayId>, to_remove: Vec<InlayId>,
@ -5408,6 +5419,7 @@ impl Editor {
} }
pub fn transpose(&mut self, _: &Transpose, cx: &mut ViewContext<Self>) { pub fn transpose(&mut self, _: &Transpose, cx: &mut ViewContext<Self>) {
let text_layout_details = &self.text_layout_details(cx);
self.transact(cx, |this, cx| { self.transact(cx, |this, cx| {
let edits = this.change_selections(Some(Autoscroll::fit()), cx, |s| { let edits = this.change_selections(Some(Autoscroll::fit()), cx, |s| {
let mut edits: Vec<(Range<usize>, String)> = Default::default(); let mut edits: Vec<(Range<usize>, String)> = Default::default();
@ -5431,7 +5443,10 @@ impl Editor {
*head.column_mut() += 1; *head.column_mut() += 1;
head = display_map.clip_point(head, Bias::Right); head = display_map.clip_point(head, Bias::Right);
selection.collapse_to(head, SelectionGoal::Column(head.column())); let goal = SelectionGoal::HorizontalPosition(
display_map.x_for_point(head, &text_layout_details),
);
selection.collapse_to(head, goal);
let transpose_start = display_map let transpose_start = display_map
.buffer_snapshot .buffer_snapshot
@ -5695,13 +5710,21 @@ impl Editor {
return; return;
} }
let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(Autoscroll::fit()), cx, |s| { self.change_selections(Some(Autoscroll::fit()), cx, |s| {
let line_mode = s.line_mode; let line_mode = s.line_mode;
s.move_with(|map, selection| { s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode { if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None; selection.goal = SelectionGoal::None;
} }
let (cursor, goal) = movement::up(map, selection.start, selection.goal, false); let (cursor, goal) = movement::up(
map,
selection.start,
selection.goal,
false,
&text_layout_details,
);
selection.collapse_to(cursor, goal); selection.collapse_to(cursor, goal);
}); });
}) })
@ -5729,22 +5752,33 @@ impl Editor {
Autoscroll::fit() Autoscroll::fit()
}; };
let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(autoscroll), cx, |s| { self.change_selections(Some(autoscroll), cx, |s| {
let line_mode = s.line_mode; let line_mode = s.line_mode;
s.move_with(|map, selection| { s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode { if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None; selection.goal = SelectionGoal::None;
} }
let (cursor, goal) = let (cursor, goal) = movement::up_by_rows(
movement::up_by_rows(map, selection.end, row_count, selection.goal, false); map,
selection.end,
row_count,
selection.goal,
false,
&text_layout_details,
);
selection.collapse_to(cursor, goal); selection.collapse_to(cursor, goal);
}); });
}); });
} }
pub fn select_up(&mut self, _: &SelectUp, cx: &mut ViewContext<Self>) { pub fn select_up(&mut self, _: &SelectUp, cx: &mut ViewContext<Self>) {
let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(Autoscroll::fit()), cx, |s| { self.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.move_heads_with(|map, head, goal| movement::up(map, head, goal, false)) s.move_heads_with(|map, head, goal| {
movement::up(map, head, goal, false, &text_layout_details)
})
}) })
} }
@ -5756,13 +5790,20 @@ impl Editor {
return; return;
} }
let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(Autoscroll::fit()), cx, |s| { self.change_selections(Some(Autoscroll::fit()), cx, |s| {
let line_mode = s.line_mode; let line_mode = s.line_mode;
s.move_with(|map, selection| { s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode { if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None; selection.goal = SelectionGoal::None;
} }
let (cursor, goal) = movement::down(map, selection.end, selection.goal, false); let (cursor, goal) = movement::down(
map,
selection.end,
selection.goal,
false,
&text_layout_details,
);
selection.collapse_to(cursor, goal); selection.collapse_to(cursor, goal);
}); });
}); });
@ -5800,22 +5841,32 @@ impl Editor {
Autoscroll::fit() Autoscroll::fit()
}; };
let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(autoscroll), cx, |s| { self.change_selections(Some(autoscroll), cx, |s| {
let line_mode = s.line_mode; let line_mode = s.line_mode;
s.move_with(|map, selection| { s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode { if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None; selection.goal = SelectionGoal::None;
} }
let (cursor, goal) = let (cursor, goal) = movement::down_by_rows(
movement::down_by_rows(map, selection.end, row_count, selection.goal, false); map,
selection.end,
row_count,
selection.goal,
false,
&text_layout_details,
);
selection.collapse_to(cursor, goal); selection.collapse_to(cursor, goal);
}); });
}); });
} }
pub fn select_down(&mut self, _: &SelectDown, cx: &mut ViewContext<Self>) { pub fn select_down(&mut self, _: &SelectDown, cx: &mut ViewContext<Self>) {
let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(Autoscroll::fit()), cx, |s| { self.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.move_heads_with(|map, head, goal| movement::down(map, head, goal, false)) s.move_heads_with(|map, head, goal| {
movement::down(map, head, goal, false, &text_layout_details)
})
}); });
} }
@ -6334,11 +6385,14 @@ impl Editor {
fn add_selection(&mut self, above: bool, cx: &mut ViewContext<Self>) { fn add_selection(&mut self, above: bool, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx)); let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let mut selections = self.selections.all::<Point>(cx); let mut selections = self.selections.all::<Point>(cx);
let text_layout_details = self.text_layout_details(cx);
let mut state = self.add_selections_state.take().unwrap_or_else(|| { let mut state = self.add_selections_state.take().unwrap_or_else(|| {
let oldest_selection = selections.iter().min_by_key(|s| s.id).unwrap().clone(); let oldest_selection = selections.iter().min_by_key(|s| s.id).unwrap().clone();
let range = oldest_selection.display_range(&display_map).sorted(); let range = oldest_selection.display_range(&display_map).sorted();
let columns = cmp::min(range.start.column(), range.end.column())
..cmp::max(range.start.column(), range.end.column()); let start_x = display_map.x_for_point(range.start, &text_layout_details);
let end_x = display_map.x_for_point(range.end, &text_layout_details);
let positions = start_x.min(end_x)..start_x.max(end_x);
selections.clear(); selections.clear();
let mut stack = Vec::new(); let mut stack = Vec::new();
@ -6346,8 +6400,9 @@ impl Editor {
if let Some(selection) = self.selections.build_columnar_selection( if let Some(selection) = self.selections.build_columnar_selection(
&display_map, &display_map,
row, row,
&columns, &positions,
oldest_selection.reversed, oldest_selection.reversed,
&text_layout_details,
) { ) {
stack.push(selection.id); stack.push(selection.id);
selections.push(selection); selections.push(selection);
@ -6375,12 +6430,15 @@ impl Editor {
let range = selection.display_range(&display_map).sorted(); let range = selection.display_range(&display_map).sorted();
debug_assert_eq!(range.start.row(), range.end.row()); debug_assert_eq!(range.start.row(), range.end.row());
let mut row = range.start.row(); let mut row = range.start.row();
let columns = if let SelectionGoal::ColumnRange { start, end } = selection.goal let positions = if let SelectionGoal::HorizontalRange { start, end } =
selection.goal
{ {
start..end start..end
} else { } else {
cmp::min(range.start.column(), range.end.column()) let start_x = display_map.x_for_point(range.start, &text_layout_details);
..cmp::max(range.start.column(), range.end.column()) let end_x = display_map.x_for_point(range.end, &text_layout_details);
start_x.min(end_x)..start_x.max(end_x)
}; };
while row != end_row { while row != end_row {
@ -6393,8 +6451,9 @@ impl Editor {
if let Some(new_selection) = self.selections.build_columnar_selection( if let Some(new_selection) = self.selections.build_columnar_selection(
&display_map, &display_map,
row, row,
&columns, &positions,
selection.reversed, selection.reversed,
&text_layout_details,
) { ) {
state.stack.push(new_selection.id); state.stack.push(new_selection.id);
if above { if above {
@ -6688,6 +6747,7 @@ impl Editor {
} }
pub fn toggle_comments(&mut self, action: &ToggleComments, cx: &mut ViewContext<Self>) { pub fn toggle_comments(&mut self, action: &ToggleComments, cx: &mut ViewContext<Self>) {
let text_layout_details = &self.text_layout_details(cx);
self.transact(cx, |this, cx| { self.transact(cx, |this, cx| {
let mut selections = this.selections.all::<Point>(cx); let mut selections = this.selections.all::<Point>(cx);
let mut edits = Vec::new(); let mut edits = Vec::new();
@ -6930,7 +6990,10 @@ impl Editor {
point.row += 1; point.row += 1;
point = snapshot.clip_point(point, Bias::Left); point = snapshot.clip_point(point, Bias::Left);
let display_point = point.to_display_point(display_snapshot); let display_point = point.to_display_point(display_snapshot);
(display_point, SelectionGoal::Column(display_point.column())) let goal = SelectionGoal::HorizontalPosition(
display_snapshot.x_for_point(display_point, &text_layout_details),
);
(display_point, goal)
}) })
}); });
} }

View File

@ -19,8 +19,8 @@ use gpui::{
use indoc::indoc; use indoc::indoc;
use language::{ use language::{
language_settings::{AllLanguageSettings, AllLanguageSettingsContent, LanguageSettingsContent}, language_settings::{AllLanguageSettings, AllLanguageSettingsContent, LanguageSettingsContent},
BracketPairConfig, BundledFormatter, FakeLspAdapter, LanguageConfig, LanguageConfigOverride, BracketPairConfig, FakeLspAdapter, LanguageConfig, LanguageConfigOverride, LanguageRegistry,
LanguageRegistry, Override, Point, Override, Point,
}; };
use parking_lot::Mutex; use parking_lot::Mutex;
use project::project_settings::{LspSettings, ProjectSettings}; use project::project_settings::{LspSettings, ProjectSettings};
@ -851,7 +851,7 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
let view = cx let view = cx
.add_window(|cx| { .add_window(|cx| {
let buffer = MultiBuffer::build_simple("ⓐⓑⓒⓓⓔ\nabcde\nαβγδε\n", cx); let buffer = MultiBuffer::build_simple("ⓐⓑⓒⓓⓔ\nabcde\nαβγδε", cx);
build_editor(buffer.clone(), cx) build_editor(buffer.clone(), cx)
}) })
.root(cx); .root(cx);
@ -869,7 +869,7 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
true, true,
cx, cx,
); );
assert_eq!(view.display_text(cx), "ⓐⓑ⋯ⓔ\nab⋯e\nαβ⋯ε\n"); assert_eq!(view.display_text(cx), "ⓐⓑ⋯ⓔ\nab⋯e\nαβ⋯ε");
view.move_right(&MoveRight, cx); view.move_right(&MoveRight, cx);
assert_eq!( assert_eq!(
@ -888,6 +888,11 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
); );
view.move_down(&MoveDown, cx); view.move_down(&MoveDown, cx);
assert_eq!(
view.selections.display_ranges(cx),
&[empty_range(1, "ab⋯e".len())]
);
view.move_left(&MoveLeft, cx);
assert_eq!( assert_eq!(
view.selections.display_ranges(cx), view.selections.display_ranges(cx),
&[empty_range(1, "ab⋯".len())] &[empty_range(1, "ab⋯".len())]
@ -929,17 +934,18 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
view.selections.display_ranges(cx), view.selections.display_ranges(cx),
&[empty_range(1, "ab⋯e".len())] &[empty_range(1, "ab⋯e".len())]
); );
view.move_down(&MoveDown, cx);
assert_eq!(
view.selections.display_ranges(cx),
&[empty_range(2, "αβ⋯ε".len())]
);
view.move_up(&MoveUp, cx); view.move_up(&MoveUp, cx);
assert_eq!( assert_eq!(
view.selections.display_ranges(cx), view.selections.display_ranges(cx),
&[empty_range(0, "ⓐⓑ⋯ⓔ".len())] &[empty_range(1, "ab⋯e".len())]
); );
view.move_left(&MoveLeft, cx);
assert_eq!( view.move_up(&MoveUp, cx);
view.selections.display_ranges(cx),
&[empty_range(0, "ⓐⓑ⋯".len())]
);
view.move_left(&MoveLeft, cx);
assert_eq!( assert_eq!(
view.selections.display_ranges(cx), view.selections.display_ranges(cx),
&[empty_range(0, "ⓐⓑ".len())] &[empty_range(0, "ⓐⓑ".len())]
@ -949,6 +955,11 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
view.selections.display_ranges(cx), view.selections.display_ranges(cx),
&[empty_range(0, "".len())] &[empty_range(0, "".len())]
); );
view.move_left(&MoveLeft, cx);
assert_eq!(
view.selections.display_ranges(cx),
&[empty_range(0, "".len())]
);
}); });
} }
@ -5084,6 +5095,9 @@ async fn test_document_format_manual_trigger(cx: &mut gpui::TestAppContext) {
LanguageConfig { LanguageConfig {
name: "Rust".into(), name: "Rust".into(),
path_suffixes: vec!["rs".to_string()], path_suffixes: vec!["rs".to_string()],
// Enable Prettier formatting for the same buffer, and ensure
// LSP is called instead of Prettier.
prettier_parser_name: Some("test_parser".to_string()),
..Default::default() ..Default::default()
}, },
Some(tree_sitter_rust::language()), Some(tree_sitter_rust::language()),
@ -5094,12 +5108,6 @@ async fn test_document_format_manual_trigger(cx: &mut gpui::TestAppContext) {
document_formatting_provider: Some(lsp::OneOf::Left(true)), document_formatting_provider: Some(lsp::OneOf::Left(true)),
..Default::default() ..Default::default()
}, },
// Enable Prettier formatting for the same buffer, and ensure
// LSP is called instead of Prettier.
enabled_formatters: vec![BundledFormatter::Prettier {
parser_name: Some("test_parser"),
plugin_names: Vec::new(),
}],
..Default::default() ..Default::default()
})) }))
.await; .await;
@ -7838,6 +7846,7 @@ async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
LanguageConfig { LanguageConfig {
name: "Rust".into(), name: "Rust".into(),
path_suffixes: vec!["rs".to_string()], path_suffixes: vec!["rs".to_string()],
prettier_parser_name: Some("test_parser".to_string()),
..Default::default() ..Default::default()
}, },
Some(tree_sitter_rust::language()), Some(tree_sitter_rust::language()),
@ -7846,10 +7855,7 @@ async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
let test_plugin = "test_plugin"; let test_plugin = "test_plugin";
let _ = language let _ = language
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter { .set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
enabled_formatters: vec![BundledFormatter::Prettier { prettier_plugins: vec![test_plugin],
parser_name: Some("test_parser"),
plugin_names: vec![test_plugin],
}],
..Default::default() ..Default::default()
})) }))
.await; .await;

View File

@ -4,7 +4,7 @@ use super::{
MAX_LINE_LEN, MAX_LINE_LEN,
}; };
use crate::{ use crate::{
display_map::{BlockStyle, DisplaySnapshot, FoldStatus, TransformBlock}, display_map::{BlockStyle, DisplaySnapshot, FoldStatus, HighlightedChunk, TransformBlock},
editor_settings::ShowScrollbar, editor_settings::ShowScrollbar,
git::{diff_hunk_to_display, DisplayDiffHunk}, git::{diff_hunk_to_display, DisplayDiffHunk},
hover_popover::{ hover_popover::{
@ -22,7 +22,7 @@ use git::diff::DiffHunkStatus;
use gpui::{ use gpui::{
color::Color, color::Color,
elements::*, elements::*,
fonts::{HighlightStyle, TextStyle, Underline}, fonts::TextStyle,
geometry::{ geometry::{
rect::RectF, rect::RectF,
vector::{vec2f, Vector2F}, vector::{vec2f, Vector2F},
@ -37,8 +37,7 @@ use gpui::{
use itertools::Itertools; use itertools::Itertools;
use json::json; use json::json;
use language::{ use language::{
language_settings::ShowWhitespaceSetting, Bias, CursorShape, DiagnosticSeverity, OffsetUtf16, language_settings::ShowWhitespaceSetting, Bias, CursorShape, OffsetUtf16, Selection,
Selection,
}; };
use project::{ use project::{
project_settings::{GitGutterSetting, ProjectSettings}, project_settings::{GitGutterSetting, ProjectSettings},
@ -1584,56 +1583,7 @@ impl EditorElement {
.collect() .collect()
} else { } else {
let style = &self.style; let style = &self.style;
let chunks = snapshot let chunks = snapshot.highlighted_chunks(rows.clone(), true, style);
.chunks(
rows.clone(),
true,
Some(style.theme.hint),
Some(style.theme.suggestion),
)
.map(|chunk| {
let mut highlight_style = chunk
.syntax_highlight_id
.and_then(|id| id.style(&style.syntax));
if let Some(chunk_highlight) = chunk.highlight_style {
if let Some(highlight_style) = highlight_style.as_mut() {
highlight_style.highlight(chunk_highlight);
} else {
highlight_style = Some(chunk_highlight);
}
}
let mut diagnostic_highlight = HighlightStyle::default();
if chunk.is_unnecessary {
diagnostic_highlight.fade_out = Some(style.unnecessary_code_fade);
}
if let Some(severity) = chunk.diagnostic_severity {
// Omit underlines for HINT/INFO diagnostics on 'unnecessary' code.
if severity <= DiagnosticSeverity::WARNING || !chunk.is_unnecessary {
let diagnostic_style = super::diagnostic_style(severity, true, style);
diagnostic_highlight.underline = Some(Underline {
color: Some(diagnostic_style.message.text.color),
thickness: 1.0.into(),
squiggly: true,
});
}
}
if let Some(highlight_style) = highlight_style.as_mut() {
highlight_style.highlight(diagnostic_highlight);
} else {
highlight_style = Some(diagnostic_highlight);
}
HighlightedChunk {
chunk: chunk.text,
style: highlight_style,
is_tab: chunk.is_tab,
}
});
LineWithInvisibles::from_chunks( LineWithInvisibles::from_chunks(
chunks, chunks,
@ -1870,12 +1820,6 @@ impl EditorElement {
} }
} }
struct HighlightedChunk<'a> {
chunk: &'a str,
style: Option<HighlightStyle>,
is_tab: bool,
}
#[derive(Debug)] #[derive(Debug)]
pub struct LineWithInvisibles { pub struct LineWithInvisibles {
pub line: Line, pub line: Line,

View File

@ -2138,7 +2138,7 @@ pub mod tests {
}); });
} }
#[gpui::test] #[gpui::test(iterations = 10)]
async fn test_large_buffer_inlay_requests_split(cx: &mut gpui::TestAppContext) { async fn test_large_buffer_inlay_requests_split(cx: &mut gpui::TestAppContext) {
init_test(cx, |settings| { init_test(cx, |settings| {
settings.defaults.inlay_hints = Some(InlayHintSettings { settings.defaults.inlay_hints = Some(InlayHintSettings {
@ -2400,11 +2400,13 @@ pub mod tests {
)); ));
cx.foreground().run_until_parked(); cx.foreground().run_until_parked();
editor.update(cx, |editor, cx| { editor.update(cx, |editor, cx| {
let ranges = lsp_request_ranges.lock().drain(..).collect::<Vec<_>>(); let mut ranges = lsp_request_ranges.lock().drain(..).collect::<Vec<_>>();
ranges.sort_by_key(|r| r.start);
assert_eq!(ranges.len(), 3, assert_eq!(ranges.len(), 3,
"On edit, should scroll to selection and query a range around it: visible + same range above and below. Instead, got query ranges {ranges:?}"); "On edit, should scroll to selection and query a range around it: visible + same range above and below. Instead, got query ranges {ranges:?}");
let visible_query_range = &ranges[0]; let above_query_range = &ranges[0];
let above_query_range = &ranges[1]; let visible_query_range = &ranges[1];
let below_query_range = &ranges[2]; let below_query_range = &ranges[2];
assert!(above_query_range.end.character < visible_query_range.start.character || above_query_range.end.line + 1 == visible_query_range.start.line, assert!(above_query_range.end.character < visible_query_range.start.character || above_query_range.end.line + 1 == visible_query_range.start.line,
"Above range {above_query_range:?} should be before visible range {visible_query_range:?}"); "Above range {above_query_range:?} should be before visible range {visible_query_range:?}");

View File

@ -1,7 +1,8 @@
use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint}; use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint};
use crate::{char_kind, CharKind, ToOffset, ToPoint}; use crate::{char_kind, CharKind, EditorStyle, ToOffset, ToPoint};
use gpui::{FontCache, TextLayoutCache};
use language::Point; use language::Point;
use std::ops::Range; use std::{ops::Range, sync::Arc};
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum FindRange { pub enum FindRange {
@ -9,6 +10,14 @@ pub enum FindRange {
MultiLine, MultiLine,
} }
/// TextLayoutDetails encompasses everything we need to move vertically
/// taking into account variable width characters.
pub struct TextLayoutDetails {
pub font_cache: Arc<FontCache>,
pub text_layout_cache: Arc<TextLayoutCache>,
pub editor_style: EditorStyle,
}
pub fn left(map: &DisplaySnapshot, mut point: DisplayPoint) -> DisplayPoint { pub fn left(map: &DisplaySnapshot, mut point: DisplayPoint) -> DisplayPoint {
if point.column() > 0 { if point.column() > 0 {
*point.column_mut() -= 1; *point.column_mut() -= 1;
@ -47,8 +56,16 @@ pub fn up(
start: DisplayPoint, start: DisplayPoint,
goal: SelectionGoal, goal: SelectionGoal,
preserve_column_at_start: bool, preserve_column_at_start: bool,
text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) { ) -> (DisplayPoint, SelectionGoal) {
up_by_rows(map, start, 1, goal, preserve_column_at_start) up_by_rows(
map,
start,
1,
goal,
preserve_column_at_start,
text_layout_details,
)
} }
pub fn down( pub fn down(
@ -56,8 +73,16 @@ pub fn down(
start: DisplayPoint, start: DisplayPoint,
goal: SelectionGoal, goal: SelectionGoal,
preserve_column_at_end: bool, preserve_column_at_end: bool,
text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) { ) -> (DisplayPoint, SelectionGoal) {
down_by_rows(map, start, 1, goal, preserve_column_at_end) down_by_rows(
map,
start,
1,
goal,
preserve_column_at_end,
text_layout_details,
)
} }
pub fn up_by_rows( pub fn up_by_rows(
@ -66,11 +91,13 @@ pub fn up_by_rows(
row_count: u32, row_count: u32,
goal: SelectionGoal, goal: SelectionGoal,
preserve_column_at_start: bool, preserve_column_at_start: bool,
text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) { ) -> (DisplayPoint, SelectionGoal) {
let mut goal_column = match goal { let mut goal_x = match goal {
SelectionGoal::Column(column) => column, SelectionGoal::HorizontalPosition(x) => x,
SelectionGoal::ColumnRange { end, .. } => end, SelectionGoal::WrappedHorizontalPosition((_, x)) => x,
_ => map.column_to_chars(start.row(), start.column()), SelectionGoal::HorizontalRange { end, .. } => end,
_ => map.x_for_point(start, text_layout_details),
}; };
let prev_row = start.row().saturating_sub(row_count); let prev_row = start.row().saturating_sub(row_count);
@ -79,19 +106,19 @@ pub fn up_by_rows(
Bias::Left, Bias::Left,
); );
if point.row() < start.row() { if point.row() < start.row() {
*point.column_mut() = map.column_from_chars(point.row(), goal_column); *point.column_mut() = map.column_for_x(point.row(), goal_x, text_layout_details)
} else if preserve_column_at_start { } else if preserve_column_at_start {
return (start, goal); return (start, goal);
} else { } else {
point = DisplayPoint::new(0, 0); point = DisplayPoint::new(0, 0);
goal_column = 0; goal_x = 0.0;
} }
let mut clipped_point = map.clip_point(point, Bias::Left); let mut clipped_point = map.clip_point(point, Bias::Left);
if clipped_point.row() < point.row() { if clipped_point.row() < point.row() {
clipped_point = map.clip_point(point, Bias::Right); clipped_point = map.clip_point(point, Bias::Right);
} }
(clipped_point, SelectionGoal::Column(goal_column)) (clipped_point, SelectionGoal::HorizontalPosition(goal_x))
} }
pub fn down_by_rows( pub fn down_by_rows(
@ -100,29 +127,31 @@ pub fn down_by_rows(
row_count: u32, row_count: u32,
goal: SelectionGoal, goal: SelectionGoal,
preserve_column_at_end: bool, preserve_column_at_end: bool,
text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) { ) -> (DisplayPoint, SelectionGoal) {
let mut goal_column = match goal { let mut goal_x = match goal {
SelectionGoal::Column(column) => column, SelectionGoal::HorizontalPosition(x) => x,
SelectionGoal::ColumnRange { end, .. } => end, SelectionGoal::WrappedHorizontalPosition((_, x)) => x,
_ => map.column_to_chars(start.row(), start.column()), SelectionGoal::HorizontalRange { end, .. } => end,
_ => map.x_for_point(start, text_layout_details),
}; };
let new_row = start.row() + row_count; let new_row = start.row() + row_count;
let mut point = map.clip_point(DisplayPoint::new(new_row, 0), Bias::Right); let mut point = map.clip_point(DisplayPoint::new(new_row, 0), Bias::Right);
if point.row() > start.row() { if point.row() > start.row() {
*point.column_mut() = map.column_from_chars(point.row(), goal_column); *point.column_mut() = map.column_for_x(point.row(), goal_x, text_layout_details)
} else if preserve_column_at_end { } else if preserve_column_at_end {
return (start, goal); return (start, goal);
} else { } else {
point = map.max_point(); point = map.max_point();
goal_column = map.column_to_chars(point.row(), point.column()) goal_x = map.x_for_point(point, text_layout_details)
} }
let mut clipped_point = map.clip_point(point, Bias::Right); let mut clipped_point = map.clip_point(point, Bias::Right);
if clipped_point.row() > point.row() { if clipped_point.row() > point.row() {
clipped_point = map.clip_point(point, Bias::Left); clipped_point = map.clip_point(point, Bias::Left);
} }
(clipped_point, SelectionGoal::Column(goal_column)) (clipped_point, SelectionGoal::HorizontalPosition(goal_x))
} }
pub fn line_beginning( pub fn line_beginning(
@ -396,9 +425,11 @@ pub fn split_display_range_by_lines(
mod tests { mod tests {
use super::*; use super::*;
use crate::{ use crate::{
display_map::Inlay, test::marked_display_snapshot, Buffer, DisplayMap, ExcerptRange, display_map::Inlay,
InlayId, MultiBuffer, test::{editor_test_context::EditorTestContext, marked_display_snapshot},
Buffer, DisplayMap, ExcerptRange, InlayId, MultiBuffer,
}; };
use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use util::post_inc; use util::post_inc;
@ -691,123 +722,173 @@ mod tests {
} }
#[gpui::test] #[gpui::test]
fn test_move_up_and_down_with_excerpts(cx: &mut gpui::AppContext) { async fn test_move_up_and_down_with_excerpts(cx: &mut gpui::TestAppContext) {
init_test(cx); cx.update(|cx| {
init_test(cx);
let family_id = cx
.font_cache()
.load_family(&["Helvetica"], &Default::default())
.unwrap();
let font_id = cx
.font_cache()
.select_font(family_id, &Default::default())
.unwrap();
let buffer =
cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "abc\ndefg\nhijkl\nmn"));
let multibuffer = cx.add_model(|cx| {
let mut multibuffer = MultiBuffer::new(0);
multibuffer.push_excerpts(
buffer.clone(),
[
ExcerptRange {
context: Point::new(0, 0)..Point::new(1, 4),
primary: None,
},
ExcerptRange {
context: Point::new(2, 0)..Point::new(3, 2),
primary: None,
},
],
cx,
);
multibuffer
}); });
let display_map =
cx.add_model(|cx| DisplayMap::new(multibuffer, font_id, 14.0, None, 2, 2, cx));
let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx));
assert_eq!(snapshot.text(), "\n\nabc\ndefg\n\n\nhijkl\nmn"); let mut cx = EditorTestContext::new(cx).await;
let editor = cx.editor.clone();
let window = cx.window.clone();
cx.update_window(window, |cx| {
let text_layout_details =
editor.read_with(cx, |editor, cx| editor.text_layout_details(cx));
// Can't move up into the first excerpt's header let family_id = cx
assert_eq!( .font_cache()
up( .load_family(&["Helvetica"], &Default::default())
&snapshot, .unwrap();
DisplayPoint::new(2, 2), let font_id = cx
SelectionGoal::Column(2), .font_cache()
false .select_font(family_id, &Default::default())
), .unwrap();
(DisplayPoint::new(2, 0), SelectionGoal::Column(0)),
);
assert_eq!(
up(
&snapshot,
DisplayPoint::new(2, 0),
SelectionGoal::None,
false
),
(DisplayPoint::new(2, 0), SelectionGoal::Column(0)),
);
// Move up and down within first excerpt let buffer =
assert_eq!( cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "abc\ndefg\nhijkl\nmn"));
up( let multibuffer = cx.add_model(|cx| {
&snapshot, let mut multibuffer = MultiBuffer::new(0);
DisplayPoint::new(3, 4), multibuffer.push_excerpts(
SelectionGoal::Column(4), buffer.clone(),
false [
), ExcerptRange {
(DisplayPoint::new(2, 3), SelectionGoal::Column(4)), context: Point::new(0, 0)..Point::new(1, 4),
); primary: None,
assert_eq!( },
down( ExcerptRange {
&snapshot, context: Point::new(2, 0)..Point::new(3, 2),
DisplayPoint::new(2, 3), primary: None,
SelectionGoal::Column(4), },
false ],
), cx,
(DisplayPoint::new(3, 4), SelectionGoal::Column(4)), );
); multibuffer
});
let display_map =
cx.add_model(|cx| DisplayMap::new(multibuffer, font_id, 14.0, None, 2, 2, cx));
let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx));
// Move up and down across second excerpt's header assert_eq!(snapshot.text(), "\n\nabc\ndefg\n\n\nhijkl\nmn");
assert_eq!(
up(
&snapshot,
DisplayPoint::new(6, 5),
SelectionGoal::Column(5),
false
),
(DisplayPoint::new(3, 4), SelectionGoal::Column(5)),
);
assert_eq!(
down(
&snapshot,
DisplayPoint::new(3, 4),
SelectionGoal::Column(5),
false
),
(DisplayPoint::new(6, 5), SelectionGoal::Column(5)),
);
// Can't move down off the end let col_2_x = snapshot.x_for_point(DisplayPoint::new(2, 2), &text_layout_details);
assert_eq!(
down( // Can't move up into the first excerpt's header
&snapshot, assert_eq!(
DisplayPoint::new(7, 0), up(
SelectionGoal::Column(0), &snapshot,
false DisplayPoint::new(2, 2),
), SelectionGoal::HorizontalPosition(col_2_x),
(DisplayPoint::new(7, 2), SelectionGoal::Column(2)), false,
); &text_layout_details
assert_eq!( ),
down( (
&snapshot, DisplayPoint::new(2, 0),
DisplayPoint::new(7, 2), SelectionGoal::HorizontalPosition(0.0)
SelectionGoal::Column(2), ),
false );
), assert_eq!(
(DisplayPoint::new(7, 2), SelectionGoal::Column(2)), up(
); &snapshot,
DisplayPoint::new(2, 0),
SelectionGoal::None,
false,
&text_layout_details
),
(
DisplayPoint::new(2, 0),
SelectionGoal::HorizontalPosition(0.0)
),
);
let col_4_x = snapshot.x_for_point(DisplayPoint::new(3, 4), &text_layout_details);
// Move up and down within first excerpt
assert_eq!(
up(
&snapshot,
DisplayPoint::new(3, 4),
SelectionGoal::HorizontalPosition(col_4_x),
false,
&text_layout_details
),
(
DisplayPoint::new(2, 3),
SelectionGoal::HorizontalPosition(col_4_x)
),
);
assert_eq!(
down(
&snapshot,
DisplayPoint::new(2, 3),
SelectionGoal::HorizontalPosition(col_4_x),
false,
&text_layout_details
),
(
DisplayPoint::new(3, 4),
SelectionGoal::HorizontalPosition(col_4_x)
),
);
let col_5_x = snapshot.x_for_point(DisplayPoint::new(6, 5), &text_layout_details);
// Move up and down across second excerpt's header
assert_eq!(
up(
&snapshot,
DisplayPoint::new(6, 5),
SelectionGoal::HorizontalPosition(col_5_x),
false,
&text_layout_details
),
(
DisplayPoint::new(3, 4),
SelectionGoal::HorizontalPosition(col_5_x)
),
);
assert_eq!(
down(
&snapshot,
DisplayPoint::new(3, 4),
SelectionGoal::HorizontalPosition(col_5_x),
false,
&text_layout_details
),
(
DisplayPoint::new(6, 5),
SelectionGoal::HorizontalPosition(col_5_x)
),
);
let max_point_x = snapshot.x_for_point(DisplayPoint::new(7, 2), &text_layout_details);
// Can't move down off the end
assert_eq!(
down(
&snapshot,
DisplayPoint::new(7, 0),
SelectionGoal::HorizontalPosition(0.0),
false,
&text_layout_details
),
(
DisplayPoint::new(7, 2),
SelectionGoal::HorizontalPosition(max_point_x)
),
);
assert_eq!(
down(
&snapshot,
DisplayPoint::new(7, 2),
SelectionGoal::HorizontalPosition(max_point_x),
false,
&text_layout_details
),
(
DisplayPoint::new(7, 2),
SelectionGoal::HorizontalPosition(max_point_x)
),
);
});
} }
fn init_test(cx: &mut gpui::AppContext) { fn init_test(cx: &mut gpui::AppContext) {
@ -815,5 +896,6 @@ mod tests {
theme::init((), cx); theme::init((), cx);
language::init(cx); language::init(cx);
crate::init(cx); crate::init(cx);
Project::init_settings(cx);
} }
} }

View File

@ -1,6 +1,6 @@
use std::{ use std::{
cell::Ref, cell::Ref,
cmp, iter, mem, iter, mem,
ops::{Deref, DerefMut, Range, Sub}, ops::{Deref, DerefMut, Range, Sub},
sync::Arc, sync::Arc,
}; };
@ -13,6 +13,7 @@ use util::post_inc;
use crate::{ use crate::{
display_map::{DisplayMap, DisplaySnapshot, ToDisplayPoint}, display_map::{DisplayMap, DisplaySnapshot, ToDisplayPoint},
movement::TextLayoutDetails,
Anchor, DisplayPoint, ExcerptId, MultiBuffer, MultiBufferSnapshot, SelectMode, ToOffset, Anchor, DisplayPoint, ExcerptId, MultiBuffer, MultiBufferSnapshot, SelectMode, ToOffset,
}; };
@ -305,23 +306,29 @@ impl SelectionsCollection {
&mut self, &mut self,
display_map: &DisplaySnapshot, display_map: &DisplaySnapshot,
row: u32, row: u32,
columns: &Range<u32>, positions: &Range<f32>,
reversed: bool, reversed: bool,
text_layout_details: &TextLayoutDetails,
) -> Option<Selection<Point>> { ) -> Option<Selection<Point>> {
let is_empty = columns.start == columns.end; let is_empty = positions.start == positions.end;
let line_len = display_map.line_len(row); let line_len = display_map.line_len(row);
if columns.start < line_len || (is_empty && columns.start == line_len) {
let start = DisplayPoint::new(row, columns.start); let layed_out_line = display_map.lay_out_line_for_row(row, &text_layout_details);
let end = DisplayPoint::new(row, cmp::min(columns.end, line_len));
let start_col = layed_out_line.closest_index_for_x(positions.start) as u32;
if start_col < line_len || (is_empty && positions.start == layed_out_line.width()) {
let start = DisplayPoint::new(row, start_col);
let end_col = layed_out_line.closest_index_for_x(positions.end) as u32;
let end = DisplayPoint::new(row, end_col);
Some(Selection { Some(Selection {
id: post_inc(&mut self.next_selection_id), id: post_inc(&mut self.next_selection_id),
start: start.to_point(display_map), start: start.to_point(display_map),
end: end.to_point(display_map), end: end.to_point(display_map),
reversed, reversed,
goal: SelectionGoal::ColumnRange { goal: SelectionGoal::HorizontalRange {
start: columns.start, start: positions.start,
end: columns.end, end: positions.end,
}, },
}) })
} else { } else {

View File

@ -30,7 +30,7 @@ struct StateInner<V> {
orientation: Orientation, orientation: Orientation,
overdraw: f32, overdraw: f32,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
scroll_handler: Option<Box<dyn FnMut(Range<usize>, &mut V, &mut ViewContext<V>)>>, scroll_handler: Option<Box<dyn FnMut(Range<usize>, usize, &mut V, &mut ViewContext<V>)>>,
} }
#[derive(Clone, Copy, Debug, Default, PartialEq)] #[derive(Clone, Copy, Debug, Default, PartialEq)]
@ -378,6 +378,10 @@ impl<V: 'static> ListState<V> {
.extend((0..element_count).map(|_| ListItem::Unrendered), &()); .extend((0..element_count).map(|_| ListItem::Unrendered), &());
} }
pub fn item_count(&self) -> usize {
self.0.borrow().items.summary().count
}
pub fn splice(&self, old_range: Range<usize>, count: usize) { pub fn splice(&self, old_range: Range<usize>, count: usize) {
let state = &mut *self.0.borrow_mut(); let state = &mut *self.0.borrow_mut();
@ -416,7 +420,7 @@ impl<V: 'static> ListState<V> {
pub fn set_scroll_handler( pub fn set_scroll_handler(
&mut self, &mut self,
handler: impl FnMut(Range<usize>, &mut V, &mut ViewContext<V>) + 'static, handler: impl FnMut(Range<usize>, usize, &mut V, &mut ViewContext<V>) + 'static,
) { ) {
self.0.borrow_mut().scroll_handler = Some(Box::new(handler)) self.0.borrow_mut().scroll_handler = Some(Box::new(handler))
} }
@ -529,7 +533,12 @@ impl<V: 'static> StateInner<V> {
if self.scroll_handler.is_some() { if self.scroll_handler.is_some() {
let visible_range = self.visible_range(height, scroll_top); let visible_range = self.visible_range(height, scroll_top);
self.scroll_handler.as_mut().unwrap()(visible_range, view, cx); self.scroll_handler.as_mut().unwrap()(
visible_range,
self.items.summary().count,
view,
cx,
);
} }
cx.notify(); cx.notify();

View File

@ -266,6 +266,8 @@ impl Line {
self.layout.len == 0 self.layout.len == 0
} }
/// index_for_x returns the character containing the given x coordinate.
/// (e.g. to handle a mouse-click)
pub fn index_for_x(&self, x: f32) -> Option<usize> { pub fn index_for_x(&self, x: f32) -> Option<usize> {
if x >= self.layout.width { if x >= self.layout.width {
None None
@ -281,6 +283,28 @@ impl Line {
} }
} }
/// closest_index_for_x returns the character boundary closest to the given x coordinate
/// (e.g. to handle aligning up/down arrow keys)
pub fn closest_index_for_x(&self, x: f32) -> usize {
let mut prev_index = 0;
let mut prev_x = 0.0;
for run in self.layout.runs.iter() {
for glyph in run.glyphs.iter() {
if glyph.position.x() >= x {
if glyph.position.x() - x < x - prev_x {
return glyph.index;
} else {
return prev_index;
}
}
prev_index = glyph.index;
prev_x = glyph.position.x();
}
}
prev_index
}
pub fn paint( pub fn paint(
&self, &self,
origin: Vector2F, origin: Vector2F,

View File

@ -201,7 +201,7 @@ pub struct CodeAction {
pub lsp_action: lsp::CodeAction, pub lsp_action: lsp::CodeAction,
} }
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq)]
pub enum Operation { pub enum Operation {
Buffer(text::Operation), Buffer(text::Operation),
@ -224,7 +224,7 @@ pub enum Operation {
}, },
} }
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq)]
pub enum Event { pub enum Event {
Operation(Operation), Operation(Operation),
Edited, Edited,

View File

@ -226,8 +226,8 @@ impl CachedLspAdapter {
self.adapter.label_for_symbol(name, kind, language).await self.adapter.label_for_symbol(name, kind, language).await
} }
pub fn enabled_formatters(&self) -> Vec<BundledFormatter> { pub fn prettier_plugins(&self) -> &[&'static str] {
self.adapter.enabled_formatters() self.adapter.prettier_plugins()
} }
} }
@ -336,31 +336,8 @@ pub trait LspAdapter: 'static + Send + Sync {
Default::default() Default::default()
} }
fn enabled_formatters(&self) -> Vec<BundledFormatter> { fn prettier_plugins(&self) -> &[&'static str] {
Vec::new() &[]
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum BundledFormatter {
Prettier {
// See https://prettier.io/docs/en/options.html#parser for a list of valid values.
// Usually, every language has a single parser (standard or plugin-provided), hence `Some("parser_name")` can be used.
// There can not be multiple parsers for a single language, in case of a conflict, we would attempt to select the one with most plugins.
//
// But exceptions like Tailwind CSS exist, which uses standard parsers for CSS/JS/HTML/etc. but require an extra plugin to be installed.
// For those cases, `None` will install the plugin but apply other, regular parser defined for the language, and this would not be a conflict.
parser_name: Option<&'static str>,
plugin_names: Vec<&'static str>,
},
}
impl BundledFormatter {
pub fn prettier(parser_name: &'static str) -> Self {
Self::Prettier {
parser_name: Some(parser_name),
plugin_names: Vec::new(),
}
} }
} }
@ -398,6 +375,8 @@ pub struct LanguageConfig {
pub overrides: HashMap<String, LanguageConfigOverride>, pub overrides: HashMap<String, LanguageConfigOverride>,
#[serde(default)] #[serde(default)]
pub word_characters: HashSet<char>, pub word_characters: HashSet<char>,
#[serde(default)]
pub prettier_parser_name: Option<String>,
} }
#[derive(Debug, Default)] #[derive(Debug, Default)]
@ -471,6 +450,7 @@ impl Default for LanguageConfig {
overrides: Default::default(), overrides: Default::default(),
collapsed_placeholder: Default::default(), collapsed_placeholder: Default::default(),
word_characters: Default::default(), word_characters: Default::default(),
prettier_parser_name: None,
} }
} }
} }
@ -496,7 +476,7 @@ pub struct FakeLspAdapter {
pub initializer: Option<Box<dyn 'static + Send + Sync + Fn(&mut lsp::FakeLanguageServer)>>, pub initializer: Option<Box<dyn 'static + Send + Sync + Fn(&mut lsp::FakeLanguageServer)>>,
pub disk_based_diagnostics_progress_token: Option<String>, pub disk_based_diagnostics_progress_token: Option<String>,
pub disk_based_diagnostics_sources: Vec<String>, pub disk_based_diagnostics_sources: Vec<String>,
pub enabled_formatters: Vec<BundledFormatter>, pub prettier_plugins: Vec<&'static str>,
} }
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
@ -1597,6 +1577,10 @@ impl Language {
override_id: None, override_id: None,
} }
} }
pub fn prettier_parser_name(&self) -> Option<&str> {
self.config.prettier_parser_name.as_deref()
}
} }
impl LanguageScope { impl LanguageScope {
@ -1759,7 +1743,7 @@ impl Default for FakeLspAdapter {
disk_based_diagnostics_progress_token: None, disk_based_diagnostics_progress_token: None,
initialization_options: None, initialization_options: None,
disk_based_diagnostics_sources: Vec::new(), disk_based_diagnostics_sources: Vec::new(),
enabled_formatters: Vec::new(), prettier_plugins: Vec::new(),
} }
} }
} }
@ -1817,8 +1801,8 @@ impl LspAdapter for Arc<FakeLspAdapter> {
self.initialization_options.clone() self.initialization_options.clone()
} }
fn enabled_formatters(&self) -> Vec<BundledFormatter> { fn prettier_plugins(&self) -> &[&'static str] {
self.enabled_formatters.clone() &self.prettier_plugins
} }
} }

View File

@ -0,0 +1,42 @@
[package]
name = "notifications"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
path = "src/notification_store.rs"
doctest = false
[features]
test-support = [
"channel/test-support",
"collections/test-support",
"gpui/test-support",
"rpc/test-support",
]
[dependencies]
channel = { path = "../channel" }
client = { path = "../client" }
clock = { path = "../clock" }
collections = { path = "../collections" }
db = { path = "../db" }
feature_flags = { path = "../feature_flags" }
gpui = { path = "../gpui" }
rpc = { path = "../rpc" }
settings = { path = "../settings" }
sum_tree = { path = "../sum_tree" }
text = { path = "../text" }
util = { path = "../util" }
anyhow.workspace = true
time.workspace = true
[dev-dependencies]
client = { path = "../client", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }

View File

@ -0,0 +1,459 @@
use anyhow::Result;
use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
use client::{Client, UserStore};
use collections::HashMap;
use db::smol::stream::StreamExt;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
use rpc::{proto, Notification, TypedEnvelope};
use std::{ops::Range, sync::Arc};
use sum_tree::{Bias, SumTree};
use time::OffsetDateTime;
use util::ResultExt;
pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
cx.set_global(notification_store);
}
pub struct NotificationStore {
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
channel_messages: HashMap<u64, ChannelMessage>,
channel_store: ModelHandle<ChannelStore>,
notifications: SumTree<NotificationEntry>,
loaded_all_notifications: bool,
_watch_connection_status: Task<Option<()>>,
_subscriptions: Vec<client::Subscription>,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum NotificationEvent {
NotificationsUpdated {
old_range: Range<usize>,
new_count: usize,
},
NewNotification {
entry: NotificationEntry,
},
NotificationRemoved {
entry: NotificationEntry,
},
NotificationRead {
entry: NotificationEntry,
},
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct NotificationEntry {
pub id: u64,
pub notification: Notification,
pub timestamp: OffsetDateTime,
pub is_read: bool,
pub response: Option<bool>,
}
#[derive(Clone, Debug, Default)]
pub struct NotificationSummary {
max_id: u64,
count: usize,
unread_count: usize,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct Count(usize);
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct UnreadCount(usize);
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct NotificationId(u64);
impl NotificationStore {
pub fn global(cx: &AppContext) -> ModelHandle<Self> {
cx.global::<ModelHandle<Self>>().clone()
}
pub fn new(
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
cx: &mut ModelContext<Self>,
) -> Self {
let mut connection_status = client.status();
let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
while let Some(status) = connection_status.next().await {
let this = this.upgrade(&cx)?;
match status {
client::Status::Connected { .. } => {
if let Some(task) = this.update(&mut cx, |this, cx| this.handle_connect(cx))
{
task.await.log_err()?;
}
}
_ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
}
}
Some(())
});
Self {
channel_store: ChannelStore::global(cx),
notifications: Default::default(),
loaded_all_notifications: false,
channel_messages: Default::default(),
_watch_connection_status: watch_connection_status,
_subscriptions: vec![
client.add_message_handler(cx.handle(), Self::handle_new_notification),
client.add_message_handler(cx.handle(), Self::handle_delete_notification),
],
user_store,
client,
}
}
pub fn notification_count(&self) -> usize {
self.notifications.summary().count
}
pub fn unread_notification_count(&self) -> usize {
self.notifications.summary().unread_count
}
pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
self.channel_messages.get(&id)
}
// Get the nth newest notification.
pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
let count = self.notifications.summary().count;
if ix >= count {
return None;
}
let ix = count - 1 - ix;
let mut cursor = self.notifications.cursor::<Count>();
cursor.seek(&Count(ix), Bias::Right, &());
cursor.item()
}
pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
let mut cursor = self.notifications.cursor::<NotificationId>();
cursor.seek(&NotificationId(id), Bias::Left, &());
if let Some(item) = cursor.item() {
if item.id == id {
return Some(item);
}
}
None
}
pub fn load_more_notifications(
&self,
clear_old: bool,
cx: &mut ModelContext<Self>,
) -> Option<Task<Result<()>>> {
if self.loaded_all_notifications && !clear_old {
return None;
}
let before_id = if clear_old {
None
} else {
self.notifications.first().map(|entry| entry.id)
};
let request = self.client.request(proto::GetNotifications { before_id });
Some(cx.spawn(|this, mut cx| async move {
let response = request.await?;
this.update(&mut cx, |this, _| {
this.loaded_all_notifications = response.done
});
Self::add_notifications(
this,
response.notifications,
AddNotificationsOptions {
is_new: false,
clear_old,
includes_first: response.done,
},
cx,
)
.await?;
Ok(())
}))
}
fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
self.notifications = Default::default();
self.channel_messages = Default::default();
cx.notify();
self.load_more_notifications(true, cx)
}
fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
cx.notify()
}
async fn handle_new_notification(
this: ModelHandle<Self>,
envelope: TypedEnvelope<proto::AddNotification>,
_: Arc<Client>,
cx: AsyncAppContext,
) -> Result<()> {
Self::add_notifications(
this,
envelope.payload.notification.into_iter().collect(),
AddNotificationsOptions {
is_new: true,
clear_old: false,
includes_first: false,
},
cx,
)
.await
}
async fn handle_delete_notification(
this: ModelHandle<Self>,
envelope: TypedEnvelope<proto::DeleteNotification>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
Ok(())
})
}
async fn add_notifications(
this: ModelHandle<Self>,
notifications: Vec<proto::Notification>,
options: AddNotificationsOptions,
mut cx: AsyncAppContext,
) -> Result<()> {
let mut user_ids = Vec::new();
let mut message_ids = Vec::new();
let notifications = notifications
.into_iter()
.filter_map(|message| {
Some(NotificationEntry {
id: message.id,
is_read: message.is_read,
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
.ok()?,
notification: Notification::from_proto(&message)?,
response: message.response,
})
})
.collect::<Vec<_>>();
if notifications.is_empty() {
return Ok(());
}
for entry in &notifications {
match entry.notification {
Notification::ChannelInvitation { inviter_id, .. } => {
user_ids.push(inviter_id);
}
Notification::ContactRequest {
sender_id: requester_id,
} => {
user_ids.push(requester_id);
}
Notification::ContactRequestAccepted {
responder_id: contact_id,
} => {
user_ids.push(contact_id);
}
Notification::ChannelMessageMention {
sender_id,
message_id,
..
} => {
user_ids.push(sender_id);
message_ids.push(message_id);
}
}
}
let (user_store, channel_store) = this.read_with(&cx, |this, _| {
(this.user_store.clone(), this.channel_store.clone())
});
user_store
.update(&mut cx, |store, cx| store.get_users(user_ids, cx))
.await?;
let messages = channel_store
.update(&mut cx, |store, cx| {
store.fetch_channel_messages(message_ids, cx)
})
.await?;
this.update(&mut cx, |this, cx| {
if options.clear_old {
cx.emit(NotificationEvent::NotificationsUpdated {
old_range: 0..this.notifications.summary().count,
new_count: 0,
});
this.notifications = SumTree::default();
this.channel_messages.clear();
this.loaded_all_notifications = false;
}
if options.includes_first {
this.loaded_all_notifications = true;
}
this.channel_messages
.extend(messages.into_iter().filter_map(|message| {
if let ChannelMessageId::Saved(id) = message.id {
Some((id, message))
} else {
None
}
}));
this.splice_notifications(
notifications
.into_iter()
.map(|notification| (notification.id, Some(notification))),
options.is_new,
cx,
);
});
Ok(())
}
fn splice_notifications(
&mut self,
notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
is_new: bool,
cx: &mut ModelContext<'_, NotificationStore>,
) {
let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
let mut new_notifications = SumTree::new();
let mut old_range = 0..0;
for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
if i == 0 {
old_range.start = cursor.start().1 .0;
}
let old_notification = cursor.item();
if let Some(old_notification) = old_notification {
if old_notification.id == id {
cursor.next(&());
if let Some(new_notification) = &new_notification {
if new_notification.is_read {
cx.emit(NotificationEvent::NotificationRead {
entry: new_notification.clone(),
});
}
} else {
cx.emit(NotificationEvent::NotificationRemoved {
entry: old_notification.clone(),
});
}
}
} else if let Some(new_notification) = &new_notification {
if is_new {
cx.emit(NotificationEvent::NewNotification {
entry: new_notification.clone(),
});
}
}
if let Some(notification) = new_notification {
new_notifications.push(notification, &());
}
}
old_range.end = cursor.start().1 .0;
let new_count = new_notifications.summary().count - old_range.start;
new_notifications.append(cursor.suffix(&()), &());
drop(cursor);
self.notifications = new_notifications;
cx.emit(NotificationEvent::NotificationsUpdated {
old_range,
new_count,
});
}
pub fn respond_to_notification(
&mut self,
notification: Notification,
response: bool,
cx: &mut ModelContext<Self>,
) {
match notification {
Notification::ContactRequest { sender_id } => {
self.user_store
.update(cx, |store, cx| {
store.respond_to_contact_request(sender_id, response, cx)
})
.detach();
}
Notification::ChannelInvitation { channel_id, .. } => {
self.channel_store
.update(cx, |store, cx| {
store.respond_to_channel_invite(channel_id, response, cx)
})
.detach();
}
_ => {}
}
}
}
impl Entity for NotificationStore {
type Event = NotificationEvent;
}
impl sum_tree::Item for NotificationEntry {
type Summary = NotificationSummary;
fn summary(&self) -> Self::Summary {
NotificationSummary {
max_id: self.id,
count: 1,
unread_count: if self.is_read { 0 } else { 1 },
}
}
}
impl sum_tree::Summary for NotificationSummary {
type Context = ();
fn add_summary(&mut self, summary: &Self, _: &()) {
self.max_id = self.max_id.max(summary.max_id);
self.count += summary.count;
self.unread_count += summary.unread_count;
}
}
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
debug_assert!(summary.max_id > self.0);
self.0 = summary.max_id;
}
}
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
self.0 += summary.count;
}
}
impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
self.0 += summary.unread_count;
}
}
struct AddNotificationsOptions {
is_new: bool,
clear_old: bool,
includes_first: bool,
}

View File

@ -3,11 +3,11 @@ use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use anyhow::Context; use anyhow::Context;
use collections::{HashMap, HashSet}; use collections::HashMap;
use fs::Fs; use fs::Fs;
use gpui::{AsyncAppContext, ModelHandle}; use gpui::{AsyncAppContext, ModelHandle};
use language::language_settings::language_settings; use language::language_settings::language_settings;
use language::{Buffer, BundledFormatter, Diff}; use language::{Buffer, Diff};
use lsp::{LanguageServer, LanguageServerId}; use lsp::{LanguageServer, LanguageServerId};
use node_runtime::NodeRuntime; use node_runtime::NodeRuntime;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -242,40 +242,16 @@ impl Prettier {
Self::Real(local) => { Self::Real(local) => {
let params = buffer.read_with(cx, |buffer, cx| { let params = buffer.read_with(cx, |buffer, cx| {
let buffer_language = buffer.language(); let buffer_language = buffer.language();
let parsers_with_plugins = buffer_language let parser_with_plugins = buffer_language.and_then(|l| {
.into_iter() let prettier_parser = l.prettier_parser_name()?;
.flat_map(|language| { let mut prettier_plugins = l
language .lsp_adapters()
.lsp_adapters() .iter()
.iter() .flat_map(|adapter| adapter.prettier_plugins())
.flat_map(|adapter| adapter.enabled_formatters()) .collect::<Vec<_>>();
.filter_map(|formatter| match formatter { prettier_plugins.dedup();
BundledFormatter::Prettier { Some((prettier_parser, prettier_plugins))
parser_name, });
plugin_names,
} => Some((parser_name, plugin_names)),
})
})
.fold(
HashMap::default(),
|mut parsers_with_plugins, (parser_name, plugins)| {
match parser_name {
Some(parser_name) => parsers_with_plugins
.entry(parser_name)
.or_insert_with(HashSet::default)
.extend(plugins),
None => parsers_with_plugins.values_mut().for_each(|existing_plugins| {
existing_plugins.extend(plugins.iter());
}),
}
parsers_with_plugins
},
);
let selected_parser_with_plugins = parsers_with_plugins.iter().max_by_key(|(_, plugins)| plugins.len());
if parsers_with_plugins.len() > 1 {
log::warn!("Found multiple parsers with plugins {parsers_with_plugins:?}, will select only one: {selected_parser_with_plugins:?}");
}
let prettier_node_modules = self.prettier_dir().join("node_modules"); let prettier_node_modules = self.prettier_dir().join("node_modules");
anyhow::ensure!(prettier_node_modules.is_dir(), "Prettier node_modules dir does not exist: {prettier_node_modules:?}"); anyhow::ensure!(prettier_node_modules.is_dir(), "Prettier node_modules dir does not exist: {prettier_node_modules:?}");
@ -296,7 +272,7 @@ impl Prettier {
} }
None None
}; };
let (parser, located_plugins) = match selected_parser_with_plugins { let (parser, located_plugins) = match parser_with_plugins {
Some((parser, plugins)) => { Some((parser, plugins)) => {
// Tailwind plugin requires being added last // Tailwind plugin requires being added last
// https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins // https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins

View File

@ -39,11 +39,11 @@ use language::{
deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version, deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version,
serialize_anchor, serialize_version, split_operations, serialize_anchor, serialize_version, split_operations,
}, },
range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, BundledFormatter, CachedLspAdapter, range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction,
CodeAction, CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent,
Event as BufferEvent, File as _, Language, LanguageRegistry, LanguageServerName, LocalFile, File as _, Language, LanguageRegistry, LanguageServerName, LocalFile, LspAdapterDelegate,
LspAdapterDelegate, OffsetRangeExt, Operation, Patch, PendingLanguageServer, PointUtf16, OffsetRangeExt, Operation, Patch, PendingLanguageServer, PointUtf16, TextBufferSnapshot,
TextBufferSnapshot, ToOffset, ToPointUtf16, Transaction, Unclipped, ToOffset, ToPointUtf16, Transaction, Unclipped,
}; };
use log::error; use log::error;
use lsp::{ use lsp::{
@ -8352,12 +8352,7 @@ impl Project {
let Some(buffer_language) = buffer.language() else { let Some(buffer_language) = buffer.language() else {
return Task::ready(None); return Task::ready(None);
}; };
if !buffer_language if buffer_language.prettier_parser_name().is_none() {
.lsp_adapters()
.iter()
.flat_map(|adapter| adapter.enabled_formatters())
.any(|formatter| matches!(formatter, BundledFormatter::Prettier { .. }))
{
return Task::ready(None); return Task::ready(None);
} }
@ -8510,16 +8505,15 @@ impl Project {
}; };
let mut prettier_plugins = None; let mut prettier_plugins = None;
for formatter in new_language if new_language.prettier_parser_name().is_some() {
.lsp_adapters() prettier_plugins
.into_iter() .get_or_insert_with(|| HashSet::default())
.flat_map(|adapter| adapter.enabled_formatters()) .extend(
{ new_language
match formatter { .lsp_adapters()
BundledFormatter::Prettier { plugin_names, .. } => prettier_plugins .iter()
.get_or_insert_with(|| HashSet::default()) .flat_map(|adapter| adapter.prettier_plugins()),
.extend(plugin_names), )
}
} }
let Some(prettier_plugins) = prettier_plugins else { let Some(prettier_plugins) = prettier_plugins else {
return Task::ready(Ok(())); return Task::ready(Ok(()));

View File

@ -1,20 +1,35 @@
use std::{ops::Range, sync::Arc}; use std::{ops::Range, sync::Arc};
use anyhow::bail;
use futures::FutureExt; use futures::FutureExt;
use gpui::{ use gpui::{
color::Color,
elements::Text, elements::Text,
fonts::{HighlightStyle, TextStyle, Underline, Weight}, fonts::{HighlightStyle, Underline, Weight},
platform::{CursorStyle, MouseButton}, platform::{CursorStyle, MouseButton},
AnyElement, CursorRegion, Element, MouseRegion, ViewContext, AnyElement, CursorRegion, Element, MouseRegion, ViewContext,
}; };
use language::{HighlightId, Language, LanguageRegistry}; use language::{HighlightId, Language, LanguageRegistry};
use theme::SyntaxTheme; use theme::{RichTextStyle, SyntaxTheme};
use util::RangeExt;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum Highlight { pub enum Highlight {
Id(HighlightId), Id(HighlightId),
Highlight(HighlightStyle), Highlight(HighlightStyle),
Mention,
SelfMention,
}
impl From<HighlightStyle> for Highlight {
fn from(style: HighlightStyle) -> Self {
Self::Highlight(style)
}
}
impl From<HighlightId> for Highlight {
fn from(style: HighlightId) -> Self {
Self::Id(style)
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -25,18 +40,32 @@ pub struct RichText {
pub regions: Vec<RenderedRegion>, pub regions: Vec<RenderedRegion>,
} }
#[derive(Debug, Clone)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BackgroundKind {
Code,
/// A mention background for non-self user.
Mention,
SelfMention,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RenderedRegion { pub struct RenderedRegion {
code: bool, pub background_kind: Option<BackgroundKind>,
link_url: Option<String>, pub link_url: Option<String>,
}
/// Allows one to specify extra links to the rendered markdown, which can be used
/// for e.g. mentions.
pub struct Mention {
pub range: Range<usize>,
pub is_self_mention: bool,
} }
impl RichText { impl RichText {
pub fn element<V: 'static>( pub fn element<V: 'static>(
&self, &self,
syntax: Arc<SyntaxTheme>, syntax: Arc<SyntaxTheme>,
style: TextStyle, style: RichTextStyle,
code_span_background_color: Color,
cx: &mut ViewContext<V>, cx: &mut ViewContext<V>,
) -> AnyElement<V> { ) -> AnyElement<V> {
let mut region_id = 0; let mut region_id = 0;
@ -45,7 +74,7 @@ impl RichText {
let regions = self.regions.clone(); let regions = self.regions.clone();
enum Markdown {} enum Markdown {}
Text::new(self.text.clone(), style.clone()) Text::new(self.text.clone(), style.text.clone())
.with_highlights( .with_highlights(
self.highlights self.highlights
.iter() .iter()
@ -53,6 +82,8 @@ impl RichText {
let style = match highlight { let style = match highlight {
Highlight::Id(id) => id.style(&syntax)?, Highlight::Id(id) => id.style(&syntax)?,
Highlight::Highlight(style) => style.clone(), Highlight::Highlight(style) => style.clone(),
Highlight::Mention => style.mention_highlight,
Highlight::SelfMention => style.self_mention_highlight,
}; };
Some((range.clone(), style)) Some((range.clone(), style))
}) })
@ -73,22 +104,55 @@ impl RichText {
}), }),
); );
} }
if region.code { if let Some(region_kind) = &region.background_kind {
cx.scene().push_quad(gpui::Quad { let background = match region_kind {
bounds, BackgroundKind::Code => style.code_background,
background: Some(code_span_background_color), BackgroundKind::Mention => style.mention_background,
border: Default::default(), BackgroundKind::SelfMention => style.self_mention_background,
corner_radii: (2.0).into(), };
}); if background.is_some() {
cx.scene().push_quad(gpui::Quad {
bounds,
background,
border: Default::default(),
corner_radii: (2.0).into(),
});
}
} }
}) })
.with_soft_wrap(true) .with_soft_wrap(true)
.into_any() .into_any()
} }
pub fn add_mention(
&mut self,
range: Range<usize>,
is_current_user: bool,
mention_style: HighlightStyle,
) -> anyhow::Result<()> {
if range.end > self.text.len() {
bail!(
"Mention in range {range:?} is outside of bounds for a message of length {}",
self.text.len()
);
}
if is_current_user {
self.region_ranges.push(range.clone());
self.regions.push(RenderedRegion {
background_kind: Some(BackgroundKind::Mention),
link_url: None,
});
}
self.highlights
.push((range, Highlight::Highlight(mention_style)));
Ok(())
}
} }
pub fn render_markdown_mut( pub fn render_markdown_mut(
block: &str, block: &str,
mut mentions: &[Mention],
language_registry: &Arc<LanguageRegistry>, language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>, language: Option<&Arc<Language>>,
data: &mut RichText, data: &mut RichText,
@ -101,15 +165,40 @@ pub fn render_markdown_mut(
let mut current_language = None; let mut current_language = None;
let mut list_stack = Vec::new(); let mut list_stack = Vec::new();
for event in Parser::new_ext(&block, Options::all()) { let options = Options::all();
for (event, source_range) in Parser::new_ext(&block, options).into_offset_iter() {
let prev_len = data.text.len(); let prev_len = data.text.len();
match event { match event {
Event::Text(t) => { Event::Text(t) => {
if let Some(language) = &current_language { if let Some(language) = &current_language {
render_code(&mut data.text, &mut data.highlights, t.as_ref(), language); render_code(&mut data.text, &mut data.highlights, t.as_ref(), language);
} else { } else {
data.text.push_str(t.as_ref()); if let Some(mention) = mentions.first() {
if source_range.contains_inclusive(&mention.range) {
mentions = &mentions[1..];
let range = (prev_len + mention.range.start - source_range.start)
..(prev_len + mention.range.end - source_range.start);
data.highlights.push((
range.clone(),
if mention.is_self_mention {
Highlight::SelfMention
} else {
Highlight::Mention
},
));
data.region_ranges.push(range);
data.regions.push(RenderedRegion {
background_kind: Some(if mention.is_self_mention {
BackgroundKind::SelfMention
} else {
BackgroundKind::Mention
}),
link_url: None,
});
}
}
data.text.push_str(t.as_ref());
let mut style = HighlightStyle::default(); let mut style = HighlightStyle::default();
if bold_depth > 0 { if bold_depth > 0 {
style.weight = Some(Weight::BOLD); style.weight = Some(Weight::BOLD);
@ -121,7 +210,7 @@ pub fn render_markdown_mut(
data.region_ranges.push(prev_len..data.text.len()); data.region_ranges.push(prev_len..data.text.len());
data.regions.push(RenderedRegion { data.regions.push(RenderedRegion {
link_url: Some(link_url), link_url: Some(link_url),
code: false, background_kind: None,
}); });
style.underline = Some(Underline { style.underline = Some(Underline {
thickness: 1.0.into(), thickness: 1.0.into(),
@ -162,7 +251,7 @@ pub fn render_markdown_mut(
)); ));
} }
data.regions.push(RenderedRegion { data.regions.push(RenderedRegion {
code: true, background_kind: Some(BackgroundKind::Code),
link_url: link_url.clone(), link_url: link_url.clone(),
}); });
} }
@ -228,6 +317,7 @@ pub fn render_markdown_mut(
pub fn render_markdown( pub fn render_markdown(
block: String, block: String,
mentions: &[Mention],
language_registry: &Arc<LanguageRegistry>, language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>, language: Option<&Arc<Language>>,
) -> RichText { ) -> RichText {
@ -238,7 +328,7 @@ pub fn render_markdown(
regions: Default::default(), regions: Default::default(),
}; };
render_markdown_mut(&block, language_registry, language, &mut data); render_markdown_mut(&block, mentions, language_registry, language, &mut data);
data.text = data.text.trim().to_string(); data.text = data.text.trim().to_string();

View File

@ -17,6 +17,7 @@ clock = { path = "../clock" }
collections = { path = "../collections" } collections = { path = "../collections" }
gpui = { path = "../gpui", optional = true } gpui = { path = "../gpui", optional = true }
util = { path = "../util" } util = { path = "../util" }
anyhow.workspace = true anyhow.workspace = true
async-lock = "2.4" async-lock = "2.4"
async-tungstenite = "0.16" async-tungstenite = "0.16"
@ -27,8 +28,10 @@ prost.workspace = true
rand.workspace = true rand.workspace = true
rsa = "0.4" rsa = "0.4"
serde.workspace = true serde.workspace = true
serde_json.workspace = true
serde_derive.workspace = true serde_derive.workspace = true
smol-timeout = "0.6" smol-timeout = "0.6"
strum.workspace = true
tracing = { version = "0.1.34", features = ["log"] } tracing = { version = "0.1.34", features = ["log"] }
zstd = "0.11" zstd = "0.11"

View File

@ -157,23 +157,30 @@ message Envelope {
UpdateChannelBufferCollaborators update_channel_buffer_collaborators = 130; UpdateChannelBufferCollaborators update_channel_buffer_collaborators = 130;
RejoinChannelBuffers rejoin_channel_buffers = 131; RejoinChannelBuffers rejoin_channel_buffers = 131;
RejoinChannelBuffersResponse rejoin_channel_buffers_response = 132; RejoinChannelBuffersResponse rejoin_channel_buffers_response = 132;
AckBufferOperation ack_buffer_operation = 145; AckBufferOperation ack_buffer_operation = 133;
JoinChannelChat join_channel_chat = 133; JoinChannelChat join_channel_chat = 134;
JoinChannelChatResponse join_channel_chat_response = 134; JoinChannelChatResponse join_channel_chat_response = 135;
LeaveChannelChat leave_channel_chat = 135; LeaveChannelChat leave_channel_chat = 136;
SendChannelMessage send_channel_message = 136; SendChannelMessage send_channel_message = 137;
SendChannelMessageResponse send_channel_message_response = 137; SendChannelMessageResponse send_channel_message_response = 138;
ChannelMessageSent channel_message_sent = 138; ChannelMessageSent channel_message_sent = 139;
GetChannelMessages get_channel_messages = 139; GetChannelMessages get_channel_messages = 140;
GetChannelMessagesResponse get_channel_messages_response = 140; GetChannelMessagesResponse get_channel_messages_response = 141;
RemoveChannelMessage remove_channel_message = 141; RemoveChannelMessage remove_channel_message = 142;
AckChannelMessage ack_channel_message = 146; AckChannelMessage ack_channel_message = 143;
GetChannelMessagesById get_channel_messages_by_id = 144;
LinkChannel link_channel = 142; LinkChannel link_channel = 145;
UnlinkChannel unlink_channel = 143; UnlinkChannel unlink_channel = 146;
MoveChannel move_channel = 144; MoveChannel move_channel = 147;
SetChannelVisibility set_channel_visibility = 147; // current max: 147 SetChannelVisibility set_channel_visibility = 148;
AddNotification add_notification = 149;
GetNotifications get_notifications = 150;
GetNotificationsResponse get_notifications_response = 151;
DeleteNotification delete_notification = 152;
MarkNotificationRead mark_notification_read = 153; // Current max
} }
} }
@ -1094,6 +1101,7 @@ message SendChannelMessage {
uint64 channel_id = 1; uint64 channel_id = 1;
string body = 2; string body = 2;
Nonce nonce = 3; Nonce nonce = 3;
repeated ChatMention mentions = 4;
} }
message RemoveChannelMessage { message RemoveChannelMessage {
@ -1125,6 +1133,10 @@ message GetChannelMessagesResponse {
bool done = 2; bool done = 2;
} }
message GetChannelMessagesById {
repeated uint64 message_ids = 1;
}
message LinkChannel { message LinkChannel {
uint64 channel_id = 1; uint64 channel_id = 1;
uint64 to = 2; uint64 to = 2;
@ -1151,6 +1163,12 @@ message ChannelMessage {
uint64 timestamp = 3; uint64 timestamp = 3;
uint64 sender_id = 4; uint64 sender_id = 4;
Nonce nonce = 5; Nonce nonce = 5;
repeated ChatMention mentions = 6;
}
message ChatMention {
Range range = 1;
uint64 user_id = 2;
} }
message RejoinChannelBuffers { message RejoinChannelBuffers {
@ -1242,7 +1260,6 @@ message ShowContacts {}
message IncomingContactRequest { message IncomingContactRequest {
uint64 requester_id = 1; uint64 requester_id = 1;
bool should_notify = 2;
} }
message UpdateDiagnostics { message UpdateDiagnostics {
@ -1575,7 +1592,6 @@ message Contact {
uint64 user_id = 1; uint64 user_id = 1;
bool online = 2; bool online = 2;
bool busy = 3; bool busy = 3;
bool should_notify = 4;
} }
message WorktreeMetadata { message WorktreeMetadata {
@ -1590,3 +1606,34 @@ message UpdateDiffBase {
uint64 buffer_id = 2; uint64 buffer_id = 2;
optional string diff_base = 3; optional string diff_base = 3;
} }
message GetNotifications {
optional uint64 before_id = 1;
}
message AddNotification {
Notification notification = 1;
}
message GetNotificationsResponse {
repeated Notification notifications = 1;
bool done = 2;
}
message DeleteNotification {
uint64 notification_id = 1;
}
message MarkNotificationRead {
uint64 notification_id = 1;
}
message Notification {
uint64 id = 1;
uint64 timestamp = 2;
string kind = 3;
optional uint64 entity_id = 4;
string content = 5;
bool is_read = 6;
optional bool response = 7;
}

View File

@ -0,0 +1,105 @@
use crate::proto;
use serde::{Deserialize, Serialize};
use serde_json::{map, Value};
use strum::{EnumVariantNames, VariantNames as _};
const KIND: &'static str = "kind";
const ENTITY_ID: &'static str = "entity_id";
/// A notification that can be stored, associated with a given recipient.
///
/// This struct is stored in the collab database as JSON, so it shouldn't be
/// changed in a backward-incompatible way. For example, when renaming a
/// variant, add a serde alias for the old name.
///
/// Most notification types have a special field which is aliased to
/// `entity_id`. This field is stored in its own database column, and can
/// be used to query the notification.
#[derive(Debug, Clone, PartialEq, Eq, EnumVariantNames, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum Notification {
ContactRequest {
#[serde(rename = "entity_id")]
sender_id: u64,
},
ContactRequestAccepted {
#[serde(rename = "entity_id")]
responder_id: u64,
},
ChannelInvitation {
#[serde(rename = "entity_id")]
channel_id: u64,
channel_name: String,
inviter_id: u64,
},
ChannelMessageMention {
#[serde(rename = "entity_id")]
message_id: u64,
sender_id: u64,
channel_id: u64,
},
}
impl Notification {
pub fn to_proto(&self) -> proto::Notification {
let mut value = serde_json::to_value(self).unwrap();
let mut entity_id = None;
let value = value.as_object_mut().unwrap();
let Some(Value::String(kind)) = value.remove(KIND) else {
unreachable!("kind is the enum tag")
};
if let map::Entry::Occupied(e) = value.entry(ENTITY_ID) {
if e.get().is_u64() {
entity_id = e.remove().as_u64();
}
}
proto::Notification {
kind,
entity_id,
content: serde_json::to_string(&value).unwrap(),
..Default::default()
}
}
pub fn from_proto(notification: &proto::Notification) -> Option<Self> {
let mut value = serde_json::from_str::<Value>(&notification.content).ok()?;
let object = value.as_object_mut()?;
object.insert(KIND.into(), notification.kind.to_string().into());
if let Some(entity_id) = notification.entity_id {
object.insert(ENTITY_ID.into(), entity_id.into());
}
serde_json::from_value(value).ok()
}
pub fn all_variant_names() -> &'static [&'static str] {
Self::VARIANTS
}
}
#[test]
fn test_notification() {
// Notifications can be serialized and deserialized.
for notification in [
Notification::ContactRequest { sender_id: 1 },
Notification::ContactRequestAccepted { responder_id: 2 },
Notification::ChannelInvitation {
channel_id: 100,
channel_name: "the-channel".into(),
inviter_id: 50,
},
Notification::ChannelMessageMention {
sender_id: 200,
channel_id: 30,
message_id: 1,
},
] {
let message = notification.to_proto();
let deserialized = Notification::from_proto(&message).unwrap();
assert_eq!(deserialized, notification);
}
// When notifications are serialized, the `kind` and `actor_id` fields are
// stored separately, and do not appear redundantly in the JSON.
let notification = Notification::ContactRequest { sender_id: 1 };
assert_eq!(notification.to_proto().content, "{}");
}

View File

@ -133,6 +133,9 @@ impl fmt::Display for PeerId {
messages!( messages!(
(Ack, Foreground), (Ack, Foreground),
(AckBufferOperation, Background),
(AckChannelMessage, Background),
(AddNotification, Foreground),
(AddProjectCollaborator, Foreground), (AddProjectCollaborator, Foreground),
(ApplyCodeAction, Background), (ApplyCodeAction, Background),
(ApplyCodeActionResponse, Background), (ApplyCodeActionResponse, Background),
@ -143,57 +146,75 @@ messages!(
(Call, Foreground), (Call, Foreground),
(CallCanceled, Foreground), (CallCanceled, Foreground),
(CancelCall, Foreground), (CancelCall, Foreground),
(ChannelMessageSent, Foreground),
(CopyProjectEntry, Foreground), (CopyProjectEntry, Foreground),
(CreateBufferForPeer, Foreground), (CreateBufferForPeer, Foreground),
(CreateChannel, Foreground), (CreateChannel, Foreground),
(CreateChannelResponse, Foreground), (CreateChannelResponse, Foreground),
(ChannelMessageSent, Foreground),
(CreateProjectEntry, Foreground), (CreateProjectEntry, Foreground),
(CreateRoom, Foreground), (CreateRoom, Foreground),
(CreateRoomResponse, Foreground), (CreateRoomResponse, Foreground),
(DeclineCall, Foreground), (DeclineCall, Foreground),
(DeleteChannel, Foreground),
(DeleteNotification, Foreground),
(DeleteProjectEntry, Foreground), (DeleteProjectEntry, Foreground),
(Error, Foreground), (Error, Foreground),
(ExpandProjectEntry, Foreground), (ExpandProjectEntry, Foreground),
(ExpandProjectEntryResponse, Foreground),
(Follow, Foreground), (Follow, Foreground),
(FollowResponse, Foreground), (FollowResponse, Foreground),
(FormatBuffers, Foreground), (FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground), (FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground), (FuzzySearchUsers, Foreground),
(GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground),
(GetChannelMessages, Background),
(GetChannelMessagesById, Background),
(GetChannelMessagesResponse, Background),
(GetCodeActions, Background), (GetCodeActions, Background),
(GetCodeActionsResponse, Background), (GetCodeActionsResponse, Background),
(GetHover, Background),
(GetHoverResponse, Background),
(GetChannelMessages, Background),
(GetChannelMessagesResponse, Background),
(SendChannelMessage, Background),
(SendChannelMessageResponse, Background),
(GetCompletions, Background), (GetCompletions, Background),
(GetCompletionsResponse, Background), (GetCompletionsResponse, Background),
(GetDefinition, Background), (GetDefinition, Background),
(GetDefinitionResponse, Background), (GetDefinitionResponse, Background),
(GetTypeDefinition, Background),
(GetTypeDefinitionResponse, Background),
(GetDocumentHighlights, Background), (GetDocumentHighlights, Background),
(GetDocumentHighlightsResponse, Background), (GetDocumentHighlightsResponse, Background),
(GetReferences, Background), (GetHover, Background),
(GetReferencesResponse, Background), (GetHoverResponse, Background),
(GetNotifications, Foreground),
(GetNotificationsResponse, Foreground),
(GetPrivateUserInfo, Foreground),
(GetPrivateUserInfoResponse, Foreground),
(GetProjectSymbols, Background), (GetProjectSymbols, Background),
(GetProjectSymbolsResponse, Background), (GetProjectSymbolsResponse, Background),
(GetReferences, Background),
(GetReferencesResponse, Background),
(GetTypeDefinition, Background),
(GetTypeDefinitionResponse, Background),
(GetUsers, Foreground), (GetUsers, Foreground),
(Hello, Foreground), (Hello, Foreground),
(IncomingCall, Foreground), (IncomingCall, Foreground),
(InlayHints, Background),
(InlayHintsResponse, Background),
(InviteChannelMember, Foreground), (InviteChannelMember, Foreground),
(UsersResponse, Foreground), (JoinChannel, Foreground),
(JoinChannelBuffer, Foreground),
(JoinChannelBufferResponse, Foreground),
(JoinChannelChat, Foreground),
(JoinChannelChatResponse, Foreground),
(JoinProject, Foreground), (JoinProject, Foreground),
(JoinProjectResponse, Foreground), (JoinProjectResponse, Foreground),
(JoinRoom, Foreground), (JoinRoom, Foreground),
(JoinRoomResponse, Foreground), (JoinRoomResponse, Foreground),
(JoinChannelChat, Foreground), (LeaveChannelBuffer, Background),
(JoinChannelChatResponse, Foreground),
(LeaveChannelChat, Foreground), (LeaveChannelChat, Foreground),
(LeaveProject, Foreground), (LeaveProject, Foreground),
(LeaveRoom, Foreground), (LeaveRoom, Foreground),
(LinkChannel, Foreground),
(MarkNotificationRead, Foreground),
(MoveChannel, Foreground),
(OnTypeFormatting, Background),
(OnTypeFormattingResponse, Background),
(OpenBufferById, Background), (OpenBufferById, Background),
(OpenBufferByPath, Background), (OpenBufferByPath, Background),
(OpenBufferForSymbol, Background), (OpenBufferForSymbol, Background),
@ -201,61 +222,57 @@ messages!(
(OpenBufferResponse, Background), (OpenBufferResponse, Background),
(PerformRename, Background), (PerformRename, Background),
(PerformRenameResponse, Background), (PerformRenameResponse, Background),
(OnTypeFormatting, Background), (Ping, Foreground),
(OnTypeFormattingResponse, Background), (PrepareRename, Background),
(InlayHints, Background), (PrepareRenameResponse, Background),
(InlayHintsResponse, Background), (ProjectEntryResponse, Foreground),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
(RejoinRoom, Foreground),
(RejoinRoomResponse, Foreground),
(ReloadBuffers, Foreground),
(ReloadBuffersResponse, Foreground),
(RemoveChannelMember, Foreground),
(RemoveChannelMessage, Foreground),
(RemoveContact, Foreground),
(RemoveProjectCollaborator, Foreground),
(RenameChannel, Foreground),
(RenameChannelResponse, Foreground),
(RenameProjectEntry, Foreground),
(RequestContact, Foreground),
(ResolveCompletionDocumentation, Background), (ResolveCompletionDocumentation, Background),
(ResolveCompletionDocumentationResponse, Background), (ResolveCompletionDocumentationResponse, Background),
(ResolveInlayHint, Background), (ResolveInlayHint, Background),
(ResolveInlayHintResponse, Background), (ResolveInlayHintResponse, Background),
(RefreshInlayHints, Foreground),
(Ping, Foreground),
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ExpandProjectEntryResponse, Foreground),
(ProjectEntryResponse, Foreground),
(RejoinRoom, Foreground),
(RejoinRoomResponse, Foreground),
(RemoveContact, Foreground),
(RemoveChannelMember, Foreground),
(RemoveChannelMessage, Foreground),
(ReloadBuffers, Foreground),
(ReloadBuffersResponse, Foreground),
(RemoveProjectCollaborator, Foreground),
(RenameProjectEntry, Foreground),
(RequestContact, Foreground),
(RespondToContactRequest, Foreground),
(RespondToChannelInvite, Foreground), (RespondToChannelInvite, Foreground),
(JoinChannel, Foreground), (RespondToContactRequest, Foreground),
(RoomUpdated, Foreground), (RoomUpdated, Foreground),
(SaveBuffer, Foreground), (SaveBuffer, Foreground),
(RenameChannel, Foreground),
(RenameChannelResponse, Foreground),
(SetChannelMemberRole, Foreground), (SetChannelMemberRole, Foreground),
(SetChannelVisibility, Foreground), (SetChannelVisibility, Foreground),
(SearchProject, Background), (SearchProject, Background),
(SearchProjectResponse, Background), (SearchProjectResponse, Background),
(SendChannelMessage, Background),
(SendChannelMessageResponse, Background),
(ShareProject, Foreground), (ShareProject, Foreground),
(ShareProjectResponse, Foreground), (ShareProjectResponse, Foreground),
(ShowContacts, Foreground), (ShowContacts, Foreground),
(StartLanguageServer, Foreground), (StartLanguageServer, Foreground),
(SynchronizeBuffers, Foreground), (SynchronizeBuffers, Foreground),
(SynchronizeBuffersResponse, Foreground), (SynchronizeBuffersResponse, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
(Test, Foreground), (Test, Foreground),
(Unfollow, Foreground), (Unfollow, Foreground),
(UnlinkChannel, Foreground),
(UnshareProject, Foreground), (UnshareProject, Foreground),
(UpdateBuffer, Foreground), (UpdateBuffer, Foreground),
(UpdateBufferFile, Foreground), (UpdateBufferFile, Foreground),
(UpdateContacts, Foreground), (UpdateChannelBuffer, Foreground),
(DeleteChannel, Foreground), (UpdateChannelBufferCollaborators, Foreground),
(MoveChannel, Foreground),
(LinkChannel, Foreground),
(UnlinkChannel, Foreground),
(UpdateChannels, Foreground), (UpdateChannels, Foreground),
(UpdateContacts, Foreground),
(UpdateDiagnosticSummary, Foreground), (UpdateDiagnosticSummary, Foreground),
(UpdateDiffBase, Foreground),
(UpdateFollowers, Foreground), (UpdateFollowers, Foreground),
(UpdateInviteInfo, Foreground), (UpdateInviteInfo, Foreground),
(UpdateLanguageServer, Foreground), (UpdateLanguageServer, Foreground),
@ -264,18 +281,7 @@ messages!(
(UpdateProjectCollaborator, Foreground), (UpdateProjectCollaborator, Foreground),
(UpdateWorktree, Foreground), (UpdateWorktree, Foreground),
(UpdateWorktreeSettings, Foreground), (UpdateWorktreeSettings, Foreground),
(UpdateDiffBase, Foreground), (UsersResponse, Foreground),
(GetPrivateUserInfo, Foreground),
(GetPrivateUserInfoResponse, Foreground),
(GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground),
(JoinChannelBuffer, Foreground),
(JoinChannelBufferResponse, Foreground),
(LeaveChannelBuffer, Background),
(UpdateChannelBuffer, Foreground),
(UpdateChannelBufferCollaborators, Foreground),
(AckBufferOperation, Background),
(AckChannelMessage, Background),
); );
request_messages!( request_messages!(
@ -287,77 +293,80 @@ request_messages!(
(Call, Ack), (Call, Ack),
(CancelCall, Ack), (CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse), (CopyProjectEntry, ProjectEntryResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse), (CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse), (CreateRoom, CreateRoomResponse),
(CreateChannel, CreateChannelResponse),
(DeclineCall, Ack), (DeclineCall, Ack),
(DeleteChannel, Ack),
(DeleteProjectEntry, ProjectEntryResponse), (DeleteProjectEntry, ProjectEntryResponse),
(ExpandProjectEntry, ExpandProjectEntryResponse), (ExpandProjectEntry, ExpandProjectEntryResponse),
(Follow, FollowResponse), (Follow, FollowResponse),
(FormatBuffers, FormatBuffersResponse), (FormatBuffers, FormatBuffersResponse),
(FuzzySearchUsers, UsersResponse),
(GetChannelMembers, GetChannelMembersResponse),
(GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMessagesById, GetChannelMessagesResponse),
(GetCodeActions, GetCodeActionsResponse), (GetCodeActions, GetCodeActionsResponse),
(GetHover, GetHoverResponse),
(GetCompletions, GetCompletionsResponse), (GetCompletions, GetCompletionsResponse),
(GetDefinition, GetDefinitionResponse), (GetDefinition, GetDefinitionResponse),
(GetTypeDefinition, GetTypeDefinitionResponse),
(GetDocumentHighlights, GetDocumentHighlightsResponse), (GetDocumentHighlights, GetDocumentHighlightsResponse),
(GetReferences, GetReferencesResponse), (GetHover, GetHoverResponse),
(GetNotifications, GetNotificationsResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse), (GetProjectSymbols, GetProjectSymbolsResponse),
(FuzzySearchUsers, UsersResponse), (GetReferences, GetReferencesResponse),
(GetTypeDefinition, GetTypeDefinitionResponse),
(GetUsers, UsersResponse), (GetUsers, UsersResponse),
(IncomingCall, Ack),
(InlayHints, InlayHintsResponse),
(InviteChannelMember, Ack), (InviteChannelMember, Ack),
(JoinChannel, JoinRoomResponse),
(JoinChannelBuffer, JoinChannelBufferResponse),
(JoinChannelChat, JoinChannelChatResponse),
(JoinProject, JoinProjectResponse), (JoinProject, JoinProjectResponse),
(JoinRoom, JoinRoomResponse), (JoinRoom, JoinRoomResponse),
(JoinChannelChat, JoinChannelChatResponse), (LeaveChannelBuffer, Ack),
(LeaveRoom, Ack), (LeaveRoom, Ack),
(RejoinRoom, RejoinRoomResponse), (LinkChannel, Ack),
(IncomingCall, Ack), (MarkNotificationRead, Ack),
(MoveChannel, Ack),
(OnTypeFormatting, OnTypeFormattingResponse),
(OpenBufferById, OpenBufferResponse), (OpenBufferById, OpenBufferResponse),
(OpenBufferByPath, OpenBufferResponse), (OpenBufferByPath, OpenBufferResponse),
(OpenBufferForSymbol, OpenBufferForSymbolResponse), (OpenBufferForSymbol, OpenBufferForSymbolResponse),
(Ping, Ack),
(PerformRename, PerformRenameResponse), (PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse), (PrepareRename, PrepareRenameResponse),
(OnTypeFormatting, OnTypeFormattingResponse), (RefreshInlayHints, Ack),
(InlayHints, InlayHintsResponse), (RejoinChannelBuffers, RejoinChannelBuffersResponse),
(RejoinRoom, RejoinRoomResponse),
(ReloadBuffers, ReloadBuffersResponse),
(RemoveChannelMember, Ack),
(RemoveChannelMessage, Ack),
(RemoveContact, Ack),
(RenameChannel, RenameChannelResponse),
(RenameProjectEntry, ProjectEntryResponse),
(RequestContact, Ack),
( (
ResolveCompletionDocumentation, ResolveCompletionDocumentation,
ResolveCompletionDocumentationResponse ResolveCompletionDocumentationResponse
), ),
(ResolveInlayHint, ResolveInlayHintResponse), (ResolveInlayHint, ResolveInlayHintResponse),
(RefreshInlayHints, Ack),
(ReloadBuffers, ReloadBuffersResponse),
(RequestContact, Ack),
(RemoveChannelMember, Ack),
(RemoveContact, Ack),
(RespondToContactRequest, Ack),
(RespondToChannelInvite, Ack), (RespondToChannelInvite, Ack),
(SetChannelMemberRole, Ack), (RespondToContactRequest, Ack),
(SetChannelVisibility, Ack),
(SendChannelMessage, SendChannelMessageResponse),
(GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMembers, GetChannelMembersResponse),
(JoinChannel, JoinRoomResponse),
(RemoveChannelMessage, Ack),
(DeleteChannel, Ack),
(RenameProjectEntry, ProjectEntryResponse),
(RenameChannel, RenameChannelResponse),
(LinkChannel, Ack),
(UnlinkChannel, Ack),
(MoveChannel, Ack),
(SaveBuffer, BufferSaved), (SaveBuffer, BufferSaved),
(SearchProject, SearchProjectResponse), (SearchProject, SearchProjectResponse),
(SendChannelMessage, SendChannelMessageResponse),
(SetChannelMemberRole, Ack),
(SetChannelVisibility, Ack),
(ShareProject, ShareProjectResponse), (ShareProject, ShareProjectResponse),
(SynchronizeBuffers, SynchronizeBuffersResponse), (SynchronizeBuffers, SynchronizeBuffersResponse),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(Test, Test), (Test, Test),
(UnlinkChannel, Ack),
(UpdateBuffer, Ack), (UpdateBuffer, Ack),
(UpdateParticipantLocation, Ack), (UpdateParticipantLocation, Ack),
(UpdateProject, Ack), (UpdateProject, Ack),
(UpdateWorktree, Ack), (UpdateWorktree, Ack),
(JoinChannelBuffer, JoinChannelBufferResponse),
(LeaveChannelBuffer, Ack)
); );
entity_messages!( entity_messages!(
@ -376,26 +385,26 @@ entity_messages!(
GetCodeActions, GetCodeActions,
GetCompletions, GetCompletions,
GetDefinition, GetDefinition,
GetTypeDefinition,
GetDocumentHighlights, GetDocumentHighlights,
GetHover, GetHover,
GetReferences,
GetProjectSymbols, GetProjectSymbols,
GetReferences,
GetTypeDefinition,
InlayHints,
JoinProject, JoinProject,
LeaveProject, LeaveProject,
OnTypeFormatting,
OpenBufferById, OpenBufferById,
OpenBufferByPath, OpenBufferByPath,
OpenBufferForSymbol, OpenBufferForSymbol,
PerformRename, PerformRename,
OnTypeFormatting,
InlayHints,
ResolveCompletionDocumentation,
ResolveInlayHint,
RefreshInlayHints,
PrepareRename, PrepareRename,
RefreshInlayHints,
ReloadBuffers, ReloadBuffers,
RemoveProjectCollaborator, RemoveProjectCollaborator,
RenameProjectEntry, RenameProjectEntry,
ResolveCompletionDocumentation,
ResolveInlayHint,
SaveBuffer, SaveBuffer,
SearchProject, SearchProject,
StartLanguageServer, StartLanguageServer,
@ -404,19 +413,19 @@ entity_messages!(
UpdateBuffer, UpdateBuffer,
UpdateBufferFile, UpdateBufferFile,
UpdateDiagnosticSummary, UpdateDiagnosticSummary,
UpdateDiffBase,
UpdateLanguageServer, UpdateLanguageServer,
UpdateProject, UpdateProject,
UpdateProjectCollaborator, UpdateProjectCollaborator,
UpdateWorktree, UpdateWorktree,
UpdateWorktreeSettings, UpdateWorktreeSettings,
UpdateDiffBase
); );
entity_messages!( entity_messages!(
channel_id, channel_id,
ChannelMessageSent, ChannelMessageSent,
UpdateChannelBuffer,
RemoveChannelMessage, RemoveChannelMessage,
UpdateChannelBuffer,
UpdateChannelBufferCollaborators, UpdateChannelBufferCollaborators,
); );

View File

@ -1,8 +1,11 @@
pub mod auth; pub mod auth;
mod conn; mod conn;
mod notification;
mod peer; mod peer;
pub mod proto; pub mod proto;
pub use conn::Connection; pub use conn::Connection;
pub use notification::*;
pub use peer::*; pub use peer::*;
mod macros; mod macros;

View File

@ -537,6 +537,7 @@ impl BufferSearchBar {
self.active_searchable_item self.active_searchable_item
.as_ref() .as_ref()
.map(|searchable_item| searchable_item.query_suggestion(cx)) .map(|searchable_item| searchable_item.query_suggestion(cx))
.filter(|suggestion| !suggestion.is_empty())
} }
pub fn set_replacement(&mut self, replacement: Option<&str>, cx: &mut ViewContext<Self>) { pub fn set_replacement(&mut self, replacement: Option<&str>, cx: &mut ViewContext<Self>) {

View File

@ -351,33 +351,32 @@ impl View for ProjectSearchView {
SemanticIndexStatus::NotAuthenticated => { SemanticIndexStatus::NotAuthenticated => {
major_text = Cow::Borrowed("Not Authenticated"); major_text = Cow::Borrowed("Not Authenticated");
show_minor_text = false; show_minor_text = false;
Some( Some(vec![
"API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables" "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables."
.to_string(), .to_string(), "If you authenticated using the Assistant Panel, please restart Zed to Authenticate.".to_string()])
)
} }
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()), SemanticIndexStatus::Indexed => Some(vec!["Indexing complete".to_string()]),
SemanticIndexStatus::Indexing { SemanticIndexStatus::Indexing {
remaining_files, remaining_files,
rate_limit_expiry, rate_limit_expiry,
} => { } => {
if remaining_files == 0 { if remaining_files == 0 {
Some(format!("Indexing...")) Some(vec![format!("Indexing...")])
} else { } else {
if let Some(rate_limit_expiry) = rate_limit_expiry { if let Some(rate_limit_expiry) = rate_limit_expiry {
let remaining_seconds = let remaining_seconds =
rate_limit_expiry.duration_since(Instant::now()); rate_limit_expiry.duration_since(Instant::now());
if remaining_seconds > Duration::from_secs(0) { if remaining_seconds > Duration::from_secs(0) {
Some(format!( Some(vec![format!(
"Remaining files to index (rate limit resets in {}s): {}", "Remaining files to index (rate limit resets in {}s): {}",
remaining_seconds.as_secs(), remaining_seconds.as_secs(),
remaining_files remaining_files
)) )])
} else { } else {
Some(format!("Remaining files to index: {}", remaining_files)) Some(vec![format!("Remaining files to index: {}", remaining_files)])
} }
} else { } else {
Some(format!("Remaining files to index: {}", remaining_files)) Some(vec![format!("Remaining files to index: {}", remaining_files)])
} }
} }
} }
@ -394,9 +393,11 @@ impl View for ProjectSearchView {
} else { } else {
match current_mode { match current_mode {
SearchMode::Semantic => { SearchMode::Semantic => {
let mut minor_text = Vec::new(); let mut minor_text: Vec<String> = Vec::new();
minor_text.push("".into()); minor_text.push("".into());
minor_text.extend(semantic_status); if let Some(semantic_status) = semantic_status {
minor_text.extend(semantic_status);
}
if show_minor_text { if show_minor_text {
minor_text minor_text
.push("Simply explain the code you are looking to find.".into()); .push("Simply explain the code you are looking to find.".into());

View File

@ -7,7 +7,10 @@ pub mod semantic_index_settings;
mod semantic_index_tests; mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings; use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; use ai::{
completion::OPENAI_API_URL,
embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet}; use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase; use db::VectorDatabase;
@ -55,6 +58,19 @@ pub fn init(
.join(Path::new(RELEASE_CHANNEL_NAME.as_str())) .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db"); .join("embeddings_db");
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
Some(api_key)
} else if let Some((_, api_key)) = cx
.platform()
.read_credentials(OPENAI_API_URL)
.log_err()
.flatten()
{
String::from_utf8(api_key).log_err()
} else {
None
};
cx.subscribe_global::<WorkspaceCreated, _>({ cx.subscribe_global::<WorkspaceCreated, _>({
move |event, cx| { move |event, cx| {
let Some(semantic_index) = SemanticIndex::global(cx) else { let Some(semantic_index) = SemanticIndex::global(cx) else {
@ -88,7 +104,7 @@ pub fn init(
let semantic_index = SemanticIndex::new( let semantic_index = SemanticIndex::new(
fs, fs,
db_file_path, db_file_path,
Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
language_registry, language_registry,
cx.clone(), cx.clone(),
) )

View File

@ -2,14 +2,15 @@ use crate::{Anchor, BufferSnapshot, TextDimension};
use std::cmp::Ordering; use std::cmp::Ordering;
use std::ops::Range; use std::ops::Range;
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Copy, Clone, Debug, PartialEq)]
pub enum SelectionGoal { pub enum SelectionGoal {
None, None,
Column(u32), HorizontalPosition(f32),
ColumnRange { start: u32, end: u32 }, HorizontalRange { start: f32, end: f32 },
WrappedHorizontalPosition((u32, f32)),
} }
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct Selection<T> { pub struct Selection<T> {
pub id: usize, pub id: usize,
pub start: T, pub start: T,

View File

@ -53,6 +53,7 @@ pub struct Theme {
pub collab_panel: CollabPanel, pub collab_panel: CollabPanel,
pub project_panel: ProjectPanel, pub project_panel: ProjectPanel,
pub chat_panel: ChatPanel, pub chat_panel: ChatPanel,
pub notification_panel: NotificationPanel,
pub command_palette: CommandPalette, pub command_palette: CommandPalette,
pub picker: Picker, pub picker: Picker,
pub editor: Editor, pub editor: Editor,
@ -638,21 +639,43 @@ pub struct ChatPanel {
pub input_editor: FieldEditor, pub input_editor: FieldEditor,
pub avatar: AvatarStyle, pub avatar: AvatarStyle,
pub avatar_container: ContainerStyle, pub avatar_container: ContainerStyle,
pub message: ChatMessage, pub rich_text: RichTextStyle,
pub continuation_message: ChatMessage, pub message_sender: ContainedText,
pub message_timestamp: ContainedText,
pub message: Interactive<ContainerStyle>,
pub continuation_message: Interactive<ContainerStyle>,
pub pending_message: Interactive<ContainerStyle>,
pub last_message_bottom_spacing: f32, pub last_message_bottom_spacing: f32,
pub pending_message: ChatMessage,
pub sign_in_prompt: Interactive<TextStyle>, pub sign_in_prompt: Interactive<TextStyle>,
pub icon_button: Interactive<IconButton>, pub icon_button: Interactive<IconButton>,
} }
#[derive(Clone, Deserialize, Default, JsonSchema)]
pub struct RichTextStyle {
pub text: TextStyle,
pub mention_highlight: HighlightStyle,
pub mention_background: Option<Color>,
pub self_mention_highlight: HighlightStyle,
pub self_mention_background: Option<Color>,
pub code_background: Option<Color>,
}
#[derive(Deserialize, Default, JsonSchema)] #[derive(Deserialize, Default, JsonSchema)]
pub struct ChatMessage { pub struct NotificationPanel {
#[serde(flatten)] #[serde(flatten)]
pub container: Interactive<ContainerStyle>, pub container: ContainerStyle,
pub body: TextStyle, pub title: ContainedText,
pub sender: ContainedText, pub title_icon: SvgStyle,
pub title_height: f32,
pub list: ContainerStyle,
pub avatar: AvatarStyle,
pub avatar_container: ContainerStyle,
pub sign_in_prompt: Interactive<TextStyle>,
pub icon_button: Interactive<IconButton>,
pub unread_text: ContainedText,
pub read_text: ContainedText,
pub timestamp: ContainedText, pub timestamp: ContainedText,
pub button: Interactive<ContainedText>,
} }
#[derive(Deserialize, Default, JsonSchema)] #[derive(Deserialize, Default, JsonSchema)]

View File

@ -7,6 +7,7 @@ publish = false
[dependencies] [dependencies]
fuzzy = {path = "../fuzzy"} fuzzy = {path = "../fuzzy"}
fs = {path = "../fs"}
gpui = {path = "../gpui"} gpui = {path = "../gpui"}
picker = {path = "../picker"} picker = {path = "../picker"}
util = {path = "../util"} util = {path = "../util"}

View File

@ -1,4 +1,5 @@
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
use fs::repository::Branch;
use fuzzy::{StringMatch, StringMatchCandidate}; use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{ use gpui::{
actions, actions,
@ -22,18 +23,9 @@ pub type BranchList = Picker<BranchListDelegate>;
pub fn build_branch_list( pub fn build_branch_list(
workspace: ViewHandle<Workspace>, workspace: ViewHandle<Workspace>,
cx: &mut ViewContext<BranchList>, cx: &mut ViewContext<BranchList>,
) -> BranchList { ) -> Result<BranchList> {
Picker::new( Ok(Picker::new(BranchListDelegate::new(workspace, 29, cx)?, cx)
BranchListDelegate { .with_theme(|theme| theme.picker.clone()))
matches: vec![],
workspace,
selected_index: 0,
last_query: String::default(),
branch_name_trailoff_after: 29,
},
cx,
)
.with_theme(|theme| theme.picker.clone())
} }
fn toggle( fn toggle(
@ -43,31 +35,24 @@ fn toggle(
) -> Option<Task<Result<()>>> { ) -> Option<Task<Result<()>>> {
Some(cx.spawn(|workspace, mut cx| async move { Some(cx.spawn(|workspace, mut cx| async move {
workspace.update(&mut cx, |workspace, cx| { workspace.update(&mut cx, |workspace, cx| {
// Modal branch picker has a longer trailoff than a popover one.
let delegate = BranchListDelegate::new(cx.handle(), 70, cx)?;
workspace.toggle_modal(cx, |_, cx| { workspace.toggle_modal(cx, |_, cx| {
let workspace = cx.handle();
cx.add_view(|cx| { cx.add_view(|cx| {
Picker::new( Picker::new(delegate, cx)
BranchListDelegate { .with_theme(|theme| theme.picker.clone())
matches: vec![], .with_max_size(800., 1200.)
workspace,
selected_index: 0,
last_query: String::default(),
/// Modal branch picker has a longer trailoff than a popover one.
branch_name_trailoff_after: 70,
},
cx,
)
.with_theme(|theme| theme.picker.clone())
.with_max_size(800., 1200.)
}) })
}); });
})?; Ok::<_, anyhow::Error>(())
})??;
Ok(()) Ok(())
})) }))
} }
pub struct BranchListDelegate { pub struct BranchListDelegate {
matches: Vec<StringMatch>, matches: Vec<StringMatch>,
all_branches: Vec<Branch>,
workspace: ViewHandle<Workspace>, workspace: ViewHandle<Workspace>,
selected_index: usize, selected_index: usize,
last_query: String, last_query: String,
@ -76,6 +61,31 @@ pub struct BranchListDelegate {
} }
impl BranchListDelegate { impl BranchListDelegate {
fn new(
workspace: ViewHandle<Workspace>,
branch_name_trailoff_after: usize,
cx: &AppContext,
) -> Result<Self> {
let project = workspace.read(cx).project().read(&cx);
let Some(worktree) = project.visible_worktrees(cx).next() else {
bail!("Cannot update branch list as there are no visible worktrees")
};
let mut cwd = worktree.read(cx).abs_path().to_path_buf();
cwd.push(".git");
let Some(repo) = project.fs().open_repo(&cwd) else {
bail!("Project does not have associated git repository.")
};
let all_branches = repo.lock().branches()?;
Ok(Self {
matches: vec![],
workspace,
all_branches,
selected_index: 0,
last_query: Default::default(),
branch_name_trailoff_after,
})
}
fn display_error_toast(&self, message: String, cx: &mut ViewContext<BranchList>) { fn display_error_toast(&self, message: String, cx: &mut ViewContext<BranchList>) {
const GIT_CHECKOUT_FAILURE_ID: usize = 2048; const GIT_CHECKOUT_FAILURE_ID: usize = 2048;
self.workspace.update(cx, |model, ctx| { self.workspace.update(cx, |model, ctx| {
@ -83,6 +93,7 @@ impl BranchListDelegate {
}); });
} }
} }
impl PickerDelegate for BranchListDelegate { impl PickerDelegate for BranchListDelegate {
fn placeholder_text(&self) -> Arc<str> { fn placeholder_text(&self) -> Arc<str> {
"Select branch...".into() "Select branch...".into()
@ -102,45 +113,28 @@ impl PickerDelegate for BranchListDelegate {
fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> { fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
cx.spawn(move |picker, mut cx| async move { cx.spawn(move |picker, mut cx| async move {
let Some(candidates) = picker let candidates = picker.read_with(&mut cx, |view, _| {
.read_with(&mut cx, |view, cx| { const RECENT_BRANCHES_COUNT: usize = 10;
let delegate = view.delegate(); let mut branches = view.delegate().all_branches.clone();
let project = delegate.workspace.read(cx).project().read(&cx); if query.is_empty() && branches.len() > RECENT_BRANCHES_COUNT {
// Truncate list of recent branches
let Some(worktree) = project.visible_worktrees(cx).next() else { // Do a partial sort to show recent-ish branches first.
bail!("Cannot update branch list as there are no visible worktrees") branches.select_nth_unstable_by(RECENT_BRANCHES_COUNT - 1, |lhs, rhs| {
}; rhs.unix_timestamp.cmp(&lhs.unix_timestamp)
let mut cwd = worktree.read(cx).abs_path().to_path_buf(); });
cwd.push(".git"); branches.truncate(RECENT_BRANCHES_COUNT);
let Some(repo) = project.fs().open_repo(&cwd) else { branches.sort_unstable_by(|lhs, rhs| lhs.name.cmp(&rhs.name));
bail!("Project does not have associated git repository.") }
}; branches
let mut branches = repo.lock().branches()?; .into_iter()
const RECENT_BRANCHES_COUNT: usize = 10; .enumerate()
if query.is_empty() && branches.len() > RECENT_BRANCHES_COUNT { .map(|(ix, command)| StringMatchCandidate {
// Truncate list of recent branches id: ix,
// Do a partial sort to show recent-ish branches first. char_bag: command.name.chars().collect(),
branches.select_nth_unstable_by(RECENT_BRANCHES_COUNT - 1, |lhs, rhs| { string: command.name.into(),
rhs.unix_timestamp.cmp(&lhs.unix_timestamp) })
}); .collect::<Vec<StringMatchCandidate>>()
branches.truncate(RECENT_BRANCHES_COUNT); });
branches.sort_unstable_by(|lhs, rhs| lhs.name.cmp(&rhs.name));
}
Ok(branches
.iter()
.cloned()
.enumerate()
.map(|(ix, command)| StringMatchCandidate {
id: ix,
char_bag: command.name.chars().collect(),
string: command.name.into(),
})
.collect::<Vec<_>>())
})
.log_err()
else {
return;
};
let Some(candidates) = candidates.log_err() else { let Some(candidates) = candidates.log_err() else {
return; return;
}; };

View File

@ -1,9 +1,7 @@
use std::cmp;
use editor::{ use editor::{
char_kind, char_kind,
display_map::{DisplaySnapshot, FoldPoint, ToDisplayPoint}, display_map::{DisplaySnapshot, FoldPoint, ToDisplayPoint},
movement::{self, find_boundary, find_preceding_boundary, FindRange}, movement::{self, find_boundary, find_preceding_boundary, FindRange, TextLayoutDetails},
Bias, CharKind, DisplayPoint, ToOffset, Bias, CharKind, DisplayPoint, ToOffset,
}; };
use gpui::{actions, impl_actions, AppContext, WindowContext}; use gpui::{actions, impl_actions, AppContext, WindowContext};
@ -361,6 +359,7 @@ impl Motion {
point: DisplayPoint, point: DisplayPoint,
goal: SelectionGoal, goal: SelectionGoal,
maybe_times: Option<usize>, maybe_times: Option<usize>,
text_layout_details: &TextLayoutDetails,
) -> Option<(DisplayPoint, SelectionGoal)> { ) -> Option<(DisplayPoint, SelectionGoal)> {
let times = maybe_times.unwrap_or(1); let times = maybe_times.unwrap_or(1);
use Motion::*; use Motion::*;
@ -370,16 +369,16 @@ impl Motion {
Backspace => (backspace(map, point, times), SelectionGoal::None), Backspace => (backspace(map, point, times), SelectionGoal::None),
Down { Down {
display_lines: false, display_lines: false,
} => down(map, point, goal, times), } => up_down_buffer_rows(map, point, goal, times as isize, &text_layout_details),
Down { Down {
display_lines: true, display_lines: true,
} => down_display(map, point, goal, times), } => down_display(map, point, goal, times, &text_layout_details),
Up { Up {
display_lines: false, display_lines: false,
} => up(map, point, goal, times), } => up_down_buffer_rows(map, point, goal, 0 - times as isize, &text_layout_details),
Up { Up {
display_lines: true, display_lines: true,
} => up_display(map, point, goal, times), } => up_display(map, point, goal, times, &text_layout_details),
Right => (right(map, point, times), SelectionGoal::None), Right => (right(map, point, times), SelectionGoal::None),
NextWordStart { ignore_punctuation } => ( NextWordStart { ignore_punctuation } => (
next_word_start(map, point, *ignore_punctuation, times), next_word_start(map, point, *ignore_punctuation, times),
@ -442,10 +441,15 @@ impl Motion {
selection: &mut Selection<DisplayPoint>, selection: &mut Selection<DisplayPoint>,
times: Option<usize>, times: Option<usize>,
expand_to_surrounding_newline: bool, expand_to_surrounding_newline: bool,
text_layout_details: &TextLayoutDetails,
) -> bool { ) -> bool {
if let Some((new_head, goal)) = if let Some((new_head, goal)) = self.move_point(
self.move_point(map, selection.head(), selection.goal, times) map,
{ selection.head(),
selection.goal,
times,
&text_layout_details,
) {
selection.set_head(new_head, goal); selection.set_head(new_head, goal);
if self.linewise() { if self.linewise() {
@ -530,35 +534,85 @@ fn backspace(map: &DisplaySnapshot, mut point: DisplayPoint, times: usize) -> Di
point point
} }
fn down( pub(crate) fn start_of_relative_buffer_row(
map: &DisplaySnapshot,
point: DisplayPoint,
times: isize,
) -> DisplayPoint {
let start = map.display_point_to_fold_point(point, Bias::Left);
let target = start.row() as isize + times;
let new_row = (target.max(0) as u32).min(map.fold_snapshot.max_point().row());
map.clip_point(
map.fold_point_to_display_point(
map.fold_snapshot
.clip_point(FoldPoint::new(new_row, 0), Bias::Right),
),
Bias::Right,
)
}
fn up_down_buffer_rows(
map: &DisplaySnapshot, map: &DisplaySnapshot,
point: DisplayPoint, point: DisplayPoint,
mut goal: SelectionGoal, mut goal: SelectionGoal,
times: usize, times: isize,
text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) { ) -> (DisplayPoint, SelectionGoal) {
let start = map.display_point_to_fold_point(point, Bias::Left); let start = map.display_point_to_fold_point(point, Bias::Left);
let begin_folded_line = map.fold_point_to_display_point(
map.fold_snapshot
.clip_point(FoldPoint::new(start.row(), 0), Bias::Left),
);
let select_nth_wrapped_row = point.row() - begin_folded_line.row();
let goal_column = match goal { let (goal_wrap, goal_x) = match goal {
SelectionGoal::Column(column) => column, SelectionGoal::WrappedHorizontalPosition((row, x)) => (row, x),
SelectionGoal::ColumnRange { end, .. } => end, SelectionGoal::HorizontalRange { end, .. } => (select_nth_wrapped_row, end),
SelectionGoal::HorizontalPosition(x) => (select_nth_wrapped_row, x),
_ => { _ => {
goal = SelectionGoal::Column(start.column()); let x = map.x_for_point(point, text_layout_details);
start.column() goal = SelectionGoal::WrappedHorizontalPosition((select_nth_wrapped_row, x));
(select_nth_wrapped_row, x)
} }
}; };
let new_row = cmp::min( let target = start.row() as isize + times;
start.row() + times as u32, let new_row = (target.max(0) as u32).min(map.fold_snapshot.max_point().row());
map.fold_snapshot.max_point().row(),
); let mut begin_folded_line = map.fold_point_to_display_point(
let new_col = cmp::min(goal_column, map.fold_snapshot.line_len(new_row));
let point = map.fold_point_to_display_point(
map.fold_snapshot map.fold_snapshot
.clip_point(FoldPoint::new(new_row, new_col), Bias::Left), .clip_point(FoldPoint::new(new_row, 0), Bias::Left),
); );
// clip twice to "clip at end of line" let mut i = 0;
(map.clip_point(point, Bias::Left), goal) while i < goal_wrap && begin_folded_line.row() < map.max_point().row() {
let next_folded_line = DisplayPoint::new(begin_folded_line.row() + 1, 0);
if map
.display_point_to_fold_point(next_folded_line, Bias::Right)
.row()
== new_row
{
i += 1;
begin_folded_line = next_folded_line;
} else {
break;
}
}
let new_col = if i == goal_wrap {
map.column_for_x(begin_folded_line.row(), goal_x, text_layout_details)
} else {
map.line_len(begin_folded_line.row())
};
(
map.clip_point(
DisplayPoint::new(begin_folded_line.row(), new_col),
Bias::Left,
),
goal,
)
} }
fn down_display( fn down_display(
@ -566,49 +620,24 @@ fn down_display(
mut point: DisplayPoint, mut point: DisplayPoint,
mut goal: SelectionGoal, mut goal: SelectionGoal,
times: usize, times: usize,
text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) { ) -> (DisplayPoint, SelectionGoal) {
for _ in 0..times { for _ in 0..times {
(point, goal) = movement::down(map, point, goal, true); (point, goal) = movement::down(map, point, goal, true, text_layout_details);
} }
(point, goal) (point, goal)
} }
pub(crate) fn up(
map: &DisplaySnapshot,
point: DisplayPoint,
mut goal: SelectionGoal,
times: usize,
) -> (DisplayPoint, SelectionGoal) {
let start = map.display_point_to_fold_point(point, Bias::Left);
let goal_column = match goal {
SelectionGoal::Column(column) => column,
SelectionGoal::ColumnRange { end, .. } => end,
_ => {
goal = SelectionGoal::Column(start.column());
start.column()
}
};
let new_row = start.row().saturating_sub(times as u32);
let new_col = cmp::min(goal_column, map.fold_snapshot.line_len(new_row));
let point = map.fold_point_to_display_point(
map.fold_snapshot
.clip_point(FoldPoint::new(new_row, new_col), Bias::Left),
);
(map.clip_point(point, Bias::Left), goal)
}
fn up_display( fn up_display(
map: &DisplaySnapshot, map: &DisplaySnapshot,
mut point: DisplayPoint, mut point: DisplayPoint,
mut goal: SelectionGoal, mut goal: SelectionGoal,
times: usize, times: usize,
text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) { ) -> (DisplayPoint, SelectionGoal) {
for _ in 0..times { for _ in 0..times {
(point, goal) = movement::up(map, point, goal, true); (point, goal) = movement::up(map, point, goal, true, &text_layout_details);
} }
(point, goal) (point, goal)
@ -707,7 +736,7 @@ fn previous_word_start(
point point
} }
fn first_non_whitespace( pub(crate) fn first_non_whitespace(
map: &DisplaySnapshot, map: &DisplaySnapshot,
display_lines: bool, display_lines: bool,
from: DisplayPoint, from: DisplayPoint,
@ -886,13 +915,17 @@ fn find_backward(
} }
fn next_line_start(map: &DisplaySnapshot, point: DisplayPoint, times: usize) -> DisplayPoint { fn next_line_start(map: &DisplaySnapshot, point: DisplayPoint, times: usize) -> DisplayPoint {
let correct_line = down(map, point, SelectionGoal::None, times).0; let correct_line = start_of_relative_buffer_row(map, point, times as isize);
first_non_whitespace(map, false, correct_line) first_non_whitespace(map, false, correct_line)
} }
fn next_line_end(map: &DisplaySnapshot, mut point: DisplayPoint, times: usize) -> DisplayPoint { pub(crate) fn next_line_end(
map: &DisplaySnapshot,
mut point: DisplayPoint,
times: usize,
) -> DisplayPoint {
if times > 1 { if times > 1 {
point = down(map, point, SelectionGoal::None, times - 1).0; point = start_of_relative_buffer_row(map, point, times as isize - 1);
} }
end_of_line(map, false, point) end_of_line(map, false, point)
} }

View File

@ -12,7 +12,7 @@ mod yank;
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
motion::{self, Motion}, motion::{self, first_non_whitespace, next_line_end, right, Motion},
object::Object, object::Object,
state::{Mode, Operator}, state::{Mode, Operator},
Vim, Vim,
@ -179,10 +179,11 @@ pub(crate) fn move_cursor(
cx: &mut WindowContext, cx: &mut WindowContext,
) { ) {
vim.update_active_editor(cx, |editor, cx| { vim.update_active_editor(cx, |editor, cx| {
let text_layout_details = editor.text_layout_details(cx);
editor.change_selections(Some(Autoscroll::fit()), cx, |s| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.move_cursors_with(|map, cursor, goal| { s.move_cursors_with(|map, cursor, goal| {
motion motion
.move_point(map, cursor, goal, times) .move_point(map, cursor, goal, times, &text_layout_details)
.unwrap_or((cursor, goal)) .unwrap_or((cursor, goal))
}) })
}) })
@ -195,9 +196,7 @@ fn insert_after(_: &mut Workspace, _: &InsertAfter, cx: &mut ViewContext<Workspa
vim.switch_mode(Mode::Insert, false, cx); vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| { vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.maybe_move_cursors_with(|map, cursor, goal| { s.move_cursors_with(|map, cursor, _| (right(map, cursor, 1), SelectionGoal::None));
Motion::Right.move_point(map, cursor, goal, None)
});
}); });
}); });
}); });
@ -220,11 +219,11 @@ fn insert_first_non_whitespace(
vim.switch_mode(Mode::Insert, false, cx); vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| { vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.maybe_move_cursors_with(|map, cursor, goal| { s.move_cursors_with(|map, cursor, _| {
Motion::FirstNonWhitespace { (
display_lines: false, first_non_whitespace(map, false, cursor),
} SelectionGoal::None,
.move_point(map, cursor, goal, None) )
}); });
}); });
}); });
@ -237,8 +236,8 @@ fn insert_end_of_line(_: &mut Workspace, _: &InsertEndOfLine, cx: &mut ViewConte
vim.switch_mode(Mode::Insert, false, cx); vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| { vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.maybe_move_cursors_with(|map, cursor, goal| { s.move_cursors_with(|map, cursor, _| {
Motion::CurrentLine.move_point(map, cursor, goal, None) (next_line_end(map, cursor, 1), SelectionGoal::None)
}); });
}); });
}); });
@ -268,7 +267,7 @@ fn insert_line_above(_: &mut Workspace, _: &InsertLineAbove, cx: &mut ViewContex
editor.edit_with_autoindent(edits, cx); editor.edit_with_autoindent(edits, cx);
editor.change_selections(Some(Autoscroll::fit()), cx, |s| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.move_cursors_with(|map, cursor, _| { s.move_cursors_with(|map, cursor, _| {
let previous_line = motion::up(map, cursor, SelectionGoal::None, 1).0; let previous_line = motion::start_of_relative_buffer_row(map, cursor, -1);
let insert_point = motion::end_of_line(map, false, previous_line); let insert_point = motion::end_of_line(map, false, previous_line);
(insert_point, SelectionGoal::None) (insert_point, SelectionGoal::None)
}); });
@ -283,6 +282,7 @@ fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContex
vim.start_recording(cx); vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx); vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| { vim.update_active_editor(cx, |editor, cx| {
let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| { editor.transact(cx, |editor, cx| {
let (map, old_selections) = editor.selections.all_display(cx); let (map, old_selections) = editor.selections.all_display(cx);
@ -301,7 +301,13 @@ fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContex
}); });
editor.change_selections(Some(Autoscroll::fit()), cx, |s| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.maybe_move_cursors_with(|map, cursor, goal| { s.maybe_move_cursors_with(|map, cursor, goal| {
Motion::CurrentLine.move_point(map, cursor, goal, None) Motion::CurrentLine.move_point(
map,
cursor,
goal,
None,
&text_layout_details,
)
}); });
}); });
editor.edit_with_autoindent(edits, cx); editor.edit_with_autoindent(edits, cx);
@ -399,12 +405,26 @@ mod test {
#[gpui::test] #[gpui::test]
async fn test_j(cx: &mut gpui::TestAppContext) { async fn test_j(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await.binding(["j"]); let mut cx = NeovimBackedTestContext::new(cx).await;
cx.assert_all(indoc! {"
ˇThe qˇuick broˇwn cx.set_shared_state(indoc! {"
ˇfox jumps" aaˇaa
😃😃"
}) })
.await; .await;
cx.simulate_shared_keystrokes(["j"]).await;
cx.assert_shared_state(indoc! {"
aaaa
😃ˇ😃"
})
.await;
for marked_position in cx.each_marked_position(indoc! {"
ˇThe qˇuick broˇwn
ˇfox jumps"
}) {
cx.assert_neovim_compatible(&marked_position, ["j"]).await;
}
} }
#[gpui::test] #[gpui::test]

View File

@ -2,7 +2,7 @@ use crate::{motion::Motion, object::Object, state::Mode, utils::copy_selections_
use editor::{ use editor::{
char_kind, char_kind,
display_map::DisplaySnapshot, display_map::DisplaySnapshot,
movement::{self, FindRange}, movement::{self, FindRange, TextLayoutDetails},
scroll::autoscroll::Autoscroll, scroll::autoscroll::Autoscroll,
CharKind, DisplayPoint, CharKind, DisplayPoint,
}; };
@ -20,6 +20,7 @@ pub fn change_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
| Motion::StartOfLine { .. } | Motion::StartOfLine { .. }
); );
vim.update_active_editor(cx, |editor, cx| { vim.update_active_editor(cx, |editor, cx| {
let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| { editor.transact(cx, |editor, cx| {
// We are swapping to insert mode anyway. Just set the line end clipping behavior now // We are swapping to insert mode anyway. Just set the line end clipping behavior now
editor.set_clip_at_line_ends(false, cx); editor.set_clip_at_line_ends(false, cx);
@ -27,9 +28,15 @@ pub fn change_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
s.move_with(|map, selection| { s.move_with(|map, selection| {
motion_succeeded |= if let Motion::NextWordStart { ignore_punctuation } = motion motion_succeeded |= if let Motion::NextWordStart { ignore_punctuation } = motion
{ {
expand_changed_word_selection(map, selection, times, ignore_punctuation) expand_changed_word_selection(
map,
selection,
times,
ignore_punctuation,
&text_layout_details,
)
} else { } else {
motion.expand_selection(map, selection, times, false) motion.expand_selection(map, selection, times, false, &text_layout_details)
}; };
}); });
}); });
@ -81,6 +88,7 @@ fn expand_changed_word_selection(
selection: &mut Selection<DisplayPoint>, selection: &mut Selection<DisplayPoint>,
times: Option<usize>, times: Option<usize>,
ignore_punctuation: bool, ignore_punctuation: bool,
text_layout_details: &TextLayoutDetails,
) -> bool { ) -> bool {
if times.is_none() || times.unwrap() == 1 { if times.is_none() || times.unwrap() == 1 {
let scope = map let scope = map
@ -103,11 +111,22 @@ fn expand_changed_word_selection(
}); });
true true
} else { } else {
Motion::NextWordStart { ignore_punctuation } Motion::NextWordStart { ignore_punctuation }.expand_selection(
.expand_selection(map, selection, None, false) map,
selection,
None,
false,
&text_layout_details,
)
} }
} else { } else {
Motion::NextWordStart { ignore_punctuation }.expand_selection(map, selection, times, false) Motion::NextWordStart { ignore_punctuation }.expand_selection(
map,
selection,
times,
false,
&text_layout_details,
)
} }
} }

View File

@ -7,6 +7,7 @@ use language::Point;
pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut WindowContext) { pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut WindowContext) {
vim.stop_recording(); vim.stop_recording();
vim.update_active_editor(cx, |editor, cx| { vim.update_active_editor(cx, |editor, cx| {
let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| { editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx); editor.set_clip_at_line_ends(false, cx);
let mut original_columns: HashMap<_, _> = Default::default(); let mut original_columns: HashMap<_, _> = Default::default();
@ -14,7 +15,7 @@ pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
s.move_with(|map, selection| { s.move_with(|map, selection| {
let original_head = selection.head(); let original_head = selection.head();
original_columns.insert(selection.id, original_head.column()); original_columns.insert(selection.id, original_head.column());
motion.expand_selection(map, selection, times, true); motion.expand_selection(map, selection, times, true, &text_layout_details);
// Motion::NextWordStart on an empty line should delete it. // Motion::NextWordStart on an empty line should delete it.
if let Motion::NextWordStart { if let Motion::NextWordStart {

Some files were not shown because too many files have changed in this diff Show More