node-handler.tsx 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import { IFlowNode, IFlowNodeInput, IFlowNodeOutput, IFlowNodeParameter } from '@/types/flow';
  2. import { FLOW_NODES_KEY } from '@/utils';
  3. import { InfoCircleOutlined, PlusOutlined } from '@ant-design/icons';
  4. import { Popconfirm, Tooltip, Typography, message } from 'antd';
  5. import classNames from 'classnames';
  6. import React from 'react';
  7. import { useTranslation } from 'react-i18next';
  8. import { Connection, Handle, Position, useReactFlow } from 'reactflow';
  9. import RequiredIcon from './required-icon';
  10. import StaticNodes from './static-nodes';
  11. interface NodeHandlerProps {
  12. node: IFlowNode;
  13. data: IFlowNodeInput | IFlowNodeParameter | IFlowNodeOutput;
  14. type: 'source' | 'target';
  15. label: 'inputs' | 'outputs' | 'parameters';
  16. index: number;
  17. }
  18. // render react flow handle item
  19. const NodeHandler: React.FC<NodeHandlerProps> = ({ node, data, type, label, index }) => {
  20. const { t } = useTranslation();
  21. const reactflow = useReactFlow();
  22. const [relatedNodes, setRelatedNodes] = React.useState<IFlowNode[]>([]);
  23. function isValidConnection(connection: Connection) {
  24. const { sourceHandle, targetHandle, source, target } = connection;
  25. const sourceNode = reactflow.getNode(source!);
  26. const targetNode = reactflow.getNode(target!);
  27. const { flow_type: sourceFlowType } = sourceNode?.data ?? {};
  28. const { flow_type: targetFlowType } = targetNode?.data ?? {};
  29. const sourceLabel = sourceHandle?.split('|')[1];
  30. const targetLabel = targetHandle?.split('|')[1];
  31. const sourceIndex = sourceHandle?.split('|')[2];
  32. const targetIndex = targetHandle?.split('|')[2];
  33. const targetTypeCls = targetNode?.data[targetLabel!][targetIndex!].type_cls;
  34. if (sourceFlowType === targetFlowType && sourceFlowType === 'operator') {
  35. // operator to operator, only type_cls and is_list matched can be connected
  36. const sourceTypeCls = sourceNode?.data[sourceLabel!][sourceIndex!].type_cls;
  37. const sourceIsList = sourceNode?.data[sourceLabel!][sourceIndex!].is_list;
  38. const targetIsList = targetNode?.data[targetLabel!][targetIndex!].is_list;
  39. return sourceTypeCls === targetTypeCls && sourceIsList === targetIsList;
  40. } else if (sourceFlowType === 'resource' && (targetFlowType === 'operator' || targetFlowType === 'resource')) {
  41. // resource to operator, check operator type_cls and resource parent_cls
  42. const sourceParentCls = sourceNode?.data.parent_cls;
  43. return sourceParentCls.includes(targetTypeCls);
  44. }
  45. message.warning(t('connect_warning'));
  46. return false;
  47. }
  48. function showRelatedNodes() {
  49. // find all nodes that can be connected to this node
  50. const cache = localStorage.getItem(FLOW_NODES_KEY);
  51. if (!cache) {
  52. return;
  53. }
  54. const staticNodes = JSON.parse(cache);
  55. const typeCls = data.type_cls;
  56. let nodes: IFlowNode[] = [];
  57. if (label === 'inputs') {
  58. // find other operators and outputs matching this input type_cls
  59. nodes = staticNodes
  60. .filter((node: IFlowNode) => node.flow_type === 'operator')
  61. .filter((node: IFlowNode) =>
  62. node.outputs?.some(
  63. (output: IFlowNodeOutput) => output.type_cls === typeCls && output.is_list === data?.is_list,
  64. ),
  65. );
  66. } else if (label === 'parameters') {
  67. // fint other resources and parent_cls including this parameter type_cls
  68. nodes = staticNodes
  69. .filter((node: IFlowNode) => node.flow_type === 'resource')
  70. .filter((node: IFlowNode) => node.parent_cls?.includes(typeCls));
  71. } else if (label === 'outputs') {
  72. if (node.flow_type === 'operator') {
  73. // find other operators and inputs matching this output type_cls
  74. nodes = staticNodes
  75. .filter((node: IFlowNode) => node.flow_type === 'operator')
  76. .filter((node: IFlowNode) =>
  77. node.inputs?.some((input: IFlowNodeInput) => input.type_cls === typeCls && input.is_list === data?.is_list),
  78. );
  79. } else if (node.flow_type === 'resource') {
  80. // find other resources or operators that this output parent_cls includes their type_cls
  81. nodes = staticNodes.filter(
  82. (item: IFlowNode) =>
  83. item.inputs?.some((input: IFlowNodeInput) => node.parent_cls?.includes(input.type_cls)) ||
  84. item.parameters?.some((parameter: IFlowNodeParameter) => node.parent_cls?.includes(parameter.type_cls)),
  85. );
  86. }
  87. }
  88. setRelatedNodes(nodes);
  89. }
  90. return (
  91. <div
  92. className={classNames('relative flex items-center', {
  93. 'justify-start': label === 'parameters' || label === 'inputs',
  94. 'justify-end': label === 'outputs',
  95. })}
  96. >
  97. <Handle
  98. className={classNames('w-2 h-2', type === 'source' ? '-mr-4' : '-ml-4')}
  99. type={type}
  100. position={type === 'source' ? Position.Right : Position.Left}
  101. id={`${node.id}|${label}|${index}`}
  102. isValidConnection={connection => isValidConnection(connection)}
  103. />
  104. <Typography
  105. className={classNames('bg-white dark:bg-[#232734] w-full px-2 py-1 rounded text-neutral-500', {
  106. 'text-right': label === 'outputs',
  107. })}
  108. >
  109. <Popconfirm
  110. placement='left'
  111. icon={null}
  112. showCancel={false}
  113. okButtonProps={{ className: 'hidden' }}
  114. title={t('related_nodes')}
  115. description={
  116. <div className='w-60'>
  117. <StaticNodes nodes={relatedNodes} />
  118. </div>
  119. }
  120. >
  121. {['inputs', 'parameters'].includes(label) && (
  122. <PlusOutlined className='cursor-pointer' onClick={showRelatedNodes} />
  123. )}
  124. </Popconfirm>
  125. {label !== 'outputs' && <RequiredIcon optional={data.optional} />}
  126. {data.type_name}
  127. {data.description && (
  128. <Tooltip title={data.description}>
  129. <InfoCircleOutlined className='ml-2 cursor-pointer' />
  130. </Tooltip>
  131. )}
  132. <Popconfirm
  133. placement='right'
  134. icon={null}
  135. showCancel={false}
  136. okButtonProps={{ className: 'hidden' }}
  137. title={t('related_nodes')}
  138. description={
  139. <div className='w-60'>
  140. <StaticNodes nodes={relatedNodes} />
  141. </div>
  142. }
  143. >
  144. {['outputs'].includes(label) && <PlusOutlined className='ml-2 cursor-pointer' onClick={showRelatedNodes} />}
  145. </Popconfirm>
  146. </Typography>
  147. </div>
  148. );
  149. };
  150. export default NodeHandler;