File size: 4,216 Bytes
4e3ae23
 
 
 
553537f
4e3ae23
553537f
4e3ae23
8e7534c
 
e1d76bb
4e3ae23
 
8e7534c
 
 
 
3869042
 
 
 
 
5ca849a
3869042
 
 
8e7534c
 
 
4e3ae23
 
 
 
 
 
 
 
 
9d3d816
4e3ae23
 
 
 
 
bd0ee22
e4db578
 
bd0ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
e4db578
4e3ae23
 
e1d76bb
4e3ae23
bd0ee22
837e507
bd0ee22
4e3ae23
 
e1d76bb
 
 
4e3ae23
975135d
4e3ae23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5060c9
4e3ae23
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0ee22
4e3ae23
 
 
 
 
 
bd0ee22
4e3ae23
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_community.chat_models.gigachat import GigaChat
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import os
import time
import json
#import telebot


def save_json(data):
    timestamp = int(time.time())
    filename = f"{timestamp}.json"
    filepath = f"./requests_from_users/{filename}"
    docs = data['context']
    list_docs = []
    for doc in docs:
        dict_doc = {
            'page_content': doc.page_content,
            'metadata': doc.metadata
        }
        list_docs.append(dict_doc)
    data['context'] = list_docs
    with open(filepath, 'w') as json_file:
        json.dump(data, json_file)

def get_yt_links(contexts):
    html = '''
    <iframe width="100%" height="200" src="{}?start={}" \
    title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; \
    encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" \
    allowfullscreen></iframe>
    '''
    yt_htmls = []
    for context in contexts:
        print(context)
        link = context.metadata['link']
        start = context.metadata['time']
        yt_htmls.append(html.format(link, start))
    return yt_htmls

'''def resp2msg(resp):
    req = resp['input']
    ans = resp['answer']
    return req + '\n' + ans'''

def get_context(contexts):
    txt_context = '''
    Фрагмент 1: {}
    Фрагмент 2: {}
    Фрагмент 3: {}
    '''.format(
        contexts[0].page_content,
        contexts[1].page_content,
        contexts[2].page_content,
    )
    return txt_context
    
    
def process_input(text):
    response = retrieval_chain.invoke({"input": text})
    #bot.send_message(user_id, resp2msg(response))
    youtube_links = get_yt_links(response['context'])
    context = get_context(response['context'])
    save_json(response)
    return response['answer'], context, youtube_links[0], youtube_links[1], youtube_links[2]

giga = os.getenv('GIGA')
#token = os.getenv('BOT')
#user_id = os.getenv('CREATOR')
#bot = telebot.TeleBot(token)
model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embedding = HuggingFaceEmbeddings(model_name=model_name,
                                  model_kwargs=model_kwargs,
                                  encode_kwargs=encode_kwargs)

vector_db = FAISS.load_local('faiss_index',
                            embeddings=embedding,
                            allow_dangerous_deserialization=True)
llm = GigaChat(credentials=giga, verify_ssl_certs=False, profanity_check=False)

prompt = ChatPromptTemplate.from_template('''Ответь на вопрос пользователя. \
Используй при этом только информацию из контекста. Если в контексте нет \
информации для ответа, сообщи об этом пользователю.
Контекст: {context}
Вопрос: {input}
Ответ:'''
)

embedding_retriever = vector_db.as_retriever(search_kwargs={"k": 3})

document_chain = create_stuff_documents_chain(
    llm=llm,
    prompt=prompt
    )

retrieval_chain = create_retrieval_chain(embedding_retriever, document_chain)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label="Введите запрос")
            submit_btn = gr.Button("Отправить запрос")
            text_output = gr.Textbox(label="Ответ", interactive=False)
            text_context = gr.Textbox(label="Контекст", interactive=False)
        
        with gr.Column():
            youtube_video1 = gr.HTML()
            youtube_video2 = gr.HTML()
            youtube_video3 = gr.HTML()
    
    submit_btn.click(process_input, text_input, [text_output, text_context, youtube_video1, youtube_video2, youtube_video3])


    demo.launch()