File size: 3,988 Bytes
8a37e0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import type { Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { map } from 'lodash-es';
import type { Connection, Edge } from 'reactflow';
/**
*
* @param source The source (node id)
* @param sourceHandle The source handle (field name), if any
* @param target The target (node id)
* @param targetHandle The target handle (field name), if any
* @param nodes The current nodes
* @param edges The current edges
* @param templates The current templates
* @param edgePendingUpdate The edge pending update, if any
* @returns
*/
export const getFirstValidConnection = (
source: string,
sourceHandle: string | null,
target: string,
targetHandle: string | null,
nodes: AnyNode[],
edges: InvocationNodeEdge[],
templates: Templates,
edgePendingUpdate: Edge | null
): Connection | null => {
if (source === target) {
return null;
}
if (sourceHandle && targetHandle) {
return { source, sourceHandle, target, targetHandle };
}
if (sourceHandle && !targetHandle) {
const candidates = getTargetCandidateFields(
source,
sourceHandle,
target,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
return { source, sourceHandle, target, targetHandle: firstCandidate.name };
}
if (!sourceHandle && targetHandle) {
const candidates = getSourceCandidateFields(
target,
targetHandle,
source,
nodes,
edges,
templates,
edgePendingUpdate
);
const firstCandidate = candidates[0];
if (!firstCandidate) {
return null;
}
return { source, sourceHandle: firstCandidate.name, target, targetHandle };
}
return null;
};
export const getTargetCandidateFields = (
source: string,
sourceHandle: string,
target: string,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
edgePendingUpdate: Edge | null
): FieldInputTemplate[] => {
const sourceNode = nodes.find((n) => n.id === source);
const targetNode = nodes.find((n) => n.id === target);
if (!sourceNode || !targetNode) {
return [];
}
const sourceTemplate = templates[sourceNode.data.type];
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
return [];
}
const sourceField = sourceTemplate.outputs[sourceHandle];
if (!sourceField) {
return [];
}
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
const c = { source, sourceHandle, target, targetHandle: field.name };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
});
return targetCandidateFields;
};
export const getSourceCandidateFields = (
target: string,
targetHandle: string,
source: string,
nodes: AnyNode[],
edges: Edge[],
templates: Templates,
edgePendingUpdate: Edge | null
): FieldOutputTemplate[] => {
const targetNode = nodes.find((n) => n.id === target);
const sourceNode = nodes.find((n) => n.id === source);
if (!sourceNode || !targetNode) {
return [];
}
const sourceTemplate = templates[sourceNode.data.type];
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
return [];
}
const targetField = targetTemplate.inputs[targetHandle];
if (!targetField) {
return [];
}
const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => {
const c = { source, sourceHandle: field.name, target, targetHandle };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
});
return sourceCandidateFields;
};
|