Spaces:
Running
on
Zero
Running
on
Zero
| // @ts-check | |
| /// <reference path="../node_modules/@types/jest/index.d.ts" /> | |
| const { start } = require("../utils"); | |
| const lg = require("../utils/litegraph"); | |
| describe("extensions", () => { | |
| beforeEach(() => { | |
| lg.setup(global); | |
| }); | |
| afterEach(() => { | |
| lg.teardown(global); | |
| }); | |
| it("calls each extension hook", async () => { | |
| const mockExtension = { | |
| name: "TestExtension", | |
| init: jest.fn(), | |
| setup: jest.fn(), | |
| addCustomNodeDefs: jest.fn(), | |
| getCustomWidgets: jest.fn(), | |
| beforeRegisterNodeDef: jest.fn(), | |
| registerCustomNodes: jest.fn(), | |
| loadedGraphNode: jest.fn(), | |
| nodeCreated: jest.fn(), | |
| beforeConfigureGraph: jest.fn(), | |
| afterConfigureGraph: jest.fn(), | |
| }; | |
| const { app, ez, graph } = await start({ | |
| async preSetup(app) { | |
| app.registerExtension(mockExtension); | |
| }, | |
| }); | |
| // Basic initialisation hooks should be called once, with app | |
| expect(mockExtension.init).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.init).toHaveBeenCalledWith(app); | |
| // Adding custom node defs should be passed the full list of nodes | |
| expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app); | |
| const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0]; | |
| expect(defs).toHaveProperty("KSampler"); | |
| expect(defs).toHaveProperty("LoadImage"); | |
| // Get custom widgets is called once and should return new widget types | |
| expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app); | |
| // Before register node def will be called once per node type | |
| const nodeNames = Object.keys(defs); | |
| const nodeCount = nodeNames.length; | |
| expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); | |
| for (let i = 0; i < 10; i++) { | |
| // It should be send the JS class and the original JSON definition | |
| const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; | |
| const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; | |
| expect(nodeClass.name).toBe("ComfyNode"); | |
| expect(nodeClass.comfyClass).toBe(nodeNames[i]); | |
| expect(nodeDef.name).toBe(nodeNames[i]); | |
| expect(nodeDef).toHaveProperty("input"); | |
| expect(nodeDef).toHaveProperty("output"); | |
| } | |
| // Register custom nodes is called once after registerNode defs to allow adding other frontend nodes | |
| expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); | |
| // Before configure graph will be called here as the default graph is being loaded | |
| expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1); | |
| // it gets sent the graph data that is going to be loaded | |
| const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0]; | |
| // A node created is fired for each node constructor that is called | |
| expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length); | |
| for (let i = 0; i < graphData.nodes.length; i++) { | |
| expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type); | |
| } | |
| // Each node then calls loadedGraphNode to allow them to be updated | |
| expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); | |
| for (let i = 0; i < graphData.nodes.length; i++) { | |
| expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type); | |
| } | |
| // After configure is then called once all the setup is done | |
| expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.setup).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.setup).toHaveBeenCalledWith(app); | |
| // Ensure hooks are called in the correct order | |
| const callOrder = [ | |
| "init", | |
| "addCustomNodeDefs", | |
| "getCustomWidgets", | |
| "beforeRegisterNodeDef", | |
| "registerCustomNodes", | |
| "beforeConfigureGraph", | |
| "nodeCreated", | |
| "loadedGraphNode", | |
| "afterConfigureGraph", | |
| "setup", | |
| ]; | |
| for (let i = 1; i < callOrder.length; i++) { | |
| const fn1 = mockExtension[callOrder[i - 1]]; | |
| const fn2 = mockExtension[callOrder[i]]; | |
| expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]); | |
| } | |
| graph.clear(); | |
| // Ensure adding a new node calls the correct callback | |
| ez.LoadImage(); | |
| expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); | |
| expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1); | |
| expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage"); | |
| // Reload the graph to ensure correct hooks are fired | |
| await graph.reload(); | |
| // These hooks should not be fired again | |
| expect(mockExtension.init).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); | |
| expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); | |
| expect(mockExtension.setup).toHaveBeenCalledTimes(1); | |
| // These should be called again | |
| expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2); | |
| expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); | |
| expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); | |
| expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); | |
| }, 15000); | |
| it("allows custom nodeDefs and widgets to be registered", async () => { | |
| const widgetMock = jest.fn((node, inputName, inputData, app) => { | |
| expect(node.constructor.comfyClass).toBe("TestNode"); | |
| expect(inputName).toBe("test_input"); | |
| expect(inputData[0]).toBe("CUSTOMWIDGET"); | |
| expect(inputData[1]?.hello).toBe("world"); | |
| expect(app).toStrictEqual(app); | |
| return { | |
| widget: node.addWidget("button", inputName, "hello", () => {}), | |
| }; | |
| }); | |
| // Register our extension that adds a custom node + widget type | |
| const mockExtension = { | |
| name: "TestExtension", | |
| addCustomNodeDefs: (nodeDefs) => { | |
| nodeDefs["TestNode"] = { | |
| output: [], | |
| output_name: [], | |
| output_is_list: [], | |
| name: "TestNode", | |
| display_name: "TestNode", | |
| category: "Test", | |
| input: { | |
| required: { | |
| test_input: ["CUSTOMWIDGET", { hello: "world" }], | |
| }, | |
| }, | |
| }; | |
| }, | |
| getCustomWidgets: jest.fn(() => { | |
| return { | |
| CUSTOMWIDGET: widgetMock, | |
| }; | |
| }), | |
| }; | |
| const { graph, ez } = await start({ | |
| async preSetup(app) { | |
| app.registerExtension(mockExtension); | |
| }, | |
| }); | |
| expect(mockExtension.getCustomWidgets).toBeCalledTimes(1); | |
| graph.clear(); | |
| expect(widgetMock).toBeCalledTimes(0); | |
| const node = ez.TestNode(); | |
| expect(widgetMock).toBeCalledTimes(1); | |
| // Ensure our custom widget is created | |
| expect(node.inputs.length).toBe(0); | |
| expect(node.widgets.length).toBe(1); | |
| const w = node.widgets[0].widget; | |
| expect(w.name).toBe("test_input"); | |
| expect(w.type).toBe("button"); | |
| }); | |
| }); | |