File size: 1,707 Bytes
abc16ec
 
3ee66e5
 
abc16ec
 
 
 
 
 
 
 
 
 
3ee66e5
 
603178d
 
 
 
 
 
 
 
40cb81b
3ee66e5
abc16ec
9ad3866
3ee66e5
abc16ec
 
 
 
 
3ee66e5
abc16ec
 
 
 
 
 
3ee66e5
abc16ec
 
 
 
3ee66e5
abc16ec
936abfc
3ee66e5
abc16ec
3ee66e5
ecec617
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from typing import Optional, Tuple
from threading import Lock
import pickle

import gradio as gr
from query_data import get_chain

with open("vectorstore.pkl", "rb") as f:
    vectorstore = pickle.load(f)

class ChatWrapper:
    def __init__(self):
        self.lock = Lock()

    def set_openai_api_key(self, api_key: str):
        """Set the api key and return chain.
        If no api_key, then None is returned.
        """
        if api_key:
            os.environ["OPENAI_API_KEY"] = api_key
            chain = get_chain(vectorstore)
            os.environ["OPENAI_API_KEY"] = ""
            return chain

    def __call__(self, inp: str, history: Optional[Tuple[str, str]]):
        self.lock.acquire()
        api_key = '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'
        chain = self.set_openai_api_key(api_key)
        try:
            history = history or []
            # If chain is None, that is because no API key was provided.
            if chain is None:
                history.append((inp, "Please paste your OpenAI key to use"))
                return inp, history
            # Set OpenAI key
            import openai
            openai.api_key = api_key
            # Run chain and append input.
            output = chain({"question": inp, "chat_history": history})["answer"]
            history.append((inp, output))
            chatResult = (output, history)
        except Exception as e:
            raise e
        finally:
            self.lock.release()
        return chatResult

chat = ChatWrapper()
state = gr.outputs.State()

gradio_interface = gr.Interface(chat, inputs=["text", state], outputs=["text", state])
gradio_interface.launch(debug=True)