Collapsing nodes (#8686)

Closes #8067

Also fixes `nodeRects` map, as it no longer stores invlisible nodes.

https://github.com/enso-org/enso/assets/6566674/ba66c99f-df74-497b-8924-dc779cce8ef5

# Important Notes
Positioning of newly created nodes is not handled yet, as it requires fixes in the Ast editing API.
This commit is contained in:
Ilya Bogdanov 2024-01-12 18:08:17 +04:00 committed by GitHub
parent 1b3c9638ea
commit 58cf4e5244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 321 additions and 54 deletions

View File

@ -1,4 +1,5 @@
import {
averagePositionPlacement,
mouseDictatedPlacement,
nonDictatedPlacement,
previousNodeDictatedPlacement,
@ -447,6 +448,54 @@ describe('Mouse dictated placement', () => {
})
})
describe('Average position placement', () => {
function environment(selectedNodeRects: Rect[], nonSelectedNodeRects: Rect[]): Environment {
return {
screenBounds,
nodeRects: [...selectedNodeRects, ...nonSelectedNodeRects],
selectedNodeRects,
get mousePosition() {
return getMousePosition()
},
}
}
function options(): { horizontalGap: number; verticalGap: number } {
return {
get horizontalGap() {
return getHorizontalGap()
},
get verticalGap() {
return getVerticalGap()
},
}
}
test('One selected, no other nodes', () => {
const X = 1100
const Y = 700
const selectedNodeRects = [rectAt(X, Y)]
const result = averagePositionPlacement(nodeSize, environment(selectedNodeRects, []), options())
expect(result).toEqual({ position: new Vec2(X, Y), pan: undefined })
})
test('Multiple selected, no other nodes', () => {
const selectedNodeRects = [rectAt(1000, 600), rectAt(1300, 800)]
const result = averagePositionPlacement(nodeSize, environment(selectedNodeRects, []), options())
expect(result).toEqual({ position: new Vec2(1150, 700), pan: undefined })
})
test('Average position occupied', () => {
const selectedNodeRects = [rectAt(1000, 600), rectAt(1300, 800)]
const result = averagePositionPlacement(
nodeSize,
environment(selectedNodeRects, [rectAt(1150, 700)]),
options(),
)
expect(result).toEqual({ position: new Vec2(1150, 744), pan: undefined })
})
})
// === Helpers for debugging ===
function generateVueCodeForNonDictatedPlacement(newNode: Rect, rects: Rect[]) {

View File

@ -128,3 +128,54 @@ export function mouseDictatedPlacement(
const nodeRadius = nodeSize.y / 2
return { position: mousePosition.add(new Vec2(nodeRadius, nodeRadius)) }
}
/** The new node should appear at the average position of selected nodes.
*
* If the desired place is already occupied by non-selected node, it should be moved down to the closest free space.
*
* Specifically, this code, in order:
* - calculates the average position of selected nodes
* - searches for all vertical spans below the initial position,
* that horizontally intersect the initial position (no horizontal gap is required between
* the new node and old nodes)
* - shifts the node down (if required) until there is sufficient vertical space -
* the height of the node, in addition to the specified gap both above and below the node.
*/
export function averagePositionPlacement(
nodeSize: Vec2,
{ screenBounds, selectedNodeRects, nodeRects }: Environment,
{ verticalGap = theme.node.vertical_gap }: PlacementOptions = {},
): Placement {
let totalPosition = new Vec2(0, 0)
let selectedNodeRectsCount = 0
for (const rect of selectedNodeRects) {
totalPosition = totalPosition.add(rect.pos)
selectedNodeRectsCount++
}
const initialPosition = totalPosition.scale(1.0 / selectedNodeRectsCount)
const nonSelectedNodeRects = []
outer: for (const rect of nodeRects) {
for (const sel of selectedNodeRects) {
if (sel.equals(rect)) {
continue outer
}
}
nonSelectedNodeRects.push(rect)
}
let top = initialPosition.y
const initialRect = new Rect(initialPosition, nodeSize)
const nodeRectsSorted = Array.from(nonSelectedNodeRects).sort((a, b) => a.top - b.top)
for (const rect of nodeRectsSorted) {
if (initialRect.intersectsX(rect) && rect.bottom + verticalGap > top) {
if (rect.top - (top + nodeSize.y) < verticalGap) {
top = rect.bottom + verticalGap
}
}
}
const finalPosition = new Vec2(initialPosition.x, top)
if (new Rect(finalPosition, nodeSize).within(screenBounds)) {
return { position: finalPosition }
} else {
return { position: finalPosition, pan: finalPosition.sub(initialPosition) }
}
}

View File

@ -26,6 +26,8 @@ import { useGraphStore } from '@/stores/graph'
import type { RequiredImport } from '@/stores/graph/imports'
import { useProjectStore } from '@/stores/project'
import { groupColorVar, useSuggestionDbStore } from '@/stores/suggestionDatabase'
import { assert, bail } from '@/util/assert'
import { BodyBlock } from '@/util/ast/abstract'
import { colorFromString } from '@/util/colors'
import { Rect } from '@/util/data/rect'
import { Vec2 } from '@/util/data/vec2'
@ -255,13 +257,23 @@ const graphBindingsHandler = graphBindings.handler({
},
collapse() {
if (keyboardBusy()) return false
const selected = nodeSelection.selected
const selected = new Set(nodeSelection.selected)
if (selected.size == 0) return
try {
const info = prepareCollapsedInfo(nodeSelection.selected, graphStore.db)
performCollapse(info)
const info = prepareCollapsedInfo(selected, graphStore.db)
const currentMethod = projectStore.executionContext.getStackTop()
const currentMethodName = graphStore.db.stackItemToMethodName(currentMethod)
if (currentMethodName == null) {
bail(`Cannot get the method name for the current execution stack item. ${currentMethod}`)
}
graphStore.editAst((module) => {
if (graphStore.moduleRoot == null) bail(`Module root is missing.`)
const topLevel = module.get(graphStore.moduleRoot)
assert(topLevel instanceof BodyBlock)
return performCollapse(info, module, topLevel, graphStore.db, currentMethodName)
})
} catch (err) {
console.log(`Error while collapsing, this is not normal. ${err}`)
console.log('Error while collapsing, this is not normal.', err)
}
},
enterNode() {

View File

@ -20,7 +20,7 @@ import { Vec2 } from '@/util/data/vec2'
import { displayedIconOf } from '@/util/getIconName'
import { setIfUndefined } from 'lib0/map'
import type { ExprId, VisualizationIdentifier } from 'shared/yjsModel'
import { computed, ref, watch, watchEffect } from 'vue'
import { computed, onUnmounted, ref, watch, watchEffect } from 'vue'
const MAXIMUM_CLICK_LENGTH_MS = 300
const MAXIMUM_CLICK_DISTANCE_SQ = 50
@ -73,6 +73,8 @@ const outputPortsSet = computed(() => {
const widthOverridePx = ref<number>()
const nodeId = computed(() => props.node.rootSpan.exprId)
onUnmounted(() => graph.unregisterNodeRect(nodeId.value))
const rootNode = ref<HTMLElement>()
const contentNode = ref<HTMLElement>()
const nodeSize = useResizeObserver(rootNode)

View File

@ -1,5 +1,6 @@
import { GraphDb } from '@/stores/graph/graphDatabase'
import { Ast } from '@/util/ast'
import { moduleMethodNames } from '@/util/ast/abstract'
import { unwrap } from '@/util/data/result'
import { tryIdentifier, type Identifier } from '@/util/qualifiedName'
import assert from 'assert'
@ -40,8 +41,8 @@ interface RefactoredInfo {
id: ExprId
/** The pattern of the refactored node. Included for convinience, collapsing does not affect it. */
pattern: string
/** The new expression of the refactored node. A call to the extracted function with the list of necessary arguments. */
expression: string
/** The list of necessary arguments for a call of the collapsed function. */
arguments: Identifier[]
}
// === prepareCollapsedInfo ===
@ -55,19 +56,20 @@ export function prepareCollapsedInfo(selected: Set<ExprId>, graphDb: GraphDb): C
const leaves = new Set([...selected])
const inputs: Identifier[] = []
let output: Output | null = null
for (const [targetExprId, sourceExprIds] of graphDb.connections.allReverse()) {
for (const [targetExprId, sourceExprIds] of graphDb.allConnections.allReverse()) {
const target = graphDb.getExpressionNodeId(targetExprId)
if (target == null) throw new Error(`Connection target node for id ${targetExprId} not found.`)
if (target == null) continue
for (const sourceExprId of sourceExprIds) {
const source = graphDb.getPatternExpressionNodeId(sourceExprId)
if (source == null)
throw new Error(`Connection source node for id ${sourceExprId} not found.`)
const startsInside = selected.has(source)
const startsInside = source != null && selected.has(source)
const endsInside = selected.has(target)
const stringIdentifier = graphDb.getOutputPortIdentifier(sourceExprId)
if (stringIdentifier == null) throw new Error(`Source node (${source}) has no pattern.`)
if (stringIdentifier == null)
throw new Error(`Source node (${source}) has no output identifier.`)
const identifier = unwrap(tryIdentifier(stringIdentifier))
leaves.delete(source)
if (source != null) {
leaves.delete(source)
}
if (!startsInside && endsInside) {
inputs.push(identifier)
} else if (startsInside && !endsInside) {
@ -105,21 +107,109 @@ export function prepareCollapsedInfo(selected: Set<ExprId>, graphDb: GraphDb): C
refactored: {
id: output.node,
pattern,
expression: 'Main.collapsed' + (inputs.length > 0 ? ' ' : '') + inputs.join(' '),
arguments: inputs,
},
}
}
// === performRefactoring ===
/** Generate a safe method name for a collapsed function using `baseName` as a prefix. */
function findSafeMethodName(module: Ast.Module, baseName: string): string {
const allIdentifiers = moduleMethodNames(module)
if (!allIdentifiers.has(baseName)) {
return baseName
}
let index = 1
while (allIdentifiers.has(`${baseName}${index}`)) {
index++
}
return `${baseName}${index}`
}
// === performCollapse ===
// We support working inside `Main` module of the project at the moment.
const MODULE_NAME = 'Main'
const COLLAPSED_FUNCTION_NAME = 'collapsed'
/** Perform the actual AST refactoring for collapsing nodes. */
export function performCollapse(_info: CollapsedInfo) {
// The general flow of this function:
// 1. Create a new function with a unique name and a list of arguments from the `ExtractedInfo`.
// 2. Move all nodes with `ids` from the `ExtractedInfo` into this new function. Use the order of their original definition.
// 3. Use a single identifier `output.identifier` as the return value of the function.
// 4. Change the expression of the `RefactoredInfo.id` node to the `RefactoredINfo.expression`
throw new Error('Not yet implemented, requires AST editing.')
export function performCollapse(
info: CollapsedInfo,
module: Ast.Module,
topLevel: Ast.BodyBlock,
db: GraphDb,
currentMethodName: string,
): Ast.MutableModule {
const functionAst = Ast.findModuleMethod(module, currentMethodName)
if (!(functionAst instanceof Ast.Function) || !(functionAst.body instanceof Ast.BodyBlock)) {
throw new Error(`Expected a collapsable function, found ${functionAst}.`)
}
const functionBlock = functionAst.body
const posToInsert = findInsertionPos(module, topLevel, currentMethodName)
const collapsedName = findSafeMethodName(module, COLLAPSED_FUNCTION_NAME)
const astIdsToExtract = new Set(
[...info.extracted.ids].map((nodeId) => db.nodeIdToNode.get(nodeId)?.outerExprId),
)
const astIdToReplace = db.nodeIdToNode.get(info.refactored.id)?.outerExprId
const collapsed = []
const refactored = []
const edit = module.edit()
const lines = functionBlock.lines()
for (const line of lines) {
const astId = line.expression?.node.exprId
const ast = astId != null ? module.get(astId) : null
if (ast == null) continue
if (astIdsToExtract.has(astId)) {
collapsed.push(ast)
if (astId === astIdToReplace) {
const newAst = collapsedCallAst(info, collapsedName, edit)
refactored.push({ expression: { node: newAst } })
}
} else {
refactored.push({ expression: { node: ast } })
}
}
const outputIdentifier = info.extracted.output?.identifier
if (outputIdentifier != null) {
collapsed.push(Ast.Ident.new(edit, outputIdentifier))
}
// Update the definiton of refactored function.
const refactoredBlock = Ast.BodyBlock.new(refactored, edit)
edit.replaceRef(functionBlock.exprId, refactoredBlock)
// new Ast.BodyBlock(edit, functionBlock.exprId, refactored)
const args: Ast.Ast[] = info.extracted.inputs.map((arg) => Ast.Ident.new(edit, arg))
const collapsedFunction = Ast.Function.new(edit, collapsedName, args, collapsed, true)
topLevel.insert(edit, posToInsert, collapsedFunction)
return edit
}
/** Prepare a method call expression for collapsed method. */
function collapsedCallAst(
info: CollapsedInfo,
collapsedName: string,
edit: Ast.MutableModule,
): Ast.Ast {
const pattern = info.refactored.pattern
const args = info.refactored.arguments
const functionName = `${MODULE_NAME}.${collapsedName}`
const expression = functionName + (args.length > 0 ? ' ' : '') + args.join(' ')
const assignment = Ast.Assignment.new(edit, pattern, Ast.parse(expression, edit))
return assignment
}
/** Find the position before the current method to insert a collapsed one. */
function findInsertionPos(
module: Ast.Module,
topLevel: Ast.BodyBlock,
currentMethodName: string,
): number {
const currentFuncPosition = topLevel.lines().findIndex((line) => {
const node = line.expression?.node
const expr = node ? module.get(node.exprId)?.innerExpression() : null
return expr instanceof Ast.Function && expr.name?.code() === currentMethodName
})
return currentFuncPosition === -1 ? 0 : currentFuncPosition
}
// === Tests ===
@ -148,7 +238,7 @@ if (import.meta.vitest) {
}
refactored: {
replace: string
with: { pattern: string; expression: string }
with: { pattern: string; arguments: string[] }
}
}
}
@ -166,7 +256,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'c = A + B',
with: { pattern: 'c', expression: 'Main.collapsed a' },
with: { pattern: 'c', arguments: ['a'] },
},
},
},
@ -182,7 +272,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'd = a + b',
with: { pattern: 'd', expression: 'Main.collapsed a b' },
with: { pattern: 'd', arguments: ['a', 'b'] },
},
},
},
@ -198,7 +288,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'c = 50 + d',
with: { pattern: 'c', expression: 'Main.collapsed' },
with: { pattern: 'c', arguments: [] },
},
},
},
@ -219,7 +309,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'vector = range.to_vector',
with: { pattern: 'vector', expression: 'Main.collapsed number1 number2' },
with: { pattern: 'vector', arguments: ['number1', 'number2'] },
},
},
},
@ -261,6 +351,6 @@ if (import.meta.vitest) {
expect(extracted.ids).toEqual(new Set(expectedIds))
expect(refactored.id).toEqual(expectedRefactoredId)
expect(refactored.pattern).toEqual(expectedRefactored.with.pattern)
expect(refactored.expression).toEqual(expectedRefactored.with.expression)
expect(refactored.arguments).toEqual(expectedRefactored.with.arguments)
})
}

