File size: 746 Bytes
1c4216d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from rag_fns.generation import do_generation
from rag_fns.retrieval import do_retrieval
from rag_fns.setup_load import import_data, load_oai_model


def do_rag(user_input: str, stream: bool = False, n_results: int = 3):
    # Load the data
    talk_ids, embeds, talk_info = import_data()
    # Load the model
    oai_client = load_oai_model()

    retrieved_docs = do_retrieval(
        query0=user_input,
        n_results=n_results,
        api_client=oai_client,
        talk_ids=talk_ids,
        embeds=embeds,
        talk_info=talk_info,
    )

    response, prompt_tokens = do_generation(
        query1=user_input, keep_texts=retrieved_docs, gen_client=oai_client, stream=stream
    )

    return response, retrieved_docs, prompt_tokens