diff --git a/app/gui/src/controller/graph.rs b/app/gui/src/controller/graph.rs index 574285f6062..02f8cfe12cc 100644 --- a/app/gui/src/controller/graph.rs +++ b/app/gui/src/controller/graph.rs @@ -283,6 +283,16 @@ impl Connections { target: Self::convert_endpoint(&connection.target), } } + + /// Return all connections that involve the given node. + pub fn with_node(&self, node: node::Id) -> impl Iterator { + self.connections + .iter() + .filter(move |conn| conn.source.node == node || conn.target.node == node) + .copied() + .collect_vec() + .into_iter() + } } diff --git a/app/gui/src/controller/graph/executed.rs b/app/gui/src/controller/graph/executed.rs index 48c63c972bf..b74950d787f 100644 --- a/app/gui/src/controller/graph/executed.rs +++ b/app/gui/src/controller/graph/executed.rs @@ -451,6 +451,21 @@ impl Handle { } } + /// Remove all the connections from the graph. This is a convenience method that calls + /// [`disconnect`] for each connection. If any of the calls fails, the first error is + /// propagated, but all the connections are attempted to be disconnected. + pub fn disconnect_all(&self, connections: impl Iterator) -> FallibleResult { + let errors = + connections.map(|c| self.disconnect(&c)).filter_map(|r| r.err()).collect::>(); + // Failure has no good way to propagate multiple errors with `Failure`. So we propagate + // only the first one. + if let Some(error) = errors.into_iter().next() { + Err(error) + } else { + Ok(()) + } + } + /// Set the execution environment. pub async fn set_execution_environment( &self, diff --git a/app/gui/src/presenter/graph.rs b/app/gui/src/presenter/graph.rs index a88576fdbf9..29090a92d9e 100644 --- a/app/gui/src/presenter/graph.rs +++ b/app/gui/src/presenter/graph.rs @@ -292,6 +292,17 @@ impl Model { || { let ast_id = self.state.update_from_view().remove_node(id)?; self.widget.remove_all_node_widgets(ast_id); + + let connections = self.controller.connections(); + let node_connections = connections.map(|c| c.with_node(ast_id)); + let disconnect_result = node_connections.map(|c| self.controller.disconnect_all(c)); + if let Err(e) = disconnect_result { + warn!( + "Failed to disconnect all connections from node {:?} because of {:?}", + ast_id, e + ); + } + Some(self.controller.graph().remove_node(ast_id)) }, "remove node",