View File

@ -30,16 +30,7 @@ export function useStackNavigator() {
})
function stackItemToLabel(item: StackItem): string {
switch (item.type) {
case 'ExplicitCall': {
return item.methodPointer.name
}
case 'LocalCall': {
const exprId = item.expressionId
const info = graphStore.db.getExpressionInfo(exprId)
return info?.methodCall?.methodPointer.name ?? 'unknown'
}
}
return graphStore.db.stackItemToMethodName(item) ?? 'unknown'
}
function handleBreadcrumbClick(index: number) {

View File

@ -12,7 +12,7 @@ import { Vec2 } from '@/util/data/vec2'
import { ReactiveDb, ReactiveIndex, ReactiveMapping } from '@/util/database/reactiveDb'
import * as random from 'lib0/random'
import * as set from 'lib0/set'
import { methodPointerEquals, type MethodCall } from 'shared/languageServerTypes'
import { methodPointerEquals, type MethodCall, type StackItem } from 'shared/languageServerTypes'
import {
IdMap,
visMetadataEquals,
@ -137,17 +137,30 @@ export class GraphDb {
// Display connection starting from existing node.
//TODO[ao]: When implementing input nodes, they should be taken into account here.
if (srcNode == null) return []
function* allTargets(db: GraphDb): Generator<[ExprId, ExprId]> {
for (const usage of info.usages) {
const targetNode = db.getExpressionNodeId(usage)
// Display only connections to existing targets and different than source node
if (targetNode == null || targetNode === srcNode) continue
yield [alias, usage]
}
}
return Array.from(allTargets(this))
return Array.from(this.connectionsFromBindings(info, alias, srcNode))
})
/** Same as {@link GraphDb.connections}, but also includes connections without source node,
* e.g. input arguments of the collapsed function.
*/
allConnections = new ReactiveIndex(this.bindings.bindings, (alias, info) => {
const srcNode = this.getPatternExpressionNodeId(alias)
return Array.from(this.connectionsFromBindings(info, alias, srcNode))
})
private *connectionsFromBindings(
info: BindingInfo,
alias: ExprId,
srcNode: ExprId | undefined,
): Generator<[ExprId, ExprId]> {
for (const usage of info.usages) {
const targetNode = this.getExpressionNodeId(usage)
// Display only connections to existing targets and different than source node.
if (targetNode == null || targetNode === srcNode) continue
yield [alias, usage]
}
}
/** Output port bindings of the node. Lists all bindings that can be dragged out from a node. */
nodeOutputPorts = new ReactiveIndex(this.nodeIdToNode, (id, entry) => {
if (entry.pattern == null) return []
@ -206,6 +219,10 @@ export class GraphDb {
return this.bindings.bindings.get(source)?.identifier
}
allIdentifiers(): string[] {
return [...this.bindings.identifierToBindingId.allForward()].map(([ident, _]) => ident)
}
identifierUsed(ident: string): boolean {
return this.bindings.identifierToBindingId.hasKey(ident)
}
@ -249,6 +266,20 @@ export class GraphDb {
this.nodeIdToNode.moveToLast(id)
}
/** Get the method name from the stack item. */
stackItemToMethodName(item: StackItem): string | undefined {
switch (item.type) {
case 'ExplicitCall': {
return item.methodPointer.name
}
case 'LocalCall': {
const exprId = item.expressionId
const info = this.getExpressionInfo(exprId)
return info?.methodCall?.methodPointer.name
}
}
}
readFunctionAst(functionAst_: Ast.Function, getMeta: (id: ExprId) => NodeMetadata | undefined) {
const currentNodeIds = new Set<ExprId>()
for (const nodeAst of functionAst_.bodyExpressions()) {

View File

@ -63,6 +63,8 @@ export const useGraphStore = defineStore('graph', () => {
const astModule: Module = MutableModule.Observable()
const moduleRoot = ref<AstId>()
let moduleDirty = false
const nodeRects = reactive(new Map<ExprId, Rect>())
const vizRects = reactive(new Map<ExprId, Rect>())
// Initialize text and idmap once module is loaded (data != null)
watch(data, () => {
@ -78,8 +80,6 @@ export const useGraphStore = defineStore('graph', () => {
toRef(suggestionDb, 'groups'),
proj.computedValueRegistry,
)
const nodeRects = reactive(new Map<ExprId, Rect>())
const vizRects = reactive(new Map<ExprId, Rect>())
const portInstances = reactive(new Map<PortId, Set<PortViewInstance>>())
const editedNodeInfo = ref<NodeEditInfo>()
const imports = ref<{ import: Import; span: SourceRange }[]>([])
@ -224,11 +224,15 @@ export const useGraphStore = defineStore('graph', () => {
commitEdit(edit, new Map([[rhs.exprId, meta]]))
}
function editAst(cb: (module: Ast.Module) => Ast.MutableModule) {
const edit = cb(astModule)
commitEdit(edit)
}
function deleteNode(id: ExprId) {
const node = db.nodeIdToNode.get(id)
if (!node) return
proj.module?.doc.metadata.delete(node.outerExprId)
nodeRects.delete(id)
const root = moduleRoot.value
if (!root) {
console.error(`BUG: Cannot delete node: No module root.`)
@ -316,6 +320,11 @@ export const useGraphStore = defineStore('graph', () => {
else vizRects.delete(id)
}
function unregisterNodeRect(id: ExprId) {
nodeRects.delete(id)
vizRects.delete(id)
}
function addPortInstance(id: PortId, instance: PortViewInstance) {
map.setIfUndefined(portInstances, id, set.create).add(instance)
}
@ -413,12 +422,15 @@ export const useGraphStore = defineStore('graph', () => {
moduleCode,
nodeRects,
vizRects,
unregisterNodeRect,
methodAst,
editAst,
astModule,
createEdgeFromOutput,
disconnectSource,
disconnectTarget,
clearUnconnected,
moduleRoot,
createNode,
deleteNode,
setNodeContent,

View File

@ -1,7 +1,6 @@
import { useProjectStore } from '@/stores/project'
import { entryQn, type SuggestionEntry, type SuggestionId } from '@/stores/suggestionDatabase/entry'
import { applyUpdates, entryFromLs } from '@/stores/suggestionDatabase/lsUpdate'
import { type Opt } from '@/util/data/opt'
import { ReactiveDb, ReactiveIndex } from '@/util/database/reactiveDb'
import { AsyncQueue, rpcWithRetries } from '@/util/net'
import { qnJoin, qnParent, tryQualifiedName, type QualifiedName } from '@/util/qualifiedName'

View File

@ -1017,6 +1017,26 @@ export class Function extends Ast {
this.body_ = body
setParent(module, this.exprId, ...this.concreteChildren())
}
static new(
module: MutableModule,
name: string,
args: Ast[],
exprs: Ast[],
trailingNewline?: boolean,
): Function {
const id = newAstId()
const exprs_: BlockLine[] = exprs.map((expr) => ({ expression: { node: expr } }))
if (trailingNewline) {
exprs_.push({ newline: { node: Token.new('\n') }, expression: null })
}
const body = BodyBlock.new(exprs_, module)
const args_ = args.map((arg) => [{ node: makeChild(module, arg, id) }])
const ident = { node: Ident.new(module, name).exprId }
const equals = { node: Token.new('=') }
return new Function(module, id, ident, args_, equals, { node: body.exprId })
}
*concreteChildren(): IterableIterator<NodeChild> {
yield this.name_
for (const arg of this.args_) yield* arg
@ -1582,6 +1602,16 @@ export function tokenTreeWithIds(root: Ast): TokenTree {
]
}
export function moduleMethodNames(module: Module): Set<string> {
const result = new Set<string>()
for (const node of module.raw.nodes.values()) {
if (node instanceof Function && node.name) {
result.add(node.name.code())
}
}
return result
}
// FIXME: We should use alias analysis to handle ambiguous names correctly.
export function findModuleMethod(module: Module, name: string): Function | null {
for (const node of module.raw.nodes.values()) {

View File

@ -7,7 +7,7 @@ import diff from 'fast-diff'
import * as json from 'lib0/json'
import * as Y from 'yjs'
import { TextEdit } from '../shared/languageServerTypes'
import { IdMap, ModuleDoc, type NodeMetadata, type VisualizationMetadata } from '../shared/yjsModel'
import { ModuleDoc, type NodeMetadata, type VisualizationMetadata } from '../shared/yjsModel'
import * as fileFormat from './fileFormat'
import { serializeIdMap } from './serialization'