chore: add AriaSnapshot internal type (#33631)

This commit is contained in:
Pavel Feldman 2024-11-15 13:48:43 -08:00 committed by GitHub
parent 44cd1d03cc
commit d127255881
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 250 additions and 17 deletions

View File

@ -20,15 +20,36 @@ import { escapeRegExp, longestCommonSubstring } from '@isomorphic/stringUtils';
import { yamlEscapeKeyIfNeeded, yamlEscapeValueIfNeeded } from './yaml';
import type { AriaProps, AriaRole, AriaTemplateNode, AriaTemplateRoleNode, AriaTemplateTextNode } from '@isomorphic/ariaSnapshot';
type AriaNode = AriaProps & {
export type AriaNode = AriaProps & {
role: AriaRole | 'fragment';
name: string;
children: (AriaNode | string)[];
element: Element;
};
export function generateAriaTree(rootElement: Element): AriaNode {
export type AriaSnapshot = {
root: AriaNode;
elements: Map<number, Element>;
ids: Map<Element, number>;
};
export function generateAriaTree(rootElement: Element): AriaSnapshot {
const visited = new Set<Node>();
const snapshot: AriaSnapshot = {
root: { role: 'fragment', name: '', children: [], element: rootElement },
elements: new Map<number, Element>(),
ids: new Map<Element, number>(),
};
const addElement = (element: Element) => {
const id = snapshot.elements.size + 1;
snapshot.elements.set(id, element);
snapshot.ids.set(element, id);
};
addElement(rootElement);
const visit = (ariaNode: AriaNode, node: Node) => {
if (visited.has(node))
return;
@ -58,6 +79,7 @@ export function generateAriaTree(rootElement: Element): AriaNode {
}
}
addElement(element);
const childAriaNode = toAriaNode(element);
if (childAriaNode)
ariaNode.children.push(childAriaNode);
@ -100,15 +122,14 @@ export function generateAriaTree(rootElement: Element): AriaNode {
}
roleUtils.beginAriaCaches();
const ariaRoot: AriaNode = { role: 'fragment', name: '', children: [], element: rootElement };
try {
visit(ariaRoot, rootElement);
visit(snapshot.root, rootElement);
} finally {
roleUtils.endAriaCaches();
}
normalizeStringChildren(ariaRoot);
return ariaRoot;
normalizeStringChildren(snapshot.root);
return snapshot;
}
function toAriaNode(element: Element): AriaNode | null {
@ -143,10 +164,6 @@ function toAriaNode(element: Element): AriaNode | null {
return result;
}
export function renderedAriaTree(rootElement: Element, options?: { mode?: 'raw' | 'regex' }): string {
return renderAriaTree(generateAriaTree(rootElement), options);
}
function normalizeStringChildren(rootA11yNode: AriaNode) {
const flushChildren = (buffer: string[], normalizedChildren: (AriaNode | string)[]) => {
if (!buffer.length)
@ -203,7 +220,7 @@ export type MatcherReceived = {
};
export function matchesAriaTree(rootElement: Element, template: AriaTemplateNode): { matches: AriaNode[], received: MatcherReceived } {
const root = generateAriaTree(rootElement);
const root = generateAriaTree(rootElement).root;
const matches = matchesNodeDeep(root, template, false);
return {
matches,
@ -215,7 +232,7 @@ export function matchesAriaTree(rootElement: Element, template: AriaTemplateNode
}
export function getAllByAria(rootElement: Element, template: AriaTemplateNode): Element[] {
const root = generateAriaTree(rootElement);
const root = generateAriaTree(rootElement).root;
const matches = matchesNodeDeep(root, template, true);
return matches.map(n => n.element);
}
@ -285,7 +302,7 @@ function matchesNodeDeep(root: AriaNode, template: AriaTemplateNode, collectAll:
return results;
}
export function renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'regex' }): string {
export function renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'regex', ids?: Map<Element, number> }): string {
const lines: string[] = [];
const includeText = options?.mode === 'regex' ? textContributesInfo : () => true;
const renderString = options?.mode === 'regex' ? convertToBestGuessRegex : (str: string) => str;
@ -324,6 +341,11 @@ export function renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'r
key += ` [pressed]`;
if (ariaNode.selected === true)
key += ` [selected]`;
if (options?.ids) {
const id = options?.ids.get(ariaNode.element);
if (id)
key += ` [id=${id}]`;
}
const escapedKey = indent + '- ' + yamlEscapeKeyIfNeeded(key);
if (!ariaNode.children.length) {

View File

@ -34,7 +34,8 @@ import { kLayoutSelectorNames, type LayoutSelectorName, layoutSelectorScore } fr
import { asLocator } from '../../utils/isomorphic/locatorGenerators';
import type { Language } from '../../utils/isomorphic/locatorGenerators';
import { cacheNormalizedWhitespaces, normalizeWhiteSpace, trimStringWithEllipsis } from '../../utils/isomorphic/stringUtils';
import { matchesAriaTree, renderedAriaTree, getAllByAria } from './ariaSnapshot';
import { matchesAriaTree, getAllByAria, generateAriaTree, renderAriaTree } from './ariaSnapshot';
import type { AriaNode, AriaSnapshot } from './ariaSnapshot';
import type { AriaTemplateNode } from '@isomorphic/ariaSnapshot';
import { parseYamlTemplate } from '@isomorphic/ariaSnapshot';
@ -215,10 +216,27 @@ export class InjectedScript {
return new Set<Element>(result.map(r => r.element));
}
ariaSnapshot(node: Node, options?: { mode?: 'raw' | 'regex' }): string {
ariaSnapshot(node: Node, options?: { mode?: 'raw' | 'regex', id?: boolean }): string {
if (node.nodeType !== Node.ELEMENT_NODE)
throw this.createStacklessError('Can only capture aria snapshot of Element nodes.');
return renderedAriaTree(node as Element, options);
const ariaSnapshot = generateAriaTree(node as Element);
return renderAriaTree(ariaSnapshot.root, options);
}
ariaSnapshotAsObject(node: Node): AriaSnapshot {
return generateAriaTree(node as Element);
}
ariaSnapshotElement(snapshot: AriaSnapshot, elementId: number): Element | null {
return snapshot.elements.get(elementId) || null;
}
renderAriaTree(ariaNode: AriaNode, options?: { mode?: 'raw' | 'regex', id?: boolean}): string {
return renderAriaTree(ariaNode, options);
}
renderAriaSnapshotWithIds(ariaSnapshot: AriaSnapshot): string {
return renderAriaTree(ariaSnapshot.root, { ids: ariaSnapshot.ids });
}
getAllByAria(document: Document, template: AriaTemplateNode): Element[] {

View File

@ -132,6 +132,10 @@ export class Recorder implements InstrumentationListener, IRecorder {
this._contextRecorder.clearScript();
return;
}
if (data.event === 'runTask') {
this._contextRecorder.runTask(data.params.task);
return;
}
});
await Promise.all([

View File

@ -0,0 +1,184 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { WebSocketTransport } from '../transport';
import type { ConnectionTransport, ProtocolResponse } from '../transport';
export type ChatMessage = {
content: string;
user: 'user' | 'assistant';
};
export class Chat {
private _history: ChatMessage[] = [];
private _connectionPromise: Promise<Connection> | undefined;
private _chatSinks = new Map<string, (chunk: string) => void>();
private _wsEndpoint: string;
constructor(wsEndpoint: string) {
this._wsEndpoint = wsEndpoint;
}
clearHistory() {
this._history = [];
}
async post<T>(prompt: string): Promise<T | null> {
await this._append('user', prompt);
let text = await asString(await this._post());
if (text.startsWith('```json') && text.endsWith('```'))
text = text.substring('```json'.length, text.length - '```'.length);
for (let i = 0; i < 3; ++i) {
try {
return JSON.parse(text);
} catch (e) {
await this._append('user', String(e));
}
}
throw new Error('Failed to parse response: ' + text);
}
private async _append(user: ChatMessage['user'], content: string) {
this._history.push({ user, content });
}
private async _connection(): Promise<Connection> {
if (!this._connectionPromise) {
this._connectionPromise = WebSocketTransport.connect(undefined, this._wsEndpoint).then(transport => {
return new Connection(transport, (method, params) => this._dispatchEvent(method, params), () => {});
});
}
return this._connectionPromise;
}
private _dispatchEvent(method: string, params: any) {
if (method === 'chatChunk') {
const { chatId, chunk } = params;
const chunkSink = this._chatSinks.get(chatId)!;
chunkSink(chunk);
if (!chunk)
this._chatSinks.delete(chatId);
}
}
private async _post(): Promise<AsyncIterable<string>> {
const connection = await this._connection();
const result = await connection.send('chat', { history: this._history });
const { chatId } = result;
const { iterable, addChunk } = iterablePump();
this._chatSinks.set(chatId, addChunk);
return iterable;
}
}
export async function asString(stream: AsyncIterable<string>): Promise<string> {
let result = '';
for await (const chunk of stream)
result += chunk;
return result;
}
type ChunkIterator = {
iterable: AsyncIterable<string>;
addChunk: (chunk: string) => void;
};
function iterablePump(): ChunkIterator {
let controller: ReadableStreamDefaultController<string>;
const stream = new ReadableStream<string>({ start: c => controller = c });
const iterable = (async function* () {
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done)
break;
yield value!;
}
})();
return {
iterable,
addChunk: chunk => {
if (chunk)
controller.enqueue(chunk);
else
controller.close();
}
};
}
class Connection {
private readonly _transport: ConnectionTransport;
private _lastId = 0;
private _closed = false;
private _pending = new Map<number, { resolve: (result: any) => void; reject: (error: any) => void; }>();
private _onEvent: (method: string, params: any) => void;
private _onClose: () => void;
constructor(transport: ConnectionTransport, onEvent: (method: string, params: any) => void, onClose: () => void) {
this._transport = transport;
this._onEvent = onEvent;
this._onClose = onClose;
this._transport.onmessage = this._dispatchMessage.bind(this);
this._transport.onclose = this._close.bind(this);
}
send(method: string, params: any): Promise<any> {
const id = this._lastId++;
const message = { id, method, params };
this._transport.send(message);
return new Promise((resolve, reject) => {
this._pending.set(id, { resolve, reject });
});
}
private _dispatchMessage(message: ProtocolResponse) {
if (message.id === undefined) {
this._onEvent(message.method!, message.params);
return;
}
const callback = this._pending.get(message.id);
this._pending.delete(message.id);
if (!callback)
return;
if (message.error) {
callback.reject(new Error(message.error.message));
return;
}
callback.resolve(message.result);
}
_close() {
this._closed = true;
this._transport.onmessage = undefined;
this._transport.onclose = undefined;
for (const { reject } of this._pending.values())
reject(new Error('Connection closed'));
this._onClose();
}
isClosed() {
return this._closed;
}
close() {
if (!this._closed)
this._transport.close();
}
}

View File

@ -208,6 +208,10 @@ export class ContextRecorder extends EventEmitter {
}
}
runTask(task: string): void {
// TODO: implement
}
private _describeMainFrame(page: Page): actions.FrameDescription {
return {
pageAlias: this._pageAliases.get(page)!,

View File

@ -29,7 +29,8 @@ const debugLoggerColorMap = {
'channel': 33, // blue
'server': 45, // cyan
'server:channel': 34, // green
'server:metadata': 33, // blue
'server:metadata': 33, // blue,
'recorder': 45, // cyan
};
export type LogName = keyof typeof debugLoggerColorMap;