diff --git a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts index 7eccb7f7f..13e33fd59 100644 --- a/packages/pyright-internal/src/analyzer/parseTreeUtils.ts +++ b/packages/pyright-internal/src/analyzer/parseTreeUtils.ts @@ -97,7 +97,7 @@ export function findNodeByPosition( // Returns the deepest node that contains the specified offset. export function findNodeByOffset(node: ParseNode, offset: number): ParseNode | undefined { - if (offset < node.start || offset > TextRange.getEnd(node)) { + if (!TextRange.overlaps(node, offset)) { return undefined; } @@ -109,8 +109,26 @@ export function findNodeByOffset(node: ParseNode, offset: number): ParseNode | u // when there are many siblings, such as statements in a module/suite // or expressions in a list, etc. Otherwise, we will have to traverse // every sibling before finding the correct one. - const index = getIndexContaining(children, offset); + let index = getIndexContaining(children, offset, TextRange.overlaps); + if (index >= 0) { + // Find first sibling that overlaps with the offset. This ensures that + // our binary search result matches what we would have returned via a + // linear search. + let searchIndex = index - 1; + while (searchIndex >= 0) { + const previousChild = children[searchIndex]; + if (previousChild) { + if (TextRange.overlaps(previousChild, offset)) { + index = searchIndex; + } else { + break; + } + } + + searchIndex--; + } + children = [children[index]]; } } diff --git a/packages/pyright-internal/src/common/textRangeCollection.ts b/packages/pyright-internal/src/common/textRangeCollection.ts index 5012f3726..8fdbdf261 100644 --- a/packages/pyright-internal/src/common/textRangeCollection.ts +++ b/packages/pyright-internal/src/common/textRangeCollection.ts @@ -100,7 +100,11 @@ export class TextRangeCollection { } } -export function getIndexContaining(arr: (T | undefined)[], position: number) { +export function getIndexContaining( + arr: (T | undefined)[], + position: number, + inRange: (item: T, position: number) => boolean = TextRange.contains +) { if (arr.length === 0) { return -1; } @@ -109,25 +113,25 @@ export function getIndexContaining(arr: (T | undefined)[], let max = arr.length - 1; while (min <= max) { const mid = Math.floor(min + (max - min) / 2); - const item = findNonNullElement(arr, mid, min, max); - if (item === undefined) { + const element = findNonNullElement(arr, mid, min, max); + if (element === undefined) { return -1; } - if (TextRange.contains(item, position)) { - return mid; + if (inRange(element.item, position)) { + return element.index; } - const nextItem = findNonNullElement(arr, mid + 1, mid + 1, max); - if (nextItem === undefined) { + const nextElement = findNonNullElement(arr, mid + 1, mid + 1, max); + if (nextElement === undefined) { return -1; } - if (mid < arr.length - 1 && TextRange.getEnd(item) <= position && position < nextItem.start) { + if (mid < arr.length - 1 && TextRange.getEnd(element.item) <= position && position < nextElement.item.start) { return -1; } - if (position < item.start) { + if (position < element.item.start) { max = mid - 1; } else { min = mid + 1; @@ -142,24 +146,24 @@ function findNonNullElement( position: number, min: number, max: number -): T | undefined { +): { index: number; item: T } | undefined { const item = arr[position]; if (item) { - return item; + return { index: position, item }; } // Search forward and backward until it finds non-null value. for (let i = position + 1; i <= max; i++) { - const item = arr[position]; + const item = arr[i]; if (item) { - return item; + return { index: i, item }; } } for (let i = position - 1; i >= min; i--) { - const item = arr[position]; + const item = arr[i]; if (item) { - return item; + return { index: i, item }; } } diff --git a/packages/pyright-internal/src/tests/parseTreeUtils.test.ts b/packages/pyright-internal/src/tests/parseTreeUtils.test.ts index 58073e258..3b3388f23 100644 --- a/packages/pyright-internal/src/tests/parseTreeUtils.test.ts +++ b/packages/pyright-internal/src/tests/parseTreeUtils.test.ts @@ -9,6 +9,7 @@ import assert from 'assert'; import { + findNodeByOffset, getDottedName, getDottedNameWithGivenNodeAsLastName, getFirstAncestorOrSelfOfKind, @@ -318,6 +319,104 @@ test('printExpression', () => { } }); +test('findNodeByOffset', () => { + const code = ` +//// class A: +//// def read(self): pass +//// +//// class B(A): +//// x1 = 1 +//// def r[|/*marker*/|] +//// + `; + + const state = parseAndGetTestState(code).state; + const range = state.getRangeByMarkerName('marker')!; + const sourceFile = state.program.getBoundSourceFile(range.marker!.fileUri)!; + + const node = findNodeByOffset(sourceFile.getParseResults()!.parserOutput.parseTree, range.pos); + assert.strictEqual(node?.nodeType, ParseNodeType.Name); + assert.strictEqual((node as NameNode).value, 'r'); +}); + +test('findNodeByOffset with binary search', () => { + const code = ` +//// class A: +//// def read(self): pass +//// +//// class B(A): +//// x1 = 1 +//// x2 = 2 +//// x3 = 3 +//// x4 = 4 +//// x5 = 5 +//// x6 = 6 +//// x7 = 7 +//// x8 = 8 +//// x9 = 9 +//// x10 = 10 +//// x11 = 11 +//// x12 = 12 +//// x13 = 13 +//// x14 = 14 +//// x15 = 15 +//// x16 = 16 +//// x17 = 17 +//// x18 = 18 +//// x19 = 19 +//// def r[|/*marker*/|] +//// + `; + + const state = parseAndGetTestState(code).state; + const range = state.getRangeByMarkerName('marker')!; + const sourceFile = state.program.getBoundSourceFile(range.marker!.fileUri)!; + + const node = findNodeByOffset(sourceFile.getParseResults()!.parserOutput.parseTree, range.pos); + assert.strictEqual(node?.nodeType, ParseNodeType.Name); + assert.strictEqual((node as NameNode).value, 'r'); +}); + +test('findNodeByOffset with binary search choose earliest match', () => { + const code = ` +//// class A: +//// def read(self): pass +//// +//// class B(A): +//// x1 = 1 +//// x2 = 2 +//// x3 = 3 +//// x4 = 4 +//// x5 = 5 +//// x6 = 6 +//// x7 = 7 +//// x8 = 8 +//// x9 = 9 +//// x10 = 10 +//// x11 = 11 +//// x12 = 12 +//// x13 = 13 +//// x14 = 14 +//// x15 = 15 +//// x16 = 16 +//// x17 = 17 +//// x18 = 18 +//// x19 = 19 +//// def r[|/*marker*/|] +//// x20 = 20 +//// x21 = 21 +//// + `; + + const state = parseAndGetTestState(code).state; + const range = state.getRangeByMarkerName('marker')!; + const sourceFile = state.program.getBoundSourceFile(range.marker!.fileUri)!; + + const node = findNodeByOffset(sourceFile.getParseResults()!.parserOutput.parseTree, range.pos); + assert.strictEqual(node?.nodeType, ParseNodeType.Name); + assert.strictEqual((node as NameNode).value, 'r'); +}); + function testNodeRange(state: TestState, markerName: string, type: ParseNodeType, includeTrailingBlankLines = false) { const range = state.getRangeByMarkerName(markerName)!; const sourceFile = state.program.getBoundSourceFile(range.marker!.fileUri)!;