Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 51 additions & 49 deletions src/areas/workflows/nodes/ExtensionNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,8 @@ export default function ExtensionNode({ id, data, selected }: { id: string; data
const { updateNodeData } = useReactFlow()
const running = useWorkflowRunStore((s) => s.activeNodeId === id)

// Refs for handle alignment — support up to 2 inputs
const ioRowRef = useRef<HTMLDivElement>(null)
const ioRow2Ref = useRef<HTMLDivElement>(null)
const [handleTop, setHandleTop] = useState('50%')
const [handle2Top, setHandle2Top] = useState('50%')
const handleRefs = useRef<HTMLDivElement[]>([])
const [handleTops, setHandleTops] = useState<string[]>([])

const { modelExtensions, processExtensions } = useExtensionsStore()
const allExtensions = buildAllWorkflowExtensions(modelExtensions, processExtensions)
Expand All @@ -138,15 +135,14 @@ export default function ExtensionNode({ id, data, selected }: { id: string; data

// Align handles with their respective IO rows after mount
useLayoutEffect(() => {
if (ioRowRef.current) {
const center = ioRowRef.current.offsetTop + ioRowRef.current.offsetHeight / 2
setHandleTop(`${center}px`)
}
if (ioRow2Ref.current) {
const center = ioRow2Ref.current.offsetTop + ioRow2Ref.current.offsetHeight / 2
setHandle2Top(`${center}px`)
}
}, [isMulti])
setHandleTops(handleRefs.current.map((ref) => {
if (ref) {
const center = ref.offsetTop + ref.offsetHeight / 2
return `${center}px`
}
return '50%'
}))
}, [isMulti, inputs?.length])

const patchParam = useCallback((key: string, val: number | string) => {
updateNodeData(id, { params: { ...data.params, [key]: val } })
Expand All @@ -166,30 +162,27 @@ export default function ExtensionNode({ id, data, selected }: { id: string; data
const ioSubheader = isMulti ? (
// Multi-input layout: one row per input, output on first row
<div className="flex flex-col divide-y divide-zinc-800/40">
<div ref={ioRowRef} className="flex items-center justify-between px-3 py-2">
<span className={`inline-flex items-center px-1.5 py-0.5 rounded text-[9px] font-medium border ${TAG_CLS[inputs[0]] ?? 'border-zinc-700 bg-zinc-800 text-zinc-400'}`}>
{inputs[0]}
</span>
{!isTerminal && (
<>
<svg width="8" height="8" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" className="text-zinc-600 shrink-0">
<line x1="5" y1="12" x2="19" y2="12"/><polyline points="12 5 19 12 12 19"/>
</svg>
<span className={`inline-flex items-center px-1.5 py-0.5 rounded text-[9px] font-medium border ${TAG_CLS[ext?.output ?? ''] ?? 'border-zinc-700 bg-zinc-800 text-zinc-400'}`}>
{ext?.output ?? '—'}
</span>
</>
)}
</div>
<div ref={ioRow2Ref} className="flex items-center px-3 py-2">
<span className={`inline-flex items-center px-1.5 py-0.5 rounded text-[9px] font-medium border ${TAG_CLS[inputs[1]] ?? 'border-zinc-700 bg-zinc-800 text-zinc-400'}`}>
{inputs[1]}
</span>
</div>
{inputs.map((inputType, i) => (
<div key={i} ref={(el) => { if (el) handleRefs.current[i] = el }} className="flex items-center justify-between px-3 py-2">
<span className={`inline-flex items-center px-1.5 py-0.5 rounded text-[9px] font-medium border ${TAG_CLS[inputType] ?? 'border-zinc-700 bg-zinc-800 text-zinc-400'}`}>
{inputType}
</span>
{i === 0 && !isTerminal && (
<>
<svg width="8" height="8" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" className="text-zinc-600 shrink-0">
<line x1="5" y1="12" x2="19" y2="12"/><polyline points="12 5 19 12 12 19"/>
</svg>
<span className={`inline-flex items-center px-1.5 py-0.5 rounded text-[9px] font-medium border ${TAG_CLS[ext?.output ?? ''] ?? 'border-zinc-700 bg-zinc-800 text-zinc-400'}`}>
{ext?.output ?? '—'}
</span>
</>
)}
</div>
))}
</div>
) : (
// Single-input layout (existing behavior)
<div ref={ioRowRef} className="flex items-center justify-between px-3 py-2">
<div ref={(el) => { if (el) handleRefs.current[0] = el }} className="flex items-center justify-between px-3 py-2">
<span className={`inline-flex items-center px-1.5 py-0.5 rounded text-[9px] font-medium border ${TAG_CLS[ext?.input ?? ''] ?? 'border-zinc-700 bg-zinc-800 text-zinc-400'}`}>
{ext?.input ?? '—'}
</span>
Expand All @@ -207,31 +200,40 @@ export default function ExtensionNode({ id, data, selected }: { id: string; data
)

// ── Handles ──────────────────────────────────────────────────────────────
const handlesEl = (
const handlesEl = isMulti ? (
<>
{/* Primary input handle */}
<Handle
id="input-0"
type="target"
position={Position.Left}
style={{ background: HANDLE_COLOR[isMulti ? inputs[0] : (ext?.input ?? 'image')], width: 14, height: 14, border: '2.5px solid #18181b', top: handleTop }}
/>
{/* Secondary input handle (multi-input only) */}
{isMulti && (
{inputs.map((inputType, i) => (
<Handle
id="input-1"
key={i}
id={`input-${i}`}
type="target"
position={Position.Left}
style={{ background: HANDLE_COLOR[inputs[1]], width: 14, height: 14, border: '2.5px solid #18181b', top: handle2Top }}
style={{ background: HANDLE_COLOR[inputType], width: 14, height: 14, border: '2.5px solid #18181b', top: handleTops[i] ?? '50%' }}
/>
))}
{!isTerminal && (
<Handle
id="output"
type="source"
position={Position.Right}
style={{ background: outputColor, width: 14, height: 14, border: '2.5px solid #18181b', top: handleTops[0] ?? '50%' }}
/>
)}
{/* Output handle */}
</>
) : (
<>
<Handle
id="input-0"
type="target"
position={Position.Left}
style={{ background: HANDLE_COLOR[ext?.input ?? 'image'], width: 14, height: 14, border: '2.5px solid #18181b', top: handleTops[0] ?? '50%' }}
/>
{!isTerminal && (
<Handle
id="output"
type="source"
position={Position.Right}
style={{ background: outputColor, width: 14, height: 14, border: '2.5px solid #18181b', top: handleTop }}
style={{ background: outputColor, width: 14, height: 14, border: '2.5px solid #18181b', top: handleTops[0] ?? '50%' }}
/>
)}
</>
Expand Down
57 changes: 43 additions & 14 deletions src/areas/workflows/workflowRunStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,41 @@ async function executeExtensionNode(
return realId ? nodeOutputs.get(realId) : undefined
}

let nodeInputPath: string | undefined
let nodeInputText: string | undefined
let nodeInputMeshPath: string | undefined
let nodeInputPath: string | undefined
let nodeInputText: string | undefined
let nodeInputMeshPath: string | undefined
const extraImagePaths: string[] = []

const incomingEdges = workflow.edges.filter((e) => e.target === node.id)

if (ext?.inputs && ext.inputs.length > 1) {
const inputTypes = ext.inputs
const inputPaths = new Array<string | undefined>(inputTypes.length).fill(undefined)

for (const edge of incomingEdges) {
const src = resolveSource(edge.source)
if (!src) continue
if (src.outputType === 'mesh') nodeInputMeshPath = src.filePath
else if (src.outputType === 'image') nodeInputPath = src.filePath
else if (src.filePath !== undefined) nodeInputPath = src.filePath
if (src.text !== undefined) nodeInputText = src.text
if (!src || !src.filePath) continue
let slot = 0
if (edge.targetHandle?.startsWith('input-')) {
slot = parseInt(edge.targetHandle.slice(6), 10)
}
if (slot >= 0 && slot < inputTypes.length) {
inputPaths[slot] = src.filePath
}
}

for (let i = 0; i < inputTypes.length; i++) {
const fp = inputPaths[i]
if (!fp) continue
if (inputTypes[i] === 'mesh') {
nodeInputMeshPath = fp
} else if (inputTypes[i] === 'image') {
if (!nodeInputPath) {
nodeInputPath = fp
} else {
extraImagePaths.push(fp)
}
}
}
} else {
for (const edge of incomingEdges) {
Expand Down Expand Up @@ -194,6 +215,9 @@ async function executeExtensionNode(
? norm.slice(workspaceDir.length).replace(/^\//, '')
: norm
}
if (extraImagePaths.length > 0) {
extraParams.extra_image_paths = extraImagePaths
}

const schemaDefaults = Object.fromEntries(
(ext.params ?? []).map((p) => [p.id, p.default]),
Expand Down Expand Up @@ -242,17 +266,22 @@ async function executeExtensionNode(
useAppStore.getState().updateCurrentJob({ status: 'generating', progress: st.progress, step: st.step })
}
} else {
if (ext?.input === 'mesh' && !nodeInputPath) throw new Error(`${ext.name} needs an incoming mesh connection`)
if (ext?.input === 'image' && !nodeInputPath) throw new Error(`${ext.name} needs an incoming image connection`)
if (ext?.input === 'text' && !nodeInputText) throw new Error(`${ext.name} needs an incoming text connection`)

const parts = (node.data.extensionId ?? '').split('/')
const extId = parts[0]
const nid = parts[1] ?? ''

const processParams: Record<string, unknown> = { ...(node.data.params as Record<string, unknown>) }
if (nodeInputMeshPath && nodeInputPath) {
// Texture node: mesh is filePath, all images in extra_image_paths
processParams.extra_image_paths = [nodeInputPath, ...extraImagePaths]
} else if (extraImagePaths.length > 0) {
processParams.extra_image_paths = extraImagePaths
}

const result = await window.electron.extensions.runProcess(
extId,
{ filePath: nodeInputPath, text: nodeInputText, nodeId: nid },
node.data.params as Record<string, unknown>,
{ filePath: nodeInputMeshPath ?? nodeInputPath, text: nodeInputText, nodeId: nid },
processParams,
)
if (!result.success) throw new Error(result.error ?? 'Process extension failed')
nodeInputPath = result.result?.filePath ?? nodeInputPath
Expand Down