|
import { |
|
useCallback, |
|
useEffect, |
|
useRef, |
|
useState, |
|
} from 'react' |
|
import produce from 'immer' |
|
import { isEqual } from 'lodash-es' |
|
import type { ValueSelector, Var } from '../../types' |
|
import { BlockEnum, VarType } from '../../types' |
|
import { |
|
useIsChatMode, useNodesReadOnly, |
|
useWorkflow, |
|
} from '../../hooks' |
|
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types' |
|
import { |
|
getMultipleRetrievalConfig, |
|
getSelectedDatasetsMode, |
|
} from './utils' |
|
import { RETRIEVE_TYPE } from '@/types/app' |
|
import { DATASET_DEFAULT } from '@/config' |
|
import type { DataSet } from '@/models/datasets' |
|
import { fetchDatasets } from '@/service/datasets' |
|
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' |
|
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' |
|
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' |
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' |
|
|
|
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { |
|
const { nodesReadOnly: readOnly } = useNodesReadOnly() |
|
const isChatMode = useIsChatMode() |
|
const { getBeforeNodesInSameBranch } = useWorkflow() |
|
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) |
|
const startNodeId = startNode?.id |
|
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) |
|
|
|
const inputRef = useRef(inputs) |
|
|
|
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { |
|
const newInputs = produce(s, (draft) => { |
|
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) |
|
delete draft.single_retrieval_config |
|
else |
|
delete draft.multiple_retrieval_config |
|
}) |
|
|
|
doSetInputs(newInputs) |
|
inputRef.current = newInputs |
|
}, [doSetInputs]) |
|
|
|
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { |
|
const newInputs = produce(inputs, (draft) => { |
|
draft.query_variable_selector = newVar as ValueSelector |
|
}) |
|
setInputs(newInputs) |
|
}, [inputs, setInputs]) |
|
|
|
const { |
|
currentProvider, |
|
currentModel, |
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) |
|
|
|
const { |
|
modelList: rerankModelList, |
|
defaultModel: rerankDefaultModel, |
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) |
|
|
|
const { |
|
currentModel: currentRerankModel, |
|
} = useCurrentProviderAndModel( |
|
rerankModelList, |
|
rerankDefaultModel |
|
? { |
|
...rerankDefaultModel, |
|
provider: rerankDefaultModel.provider.provider, |
|
} |
|
: undefined, |
|
) |
|
|
|
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { |
|
const newInputs = produce(inputRef.current, (draft) => { |
|
if (!draft.single_retrieval_config) { |
|
draft.single_retrieval_config = { |
|
model: { |
|
provider: '', |
|
name: '', |
|
mode: '', |
|
completion_params: {}, |
|
}, |
|
} |
|
} |
|
const draftModel = draft.single_retrieval_config?.model |
|
draftModel.provider = model.provider |
|
draftModel.name = model.modelId |
|
draftModel.mode = model.mode! |
|
}) |
|
setInputs(newInputs) |
|
}, [setInputs]) |
|
|
|
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => { |
|
|
|
if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params)) |
|
return |
|
|
|
const newInputs = produce(inputRef.current, (draft) => { |
|
if (!draft.single_retrieval_config) { |
|
draft.single_retrieval_config = { |
|
model: { |
|
provider: '', |
|
name: '', |
|
mode: '', |
|
completion_params: {}, |
|
}, |
|
} |
|
} |
|
draft.single_retrieval_config.model.completion_params = newParams |
|
}) |
|
setInputs(newInputs) |
|
}, [setInputs]) |
|
|
|
|
|
useEffect(() => { |
|
const inputs = inputRef.current |
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel) |
|
return |
|
|
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) |
|
return |
|
|
|
const newInput = produce(inputs, (draft) => { |
|
if (currentProvider?.provider && currentModel?.model) { |
|
const hasSetModel = draft.single_retrieval_config?.model?.provider |
|
if (!hasSetModel) { |
|
draft.single_retrieval_config = { |
|
model: { |
|
provider: currentProvider?.provider, |
|
name: currentModel?.model, |
|
mode: currentModel?.model_properties?.mode as string, |
|
completion_params: {}, |
|
}, |
|
} |
|
} |
|
} |
|
const multipleRetrievalConfig = draft.multiple_retrieval_config |
|
draft.multiple_retrieval_config = { |
|
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, |
|
score_threshold: multipleRetrievalConfig?.score_threshold, |
|
reranking_model: multipleRetrievalConfig?.reranking_model, |
|
reranking_mode: multipleRetrievalConfig?.reranking_mode, |
|
weights: multipleRetrievalConfig?.weights, |
|
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined |
|
? multipleRetrievalConfig.reranking_enable |
|
: Boolean(currentRerankModel && rerankDefaultModel), |
|
} |
|
}) |
|
setInputs(newInput) |
|
|
|
}, [currentProvider?.provider, currentModel, rerankDefaultModel]) |
|
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([]) |
|
const [rerankModelOpen, setRerankModelOpen] = useState(false) |
|
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => { |
|
const newInputs = produce(inputs, (draft) => { |
|
draft.retrieval_mode = newMode |
|
if (newMode === RETRIEVE_TYPE.multiWay) { |
|
const multipleRetrievalConfig = draft.multiple_retrieval_config |
|
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) |
|
} |
|
else { |
|
const hasSetModel = draft.single_retrieval_config?.model?.provider |
|
if (!hasSetModel) { |
|
draft.single_retrieval_config = { |
|
model: { |
|
provider: currentProvider?.provider || '', |
|
name: currentModel?.model || '', |
|
mode: currentModel?.model_properties?.mode as string, |
|
completion_params: {}, |
|
}, |
|
} |
|
} |
|
} |
|
}) |
|
setInputs(newInputs) |
|
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) |
|
|
|
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { |
|
const newInputs = produce(inputs, (draft) => { |
|
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) |
|
}) |
|
setInputs(newInputs) |
|
}, [inputs, setInputs, selectedDatasets, currentRerankModel]) |
|
|
|
|
|
useEffect(() => { |
|
(async () => { |
|
const inputs = inputRef.current |
|
const datasetIds = inputs.dataset_ids |
|
if (datasetIds?.length > 0) { |
|
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } }) |
|
setSelectedDatasets(dataSetsWithDetail) |
|
} |
|
const newInputs = produce(inputs, (draft) => { |
|
draft.dataset_ids = datasetIds |
|
}) |
|
setInputs(newInputs) |
|
})() |
|
|
|
}, []) |
|
|
|
useEffect(() => { |
|
const inputs = inputRef.current |
|
let query_variable_selector: ValueSelector = inputs.query_variable_selector |
|
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) |
|
query_variable_selector = [startNodeId, 'sys.query'] |
|
|
|
setInputs(produce(inputs, (draft) => { |
|
draft.query_variable_selector = query_variable_selector |
|
})) |
|
|
|
}, []) |
|
|
|
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => { |
|
const { |
|
mixtureHighQualityAndEconomic, |
|
mixtureInternalAndExternal, |
|
inconsistentEmbeddingModel, |
|
allInternal, |
|
allExternal, |
|
} = getSelectedDatasetsMode(newDatasets) |
|
const newInputs = produce(inputs, (draft) => { |
|
draft.dataset_ids = newDatasets.map(d => d.id) |
|
|
|
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { |
|
const multipleRetrievalConfig = draft.multiple_retrieval_config |
|
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) |
|
} |
|
}) |
|
setInputs(newInputs) |
|
setSelectedDatasets(newDatasets) |
|
|
|
if ( |
|
(allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) |
|
|| mixtureInternalAndExternal |
|
|| allExternal |
|
) |
|
setRerankModelOpen(true) |
|
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) |
|
|
|
const filterVar = useCallback((varPayload: Var) => { |
|
return varPayload.type === VarType.string |
|
}, []) |
|
|
|
|
|
const { |
|
isShowSingleRun, |
|
hideSingleRun, |
|
runningStatus, |
|
handleRun, |
|
handleStop, |
|
runInputData, |
|
setRunInputData, |
|
runResult, |
|
} = useOneStepRun<KnowledgeRetrievalNodeType>({ |
|
id, |
|
data: inputs, |
|
defaultRunInputData: { |
|
query: '', |
|
}, |
|
}) |
|
|
|
const query = runInputData.query |
|
const setQuery = useCallback((newQuery: string) => { |
|
setRunInputData({ |
|
...runInputData, |
|
query: newQuery, |
|
}) |
|
}, [runInputData, setRunInputData]) |
|
|
|
return { |
|
readOnly, |
|
inputs, |
|
handleQueryVarChange, |
|
filterVar, |
|
handleRetrievalModeChange, |
|
handleMultipleRetrievalConfigChange, |
|
handleModelChanged, |
|
handleCompletionParamsChange, |
|
selectedDatasets: selectedDatasets.filter(d => d.name), |
|
handleOnDatasetsChange, |
|
isShowSingleRun, |
|
hideSingleRun, |
|
runningStatus, |
|
handleRun, |
|
handleStop, |
|
query, |
|
setQuery, |
|
runResult, |
|
rerankModelOpen, |
|
setRerankModelOpen, |
|
} |
|
} |
|
|
|
export default useConfig |
|
|