Introduce Context Retrieval in Inline Assistant (#3097)

This PR introduces a new Inline Assistant feature "Retrieve Context", to
dynamically fill the content in your generation prompt based on relevant
results returned from the Semantic Search for the Prompt.

Release Notes:

- Introduce "Retrieve Context" button in Inline Assistant
This commit is contained in:
Kyle Caverly 2023-10-17 15:04:36 -04:00 committed by GitHub
commit 2795091f0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 734 additions and 139 deletions

22
Cargo.lock generated
View File

@ -103,7 +103,7 @@ dependencies = [
"rusqlite",
"serde",
"serde_json",
"tiktoken-rs 0.5.4",
"tiktoken-rs",
"util",
]
@ -316,12 +316,13 @@ dependencies = [
"regex",
"schemars",
"search",
"semantic_index",
"serde",
"serde_json",
"settings",
"smol",
"theme",
"tiktoken-rs 0.4.5",
"tiktoken-rs",
"util",
"uuid 1.4.1",
"workspace",
@ -6942,7 +6943,7 @@ dependencies = [
"smol",
"tempdir",
"theme",
"tiktoken-rs 0.5.4",
"tiktoken-rs",
"tree-sitter",
"tree-sitter-cpp",
"tree-sitter-elixir",
@ -8117,21 +8118,6 @@ dependencies = [
"weezl",
]
[[package]]
name = "tiktoken-rs"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
dependencies = [
"anyhow",
"base64 0.21.4",
"bstr",
"fancy-regex",
"lazy_static",
"parking_lot 0.12.1",
"rustc-hash",
]
[[package]]
name = "tiktoken-rs"
version = "0.5.4"

8
assets/icons/update.svg Normal file
View File

@ -0,0 +1,8 @@
<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
<path
fill-rule="evenodd"
clip-rule="evenodd"
d="M1.90321 7.29677C1.90321 10.341 4.11041 12.4147 6.58893 12.8439C6.87255 12.893 7.06266 13.1627 7.01355 13.4464C6.96444 13.73 6.69471 13.9201 6.41109 13.871C3.49942 13.3668 0.86084 10.9127 0.86084 7.29677C0.860839 5.76009 1.55996 4.55245 2.37639 3.63377C2.96124 2.97568 3.63034 2.44135 4.16846 2.03202L2.53205 2.03202C2.25591 2.03202 2.03205 1.80816 2.03205 1.53202C2.03205 1.25588 2.25591 1.03202 2.53205 1.03202L5.53205 1.03202C5.80819 1.03202 6.03205 1.25588 6.03205 1.53202L6.03205 4.53202C6.03205 4.80816 5.80819 5.03202 5.53205 5.03202C5.25591 5.03202 5.03205 4.80816 5.03205 4.53202L5.03205 2.68645L5.03054 2.68759L5.03045 2.68766L5.03044 2.68767L5.03043 2.68767C4.45896 3.11868 3.76059 3.64538 3.15554 4.3262C2.44102 5.13021 1.90321 6.10154 1.90321 7.29677ZM13.0109 7.70321C13.0109 4.69115 10.8505 2.6296 8.40384 2.17029C8.12093 2.11718 7.93465 1.84479 7.98776 1.56188C8.04087 1.27898 8.31326 1.0927 8.59616 1.14581C11.4704 1.68541 14.0532 4.12605 14.0532 7.70321C14.0532 9.23988 13.3541 10.4475 12.5377 11.3662C11.9528 12.0243 11.2837 12.5586 10.7456 12.968L12.3821 12.968C12.6582 12.968 12.8821 13.1918 12.8821 13.468C12.8821 13.7441 12.6582 13.968 12.3821 13.968L9.38205 13.968C9.10591 13.968 8.88205 13.7441 8.88205 13.468L8.88205 10.468C8.88205 10.1918 9.10591 9.96796 9.38205 9.96796C9.65819 9.96796 9.88205 10.1918 9.88205 10.468L9.88205 12.3135L9.88362 12.3123C10.4551 11.8813 11.1535 11.3546 11.7585 10.6738C12.4731 9.86976 13.0109 8.89844 13.0109 7.70321Z"
fill="currentColor"
/>
</svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -85,25 +85,6 @@ impl Embedding {
}
}
// impl FromSql for Embedding {
// fn column_result(value: ValueRef) -> FromSqlResult<Self> {
// let bytes = value.as_blob()?;
// let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
// if embedding.is_err() {
// return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
// }
// Ok(Embedding(embedding.unwrap()))
// }
// }
// impl ToSql for Embedding {
// fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
// let bytes = bincode::serialize(&self.0)
// .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
// Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
// }
// }
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
@ -300,6 +281,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
request_timeout,
)
.await?;
request_number += 1;
match response.status() {

View File

@ -22,8 +22,11 @@ settings = { path = "../settings" }
theme = { path = "../theme" }
util = { path = "../util" }
workspace = { path = "../workspace" }
uuid.workspace = true
semantic_index = { path = "../semantic_index" }
project = { path = "../project" }
uuid.workspace = true
log.workspace = true
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
@ -36,7 +39,7 @@ schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tiktoken-rs = "0.4"
tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }

View File

@ -1,7 +1,7 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
prompts::generate_content_prompt,
prompts::{generate_content_prompt, PromptCodeSnippet},
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
@ -29,13 +29,15 @@ use gpui::{
},
fonts::HighlightStyle,
geometry::vector::{vec2f, Vector2F},
platform::{CursorStyle, MouseButton},
platform::{CursorStyle, MouseButton, PromptLevel},
Action, AnyElement, AppContext, AsyncAppContext, ClipboardItem, Element, Entity, ModelContext,
ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle,
WindowContext,
ModelHandle, SizeConstraint, Subscription, Task, View, ViewContext, ViewHandle,
WeakModelHandle, WeakViewHandle, WindowContext,
};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use project::Project;
use search::BufferSearchBar;
use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
@ -46,7 +48,7 @@ use std::{
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::Duration,
time::{Duration, Instant},
};
use theme::{
components::{action_button::Button, ComponentExt},
@ -72,6 +74,7 @@ actions!(
ResetKey,
InlineAssist,
ToggleIncludeConversation,
ToggleRetrieveContext,
]
);
@ -108,6 +111,7 @@ pub fn init(cx: &mut AppContext) {
cx.add_action(InlineAssistant::confirm);
cx.add_action(InlineAssistant::cancel);
cx.add_action(InlineAssistant::toggle_include_conversation);
cx.add_action(InlineAssistant::toggle_retrieve_context);
cx.add_action(InlineAssistant::move_up);
cx.add_action(InlineAssistant::move_down);
}
@ -145,6 +149,8 @@ pub struct AssistantPanel {
include_conversation_in_next_inline_assist: bool,
inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
semantic_index: Option<ModelHandle<SemanticIndex>>,
retrieve_context_in_next_inline_assist: bool,
}
impl AssistantPanel {
@ -191,6 +197,9 @@ impl AssistantPanel {
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
toolbar
});
let semantic_index = SemanticIndex::global(cx);
let mut this = Self {
workspace: workspace_handle,
active_editor_index: Default::default(),
@ -215,6 +224,8 @@ impl AssistantPanel {
include_conversation_in_next_inline_assist: false,
inline_prompt_history: Default::default(),
_watch_saved_conversations,
semantic_index,
retrieve_context_in_next_inline_assist: false,
};
let mut old_dock_position = this.position(cx);
@ -262,12 +273,19 @@ impl AssistantPanel {
return;
};
let project = workspace.project();
this.update(cx, |assistant, cx| {
assistant.new_inline_assist(&active_editor, cx)
assistant.new_inline_assist(&active_editor, cx, project)
});
}
fn new_inline_assist(&mut self, editor: &ViewHandle<Editor>, cx: &mut ViewContext<Self>) {
fn new_inline_assist(
&mut self,
editor: &ViewHandle<Editor>,
cx: &mut ViewContext<Self>,
project: &ModelHandle<Project>,
) {
let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
api_key
} else {
@ -312,6 +330,27 @@ impl AssistantPanel {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
});
if let Some(semantic_index) = self.semantic_index.clone() {
let project = project.clone();
cx.spawn(|_, mut cx| async move {
let previously_indexed = semantic_index
.update(&mut cx, |index, cx| {
index.project_previously_indexed(&project, cx)
})
.await
.unwrap_or(false);
if previously_indexed {
let _ = semantic_index
.update(&mut cx, |index, cx| {
index.index_project(project.clone(), cx)
})
.await;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
let inline_assistant = cx.add_view(|cx| {
let assistant = InlineAssistant::new(
@ -322,6 +361,9 @@ impl AssistantPanel {
codegen.clone(),
self.workspace.clone(),
cx,
self.retrieve_context_in_next_inline_assist,
self.semantic_index.clone(),
project.clone(),
);
cx.focus_self();
assistant
@ -362,6 +404,7 @@ impl AssistantPanel {
editor: editor.downgrade(),
inline_assistant: Some((block_id, inline_assistant.clone())),
codegen: codegen.clone(),
project: project.downgrade(),
_subscriptions: vec![
cx.subscribe(&inline_assistant, Self::handle_inline_assistant_event),
cx.subscribe(editor, {
@ -440,8 +483,15 @@ impl AssistantPanel {
InlineAssistantEvent::Confirmed {
prompt,
include_conversation,
retrieve_context,
} => {
self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
self.confirm_inline_assist(
assist_id,
prompt,
*include_conversation,
cx,
*retrieve_context,
);
}
InlineAssistantEvent::Canceled => {
self.finish_inline_assist(assist_id, true, cx);
@ -454,6 +504,9 @@ impl AssistantPanel {
} => {
self.include_conversation_in_next_inline_assist = *include_conversation;
}
InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => {
self.retrieve_context_in_next_inline_assist = *retrieve_context
}
}
}
@ -532,6 +585,7 @@ impl AssistantPanel {
user_prompt: &str,
include_conversation: bool,
cx: &mut ViewContext<Self>,
retrieve_context: bool,
) {
let conversation = if include_conversation {
self.active_editor()
@ -553,6 +607,8 @@ impl AssistantPanel {
return;
};
let project = pending_assist.project.clone();
self.inline_prompt_history
.retain(|prompt| prompt != user_prompt);
self.inline_prompt_history.push_back(user_prompt.into());
@ -593,10 +649,62 @@ impl AssistantPanel {
let codegen_kind = codegen.read(cx).kind().clone();
let user_prompt = user_prompt.to_string();
let mut messages = Vec::new();
let snippets = if retrieve_context {
let Some(project) = project.upgrade(cx) else {
return;
};
let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
let search_results = semantic_index.update(cx, |this, cx| {
this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
});
cx.background()
.spawn(async move { search_results.await.unwrap_or_default() })
} else {
Task::ready(Vec::new())
};
let snippets = cx.spawn(|_, cx| async move {
let mut snippets = Vec::new();
for result in search_results.await {
snippets.push(PromptCodeSnippet::new(result, &cx));
// snippets.push(result.buffer.read_with(&cx, |buffer, _| {
// buffer
// .snapshot()
// .text_for_range(result.range)
// .collect::<String>()
// }));
}
snippets
});
snippets
} else {
Task::ready(Vec::new())
};
let mut model = settings::get::<AssistantSettings>(cx)
.default_open_ai_model
.clone();
let model_name = model.full_name();
let prompt = cx.background().spawn(async move {
let snippets = snippets.await;
let language_name = language_name.as_deref();
generate_content_prompt(
user_prompt,
language_name,
&buffer,
range,
codegen_kind,
snippets,
model_name,
)
});
let mut messages = Vec::new();
if let Some(conversation) = conversation {
let conversation = conversation.read(cx);
let buffer = conversation.buffer.read(cx);
@ -608,11 +716,6 @@ impl AssistantPanel {
model = conversation.model.clone();
}
let prompt = cx.background().spawn(async move {
let language_name = language_name.as_deref();
generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind)
});
cx.spawn(|_, mut cx| async move {
let prompt = prompt.await;
@ -1514,12 +1617,14 @@ impl Conversation {
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: self
.buffer
.read(cx)
.text_for_range(message.offset_range)
.collect(),
content: Some(
self.buffer
.read(cx)
.text_for_range(message.offset_range)
.collect(),
),
name: None,
function_call: None,
})
})
.collect::<Vec<_>>();
@ -2638,12 +2743,16 @@ enum InlineAssistantEvent {
Confirmed {
prompt: String,
include_conversation: bool,
retrieve_context: bool,
},
Canceled,
Dismissed,
IncludeConversationToggled {
include_conversation: bool,
},
RetrieveContextToggled {
retrieve_context: bool,
},
}
struct InlineAssistant {
@ -2659,6 +2768,11 @@ struct InlineAssistant {
pending_prompt: String,
codegen: ModelHandle<Codegen>,
_subscriptions: Vec<Subscription>,
retrieve_context: bool,
semantic_index: Option<ModelHandle<SemanticIndex>>,
semantic_permissioned: Option<bool>,
project: WeakModelHandle<Project>,
maintain_rate_limit: Option<Task<()>>,
}
impl Entity for InlineAssistant {
@ -2675,51 +2789,65 @@ impl View for InlineAssistant {
let theme = theme::current(cx);
Flex::row()
.with_child(
Flex::row()
.with_child(
Button::action(ToggleIncludeConversation)
.with_tooltip("Include Conversation", theme.tooltip.clone())
.with_children([Flex::row()
.with_child(
Button::action(ToggleIncludeConversation)
.with_tooltip("Include Conversation", theme.tooltip.clone())
.with_id(self.id)
.with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
.toggleable(self.include_conversation)
.with_style(theme.assistant.inline.include_conversation.clone())
.element()
.aligned(),
)
.with_children(if SemanticIndex::enabled(cx) {
Some(
Button::action(ToggleRetrieveContext)
.with_tooltip("Retrieve Context", theme.tooltip.clone())
.with_id(self.id)
.with_contents(theme::components::svg::Svg::new("icons/ai.svg"))
.toggleable(self.include_conversation)
.with_style(theme.assistant.inline.include_conversation.clone())
.with_contents(theme::components::svg::Svg::new(
"icons/magnifying_glass.svg",
))
.toggleable(self.retrieve_context)
.with_style(theme.assistant.inline.retrieve_context.clone())
.element()
.aligned(),
)
.with_children(if let Some(error) = self.codegen.read(cx).error() {
Some(
Svg::new("icons/error.svg")
.with_color(theme.assistant.error_icon.color)
.constrained()
.with_width(theme.assistant.error_icon.width)
.contained()
.with_style(theme.assistant.error_icon.container)
.with_tooltip::<ErrorIcon>(
self.id,
error.to_string(),
None,
theme.tooltip.clone(),
cx,
)
.aligned(),
)
} else {
None
})
.aligned()
.constrained()
.dynamically({
let measurements = self.measurements.clone();
move |constraint, _, _| {
let measurements = measurements.get();
SizeConstraint {
min: vec2f(measurements.gutter_width, constraint.min.y()),
max: vec2f(measurements.gutter_width, constraint.max.y()),
}
} else {
None
})
.with_children(if let Some(error) = self.codegen.read(cx).error() {
Some(
Svg::new("icons/error.svg")
.with_color(theme.assistant.error_icon.color)
.constrained()
.with_width(theme.assistant.error_icon.width)
.contained()
.with_style(theme.assistant.error_icon.container)
.with_tooltip::<ErrorIcon>(
self.id,
error.to_string(),
None,
theme.tooltip.clone(),
cx,
)
.aligned(),
)
} else {
None
})
.aligned()
.constrained()
.dynamically({
let measurements = self.measurements.clone();
move |constraint, _, _| {
let measurements = measurements.get();
SizeConstraint {
min: vec2f(measurements.gutter_width, constraint.min.y()),
max: vec2f(measurements.gutter_width, constraint.max.y()),
}
}),
)
}
})])
.with_child(Empty::new().constrained().dynamically({
let measurements = self.measurements.clone();
move |constraint, _, _| {
@ -2742,6 +2870,16 @@ impl View for InlineAssistant {
.left()
.flex(1., true),
)
.with_children(if self.retrieve_context {
Some(
Flex::row()
.with_children(self.retrieve_context_status(cx))
.flex(1., true)
.aligned(),
)
} else {
None
})
.contained()
.with_style(theme.assistant.inline.container)
.into_any()
@ -2767,6 +2905,9 @@ impl InlineAssistant {
codegen: ModelHandle<Codegen>,
workspace: WeakViewHandle<Workspace>,
cx: &mut ViewContext<Self>,
retrieve_context: bool,
semantic_index: Option<ModelHandle<SemanticIndex>>,
project: ModelHandle<Project>,
) -> Self {
let prompt_editor = cx.add_view(|cx| {
let mut editor = Editor::single_line(
@ -2780,11 +2921,16 @@ impl InlineAssistant {
editor.set_placeholder_text(placeholder, cx);
editor
});
let subscriptions = vec![
let mut subscriptions = vec![
cx.observe(&codegen, Self::handle_codegen_changed),
cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
];
Self {
if let Some(semantic_index) = semantic_index.clone() {
subscriptions.push(cx.observe(&semantic_index, Self::semantic_index_changed));
}
let assistant = Self {
id,
prompt_editor,
workspace,
@ -2797,7 +2943,33 @@ impl InlineAssistant {
pending_prompt: String::new(),
codegen,
_subscriptions: subscriptions,
retrieve_context,
semantic_permissioned: None,
semantic_index,
project: project.downgrade(),
maintain_rate_limit: None,
};
assistant.index_project(cx).log_err();
assistant
}
fn semantic_permissioned(&self, cx: &mut ViewContext<Self>) -> Task<Result<bool>> {
if let Some(value) = self.semantic_permissioned {
return Task::ready(Ok(value));
}
let Some(project) = self.project.upgrade(cx) else {
return Task::ready(Err(anyhow!("project was dropped")));
};
self.semantic_index
.as_ref()
.map(|semantic| {
semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
})
.unwrap_or(Task::ready(Ok(false)))
}
fn handle_prompt_editor_events(
@ -2812,6 +2984,37 @@ impl InlineAssistant {
}
}
fn semantic_index_changed(
&mut self,
semantic_index: ModelHandle<SemanticIndex>,
cx: &mut ViewContext<Self>,
) {
let Some(project) = self.project.upgrade(cx) else {
return;
};
let status = semantic_index.read(cx).status(&project);
match status {
SemanticIndexStatus::Indexing {
rate_limit_expiry: Some(_),
..
} => {
if self.maintain_rate_limit.is_none() {
self.maintain_rate_limit = Some(cx.spawn(|this, mut cx| async move {
loop {
cx.background().timer(Duration::from_secs(1)).await;
this.update(&mut cx, |_, cx| cx.notify()).log_err();
}
}));
}
return;
}
_ => {
self.maintain_rate_limit = None;
}
}
}
fn handle_codegen_changed(&mut self, _: ModelHandle<Codegen>, cx: &mut ViewContext<Self>) {
let is_read_only = !self.codegen.read(cx).idle();
self.prompt_editor.update(cx, |editor, cx| {
@ -2861,12 +3064,241 @@ impl InlineAssistant {
cx.emit(InlineAssistantEvent::Confirmed {
prompt,
include_conversation: self.include_conversation,
retrieve_context: self.retrieve_context,
});
self.confirmed = true;
cx.notify();
}
}
fn toggle_retrieve_context(&mut self, _: &ToggleRetrieveContext, cx: &mut ViewContext<Self>) {
let semantic_permissioned = self.semantic_permissioned(cx);
let Some(project) = self.project.upgrade(cx) else {
return;
};
let project_name = project
.read(cx)
.worktree_root_names(cx)
.collect::<Vec<&str>>()
.join("/");
let is_plural = project_name.chars().filter(|letter| *letter == '/').count() > 0;
let prompt_text = format!("Would you like to index the '{}' project{} for context retrieval? This requires sending code to the OpenAI API", project_name,
if is_plural {
"s"
} else {""});
cx.spawn(|this, mut cx| async move {
// If Necessary prompt user
if !semantic_permissioned.await.unwrap_or(false) {
let mut answer = this.update(&mut cx, |_, cx| {
cx.prompt(
PromptLevel::Info,
prompt_text.as_str(),
&["Continue", "Cancel"],
)
})?;
if answer.next().await == Some(0) {
this.update(&mut cx, |this, _| {
this.semantic_permissioned = Some(true);
})?;
} else {
return anyhow::Ok(());
}
}
// If permissioned, update context appropriately
this.update(&mut cx, |this, cx| {
this.retrieve_context = !this.retrieve_context;
cx.emit(InlineAssistantEvent::RetrieveContextToggled {
retrieve_context: this.retrieve_context,
});
if this.retrieve_context {
this.index_project(cx).log_err();
}
cx.notify();
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
fn index_project(&self, cx: &mut ViewContext<Self>) -> anyhow::Result<()> {
let Some(project) = self.project.upgrade(cx) else {
return Err(anyhow!("project was dropped!"));
};
let semantic_permissioned = self.semantic_permissioned(cx);
if let Some(semantic_index) = SemanticIndex::global(cx) {
cx.spawn(|_, mut cx| async move {
// This has to be updated to accomodate for semantic_permissions
if semantic_permissioned.await.unwrap_or(false) {
semantic_index
.update(&mut cx, |index, cx| index.index_project(project, cx))
.await
} else {
Err(anyhow!("project is not permissioned for semantic indexing"))
}
})
.detach_and_log_err(cx);
}
anyhow::Ok(())
}
fn retrieve_context_status(
&self,
cx: &mut ViewContext<Self>,
) -> Option<AnyElement<InlineAssistant>> {
enum ContextStatusIcon {}
let Some(project) = self.project.upgrade(cx) else {
return None;
};
if let Some(semantic_index) = SemanticIndex::global(cx) {
let status = semantic_index.update(cx, |index, _| index.status(&project));
let theme = theme::current(cx);
match status {
SemanticIndexStatus::NotAuthenticated {} => Some(
Svg::new("icons/error.svg")
.with_color(theme.assistant.error_icon.color)
.constrained()
.with_width(theme.assistant.error_icon.width)
.contained()
.with_style(theme.assistant.error_icon.container)
.with_tooltip::<ContextStatusIcon>(
self.id,
"Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.",
None,
theme.tooltip.clone(),
cx,
)
.aligned()
.into_any(),
),
SemanticIndexStatus::NotIndexed {} => Some(
Svg::new("icons/error.svg")
.with_color(theme.assistant.inline.context_status.error_icon.color)
.constrained()
.with_width(theme.assistant.inline.context_status.error_icon.width)
.contained()
.with_style(theme.assistant.inline.context_status.error_icon.container)
.with_tooltip::<ContextStatusIcon>(
self.id,
"Not Indexed",
None,
theme.tooltip.clone(),
cx,
)
.aligned()
.into_any(),
),
SemanticIndexStatus::Indexing {
remaining_files,
rate_limit_expiry,
} => {
let mut status_text = if remaining_files == 0 {
"Indexing...".to_string()
} else {
format!("Remaining files to index: {remaining_files}")
};
if let Some(rate_limit_expiry) = rate_limit_expiry {
let remaining_seconds = rate_limit_expiry.duration_since(Instant::now());
if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 {
write!(
status_text,
" (rate limit expires in {}s)",
remaining_seconds.as_secs()
)
.unwrap();
}
}
Some(
Svg::new("icons/update.svg")
.with_color(theme.assistant.inline.context_status.in_progress_icon.color)
.constrained()
.with_width(theme.assistant.inline.context_status.in_progress_icon.width)
.contained()
.with_style(theme.assistant.inline.context_status.in_progress_icon.container)
.with_tooltip::<ContextStatusIcon>(
self.id,
status_text,
None,
theme.tooltip.clone(),
cx,
)
.aligned()
.into_any(),
)
}
SemanticIndexStatus::Indexed {} => Some(
Svg::new("icons/check.svg")
.with_color(theme.assistant.inline.context_status.complete_icon.color)
.constrained()
.with_width(theme.assistant.inline.context_status.complete_icon.width)
.contained()
.with_style(theme.assistant.inline.context_status.complete_icon.container)
.with_tooltip::<ContextStatusIcon>(
self.id,
"Index up to date",
None,
theme.tooltip.clone(),
cx,
)
.aligned()
.into_any(),
),
}
} else {
None
}
}
// fn retrieve_context_status(&self, cx: &mut ViewContext<Self>) -> String {
// let project = self.project.clone();
// if let Some(semantic_index) = self.semantic_index.clone() {
// let status = semantic_index.update(cx, |index, cx| index.status(&project));
// return match status {
// // This theoretically shouldnt be a valid code path
// // As the inline assistant cant be launched without an API key
// // We keep it here for safety
// semantic_index::SemanticIndexStatus::NotAuthenticated => {
// "Not Authenticated!\nPlease ensure you have an `OPENAI_API_KEY` in your environment variables.".to_string()
// }
// semantic_index::SemanticIndexStatus::Indexed => {
// "Indexing Complete!".to_string()
// }
// semantic_index::SemanticIndexStatus::Indexing { remaining_files, rate_limit_expiry } => {
// let mut status = format!("Remaining files to index for Context Retrieval: {remaining_files}");
// if let Some(rate_limit_expiry) = rate_limit_expiry {
// let remaining_seconds =
// rate_limit_expiry.duration_since(Instant::now());
// if remaining_seconds > Duration::from_secs(0) {
// write!(status, " (rate limit resets in {}s)", remaining_seconds.as_secs()).unwrap();
// }
// }
// status
// }
// semantic_index::SemanticIndexStatus::NotIndexed => {
// "Not Indexed for Context Retrieval".to_string()
// }
// };
// }
// "".to_string()
// }
fn toggle_include_conversation(
&mut self,
_: &ToggleIncludeConversation,
@ -2929,6 +3361,7 @@ struct PendingInlineAssist {
inline_assistant: Option<(BlockId, ViewHandle<InlineAssistant>)>,
codegen: ModelHandle<Codegen>,
_subscriptions: Vec<Subscription>,
project: WeakModelHandle<Project>,
}
fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {

View File

@ -1,8 +1,60 @@
use crate::codegen::CodegenKind;
use gpui::AsyncAppContext;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use semantic_index::SearchResult;
use std::cmp::{self, Reverse};
use std::fmt::Write;
use std::ops::Range;
use std::path::PathBuf;
use tiktoken_rs::ChatCompletionRequestMessage;
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
let (content, language_name, file_path) =
search_result.buffer.read_with(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot
.text_for_range(search_result.range.clone())
.collect::<String>();
let language_name = buffer
.language()
.and_then(|language| Some(language.name().to_string()));
let file_path = buffer
.file()
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
});
PromptCodeSnippet {
path: file_path,
language_name,
content,
}
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.and_then(|path| Some(path.to_string_lossy().to_string()))
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
@ -121,17 +173,25 @@ pub fn generate_content_prompt(
buffer: &BufferSnapshot,
range: Range<impl ToOffset>,
kind: CodegenKind,
search_results: Vec<PromptCodeSnippet>,
model: &str,
) -> String {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
let mut prompts = Vec::new();
let range = range.to_offset(buffer);
let mut prompt = String::new();
// General Preamble
if let Some(language_name) = language_name {
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
prompts.push(format!("You're an expert {language_name} engineer.\n"));
} else {
writeln!(prompt, "You're an expert engineer.\n").unwrap();
prompts.push("You're an expert engineer.\n".to_string());
}
// Snippets
let mut snippet_position = prompts.len() - 1;
let mut content = String::new();
content.extend(buffer.text_for_range(0..range.start));
if range.start == range.end {
@ -145,59 +205,99 @@ pub fn generate_content_prompt(
}
content.extend(buffer.text_for_range(range.end..buffer.len()));
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
prompts.push("The file you are currently working on has the following content:\n".to_string());
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{content}\n```").unwrap();
prompts.push(format!("```{language_name}\n{content}\n```"));
} else {
writeln!(prompt, "```\n{content}\n```").unwrap();
prompts.push(format!("```\n{content}\n```"));
}
match kind {
CodegenKind::Generate { position: _ } => {
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
writeln!(
prompt,
"Assume the cursor is located where the `<|START|` marker is."
)
.unwrap();
writeln!(
prompt,
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
prompts
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
prompts.push(
"Text can't be replaced, so assume your answer will be inserted at the cursor."
)
.unwrap();
writeln!(
prompt,
.to_string(),
);
prompts.push(format!(
"Generate text based on the users prompt: {user_prompt}"
)
.unwrap();
));
}
CodegenKind::Transform { range: _ } => {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
writeln!(
prompt,
"Modify the users code selected text based upon the users prompt: {user_prompt}"
)
.unwrap();
writeln!(
prompt,
"You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
)
.unwrap();
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
prompts.push(format!(
"Modify the users code selected text based upon the users prompt: '{user_prompt}'"
));
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
}
}
if let Some(language_name) = language_name {
writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
prompts.push(format!(
"Your answer MUST always and only be valid {language_name}"
));
}
writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
writeln!(prompt, "Never make remarks about the output.").unwrap();
prompts.push("Never make remarks about the output.".to_string());
prompts.push("Do not return any text, except the generated code.".to_string());
prompts.push("Do not wrap your text in a Markdown block".to_string());
prompt
let current_messages = [ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(prompts.join("\n")),
function_call: None,
name: None,
}];
let mut remaining_token_count = if let Ok(current_token_count) =
tiktoken_rs::num_tokens_from_messages(model, &current_messages)
{
let max_token_count = tiktoken_rs::model::get_context_size(model);
let intermediate_token_count = max_token_count - current_token_count;
if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
0
} else {
intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
}
} else {
// If tiktoken fails to count token count, assume we have no space remaining.
0
};
// TODO:
// - add repository name to snippet
// - add file path
// - add language
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
for search_result in search_results {
let mut snippet_prompt = template.to_string();
let snippet = search_result.to_string();
writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
let token_count = encoding
.encode_with_special_tokens(snippet_prompt.as_str())
.len();
if token_count <= remaining_token_count {
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
prompts.insert(snippet_position, snippet_prompt);
snippet_position += 1;
remaining_token_count -= token_count;
// If you have already added the template to the prompt, remove the template.
template = "";
}
} else {
break;
}
}
}
prompts.join("\n")
}
#[cfg(test)]

View File

@ -1199,6 +1199,15 @@ pub struct InlineAssistantStyle {
pub disabled_editor: FieldEditor,
pub pending_edit_background: Color,
pub include_conversation: ToggleIconButtonStyle,
pub retrieve_context: ToggleIconButtonStyle,
pub context_status: ContextStatusStyle,
}
#[derive(Clone, Deserialize, Default, JsonSchema)]
pub struct ContextStatusStyle {
pub error_icon: Icon,
pub in_progress_icon: Icon,
pub complete_icon: Icon,
}
#[derive(Clone, Deserialize, Default, JsonSchema)]

View File

@ -79,6 +79,80 @@ export default function assistant(): any {
},
},
pending_edit_background: background(theme.highest, "positive"),
context_status: {
error_icon: {
margin: { left: 8, right: 18 },
color: foreground(theme.highest, "negative"),
width: 12,
},
in_progress_icon: {
margin: { left: 8, right: 18 },
color: foreground(theme.highest, "positive"),
width: 12,
},
complete_icon: {
margin: { left: 8, right: 18 },
color: foreground(theme.highest, "positive"),
width: 12,
}
},
retrieve_context: toggleable({
base: interactive({
base: {
icon_size: 12,
color: foreground(theme.highest, "variant"),
button_width: 12,
background: background(theme.highest, "on"),
corner_radius: 2,
border: {
width: 1., color: background(theme.highest, "on")
},
margin: { left: 2 },
padding: {
left: 4,
right: 4,
top: 4,
bottom: 4,
},
},
state: {
hovered: {
...text(theme.highest, "mono", "variant", "hovered"),
background: background(theme.highest, "on", "hovered"),
border: {
width: 1., color: background(theme.highest, "on", "hovered")
},
},
clicked: {
...text(theme.highest, "mono", "variant", "pressed"),
background: background(theme.highest, "on", "pressed"),
border: {
width: 1., color: background(theme.highest, "on", "pressed")
},
},
},
}),
state: {
active: {
default: {
icon_size: 12,
button_width: 12,
color: foreground(theme.highest, "variant"),
background: background(theme.highest, "accent"),
border: border(theme.highest, "accent"),
},
hovered: {
background: background(theme.highest, "accent", "hovered"),
border: border(theme.highest, "accent", "hovered"),
},
clicked: {
background: background(theme.highest, "accent", "pressed"),
border: border(theme.highest, "accent", "pressed"),
},
},
},
}),
include_conversation: toggleable({
base: interactive({
base: {