diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 7bd6378f40..8109e35a00 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -229,6 +229,7 @@ pub struct Project { search_history: SearchHistory, snippets: Model, yarn: Model, + cached_shell_environments: HashMap>, } pub enum LanguageServerToQuery { @@ -827,6 +828,7 @@ impl Project { hosted_project_id: None, dev_server_project_id: None, search_history: Self::new_search_history(), + cached_shell_environments: HashMap::default(), } }) } @@ -1021,6 +1023,7 @@ impl Project { .dev_server_project_id .map(|dev_server_project_id| DevServerProjectId(dev_server_project_id)), search_history: Self::new_search_history(), + cached_shell_environments: HashMap::default(), }; this.set_role(role, cx); for worktree in worktrees { @@ -1201,6 +1204,15 @@ impl Project { }) .await .unwrap(); + + project.update(cx, |project, cx| { + let tree_id = tree.read(cx).id(); + // In tests we always populate the environment to be empty so we don't run the shell + project + .cached_shell_environments + .insert(tree_id, HashMap::default()); + }); + tree.update(cx, |tree, _| tree.as_local().unwrap().scan_complete()) .await; } @@ -7886,6 +7898,7 @@ impl Project { } self.diagnostics.remove(&id_to_remove); self.diagnostic_summaries.remove(&id_to_remove); + self.cached_shell_environments.remove(&id_to_remove); let mut servers_to_remove = HashMap::default(); let mut servers_to_preserve = HashSet::default(); @@ -9286,6 +9299,7 @@ impl Project { })?; let task_context = context_task.await.unwrap_or_default(); Ok(proto::TaskContext { + project_env: task_context.project_env.into_iter().collect(), cwd: task_context .cwd .map(|cwd| cwd.to_string_lossy().to_string()), @@ -10260,7 +10274,14 @@ impl Project { cx: &mut ModelContext<'_, Project>, ) -> Task> { if self.is_local() { - let cwd = self.task_cwd(cx).log_err().flatten(); + let (worktree_id, cwd) = if let Some(worktree) = self.task_worktree(cx) { + ( + Some(worktree.read(cx).id()), + Some(self.task_cwd(worktree, cx)), + ) + } else { + (None, None) + }; cx.spawn(|project, cx| async move { let mut task_variables = cx @@ -10277,7 +10298,17 @@ impl Project { .flatten()?; // Remove all custom entries starting with _, as they're not intended for use by the end user. task_variables.sweep(); + + let mut project_env = None; + if let Some((worktree_id, cwd)) = worktree_id.zip(cwd.as_ref()) { + let env = Self::get_worktree_shell_env(project, worktree_id, cwd, cx).await; + if let Some(env) = env { + project_env.replace(env); + } + }; + Some(TaskContext { + project_env: project_env.unwrap_or_default(), cwd, task_variables, }) @@ -10297,6 +10328,7 @@ impl Project { cx.background_executor().spawn(async move { let task_context = task_context.await.log_err()?; Some(TaskContext { + project_env: task_context.project_env.into_iter().collect(), cwd: task_context.cwd.map(PathBuf::from), task_variables: task_context .task_variables @@ -10318,6 +10350,50 @@ impl Project { } } + async fn get_worktree_shell_env( + this: WeakModel, + worktree_id: WorktreeId, + cwd: &PathBuf, + mut cx: AsyncAppContext, + ) -> Option> { + let cached_env = this + .update(&mut cx, |project, _| { + project.cached_shell_environments.get(&worktree_id).cloned() + }) + .ok()?; + + if let Some(env) = cached_env { + Some(env) + } else { + let load_direnv = this + .update(&mut cx, |_, cx| { + ProjectSettings::get_global(cx).load_direnv.clone() + }) + .ok()?; + + let shell_env = cx + .background_executor() + .spawn({ + let cwd = cwd.clone(); + async move { + load_shell_environment(&cwd, &load_direnv) + .await + .unwrap_or_default() + } + }) + .await; + + this.update(&mut cx, |project, _| { + project + .cached_shell_environments + .insert(worktree_id, shell_env.clone()); + }) + .ok()?; + + Some(shell_env) + } + } + pub fn task_templates( &self, worktree: Option, @@ -10441,7 +10517,7 @@ impl Project { }) } - fn task_cwd(&self, cx: &AppContext) -> anyhow::Result> { + fn task_worktree(&self, cx: &AppContext) -> Option> { let available_worktrees = self .worktrees(cx) .filter(|worktree| { @@ -10451,28 +10527,24 @@ impl Project { && worktree.root_entry().map_or(false, |e| e.is_dir()) }) .collect::>(); - let cwd = match available_worktrees.len() { + + match available_worktrees.len() { 0 => None, - 1 => Some(available_worktrees[0].read(cx).abs_path()), - _ => { - let cwd_for_active_entry = self.active_entry().and_then(|entry_id| { - available_worktrees.into_iter().find_map(|worktree| { - let worktree = worktree.read(cx); - if worktree.contains_entry(entry_id) { - Some(worktree.abs_path()) - } else { - None - } - }) - }); - anyhow::ensure!( - cwd_for_active_entry.is_some(), - "Cannot determine task cwd for multiple worktrees" - ); - cwd_for_active_entry - } - }; - Ok(cwd.map(|path| path.to_path_buf())) + 1 => Some(available_worktrees[0].clone()), + _ => self.active_entry().and_then(|entry_id| { + available_worktrees.into_iter().find_map(|worktree| { + if worktree.read(cx).contains_entry(entry_id) { + Some(worktree) + } else { + None + } + }) + }), + } + } + + fn task_cwd(&self, worktree: Model, cx: &AppContext) -> PathBuf { + worktree.read(cx).abs_path().to_path_buf() } } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index d80b7127ee..60f8d01558 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -2257,6 +2257,7 @@ message TaskContextForLocation { message TaskContext { optional string cwd = 1; map task_variables = 2; + map project_env = 3; } message TaskTemplates { diff --git a/crates/task/src/lib.rs b/crates/task/src/lib.rs index 77c62638e2..5e1a9309d9 100644 --- a/crates/task/src/lib.rs +++ b/crates/task/src/lib.rs @@ -271,6 +271,10 @@ pub struct TaskContext { pub cwd: Option, /// Additional environment variables associated with a given task. pub task_variables: TaskVariables, + /// Environment variables obtained when loading the project into Zed. + /// This is the environment one would get when `cd`ing in a terminal + /// into the project's root directory. + pub project_env: HashMap, } /// This is a new type representing a 'tag' on a 'runnable symbol', typically a test of main() function, found via treesitter. diff --git a/crates/task/src/task_template.rs b/crates/task/src/task_template.rs index 403bab1b35..55f13f6bca 100644 --- a/crates/task/src/task_template.rs +++ b/crates/task/src/task_template.rs @@ -184,13 +184,27 @@ impl TaskTemplate { .context("hashing task variables") .log_err()?; let id = TaskId(format!("{id_base}_{task_hash}_{variables_hash}")); - let mut env = substitute_all_template_variables_in_map( - &self.env, - &task_variables, - &variable_names, - &mut substituted_variables, - )?; - env.extend(task_variables.into_iter().map(|(k, v)| (k, v.to_owned()))); + + let env = { + // Start with the project environment as the base. + let mut env = cx.project_env.clone(); + + // Extend that environment with what's defined in the TaskTemplate + env.extend(self.env.clone()); + + // Then we replace all task variables that could be set in environment variables + let mut env = substitute_all_template_variables_in_map( + &env, + &task_variables, + &variable_names, + &mut substituted_variables, + )?; + + // Last step: set the task variables as environment variables too + env.extend(task_variables.into_iter().map(|(k, v)| (k, v.to_owned()))); + env + }; + Some(ResolvedTask { id: id.clone(), substituted_variables, @@ -392,6 +406,7 @@ mod tests { let cx = TaskContext { cwd: None, task_variables: TaskVariables::default(), + project_env: HashMap::default(), }; assert_eq!( resolved_task(&task_without_cwd, &cx).cwd, @@ -403,6 +418,7 @@ mod tests { let cx = TaskContext { cwd: Some(context_cwd.clone()), task_variables: TaskVariables::default(), + project_env: HashMap::default(), }; assert_eq!( resolved_task(&task_without_cwd, &cx) @@ -421,6 +437,7 @@ mod tests { let cx = TaskContext { cwd: None, task_variables: TaskVariables::default(), + project_env: HashMap::default(), }; assert_eq!( resolved_task(&task_with_cwd, &cx) @@ -434,6 +451,7 @@ mod tests { let cx = TaskContext { cwd: Some(context_cwd.clone()), task_variables: TaskVariables::default(), + project_env: HashMap::default(), }; assert_eq!( resolved_task(&task_with_cwd, &cx) @@ -512,6 +530,7 @@ mod tests { &TaskContext { cwd: None, task_variables: TaskVariables::from_iter(all_variables.clone()), + project_env: HashMap::default(), }, ).unwrap_or_else(|| panic!("Should successfully resolve task {task_with_all_variables:?} with variables {all_variables:?}")); @@ -599,6 +618,7 @@ mod tests { &TaskContext { cwd: None, task_variables: TaskVariables::from_iter(not_all_variables), + project_env: HashMap::default(), }, ); assert_eq!(resolved_task_attempt, None, "If any of the Zed task variables is not substituted, the task should not be resolved, but got some resolution without the variable {removed_variable:?} (index {i})"); @@ -651,6 +671,7 @@ mod tests { VariableName::Symbol, "test_symbol".to_string(), ))), + project_env: HashMap::default(), }; for (i, symbol_dependent_task) in [ @@ -725,4 +746,74 @@ mod tests { .insert(VariableName::Symbol, "my-symbol".to_string()); assert!(faulty_go_test.resolve_task("base", &context).is_some()); } + + #[test] + fn test_project_env() { + let all_variables = [ + (VariableName::Row, "1234".to_string()), + (VariableName::Column, "5678".to_string()), + (VariableName::File, "test_file".to_string()), + (VariableName::Symbol, "my symbol".to_string()), + ]; + + let template = TaskTemplate { + label: "my task".to_string(), + command: format!( + "echo {} {}", + VariableName::File.template_value(), + VariableName::Symbol.template_value(), + ), + args: vec![], + env: HashMap::from_iter([ + ( + "TASK_ENV_VAR1".to_string(), + "TASK_ENV_VAR1_VALUE".to_string(), + ), + ( + "TASK_ENV_VAR2".to_string(), + format!( + "env_var_2 {} {}", + VariableName::Row.template_value(), + VariableName::Column.template_value() + ), + ), + ( + "PROJECT_ENV_WILL_BE_OVERWRITTEN".to_string(), + "overwritten".to_string(), + ), + ]), + ..TaskTemplate::default() + }; + + let project_env = HashMap::from_iter([ + ( + "PROJECT_ENV_VAR1".to_string(), + "PROJECT_ENV_VAR1_VALUE".to_string(), + ), + ( + "PROJECT_ENV_WILL_BE_OVERWRITTEN".to_string(), + "PROJECT_ENV_WILL_BE_OVERWRITTEN_VALUE".to_string(), + ), + ]); + + let context = TaskContext { + cwd: None, + task_variables: TaskVariables::from_iter(all_variables.clone()), + project_env, + }; + + let resolved = template + .resolve_task(TEST_ID_BASE, &context) + .unwrap() + .resolved + .unwrap(); + + assert_eq!(resolved.env["TASK_ENV_VAR1"], "TASK_ENV_VAR1_VALUE"); + assert_eq!(resolved.env["TASK_ENV_VAR2"], "env_var_2 1234 5678"); + assert_eq!(resolved.env["PROJECT_ENV_VAR1"], "PROJECT_ENV_VAR1_VALUE"); + assert_eq!( + resolved.env["PROJECT_ENV_WILL_BE_OVERWRITTEN"], + "overwritten" + ); + } } diff --git a/crates/tasks_ui/src/lib.rs b/crates/tasks_ui/src/lib.rs index 3d4e062cb2..d12d3c6bd2 100644 --- a/crates/tasks_ui/src/lib.rs +++ b/crates/tasks_ui/src/lib.rs @@ -180,7 +180,7 @@ fn active_item_selection_properties( #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{collections::HashMap, sync::Arc}; use editor::Editor; use gpui::{Entity, TestAppContext}; @@ -306,7 +306,8 @@ mod tests { (VariableName::WorktreeRoot, "/dir".into()), (VariableName::Row, "1".into()), (VariableName::Column, "1".into()), - ]) + ]), + project_env: HashMap::default(), } ); @@ -332,7 +333,8 @@ mod tests { (VariableName::Column, "15".into()), (VariableName::SelectedText, "is_i".into()), (VariableName::Symbol, "this_is_a_rust_file".into()), - ]) + ]), + project_env: HashMap::default(), } ); @@ -356,7 +358,8 @@ mod tests { (VariableName::Row, "1".into()), (VariableName::Column, "1".into()), (VariableName::Symbol, "this_is_a_test".into()), - ]) + ]), + project_env: HashMap::default(), } ); }