Chananchida commited on
Commit
caf3054
·
verified ·
1 Parent(s): 5ba2965

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import faiss
6
+ from sklearn.preprocessing import normalize
7
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
8
+ from sentence_transformers import SentenceTransformer, util
9
+ from pythainlp import Tokenizer
10
+ import pickle
11
+ import re
12
+ from pythainlp.tokenize import sent_tokenize
13
+ from unstructured.partition.html import partition_html
14
+
15
+ elements = partition_html(url=url)
16
+ context = [str(element) for element in elements if len(str(element)) >60]
17
+
18
+ DEFAULT_MODEL = 'wangchanberta'
19
+ DEFAULT_SENTENCE_EMBEDDING_MODEL = 'intfloat/multilingual-e5-base'
20
+
21
+ MODEL_DICT = {
22
+ 'wangchanberta': 'Chananchida/wangchanberta-xet_ref-params',
23
+ 'wangchanberta-hyp': 'Chananchida/wangchanberta-xet_hyp-params',
24
+ }
25
+
26
+ EMBEDDINGS_PATH = 'data/embeddings.pkl'
27
+
28
+ def load_model(model_name=DEFAULT_MODEL):
29
+ model = AutoModelForQuestionAnswering.from_pretrained(MODEL_DICT[model_name])
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DICT[model_name])
31
+ print('Load model done')
32
+ return model, tokenizer
33
+
34
+ def load_embedding_model(model_name=DEFAULT_SENTENCE_EMBEDDING_MODEL):
35
+ if torch.cuda.is_available():
36
+ embedding_model = SentenceTransformer(model_name, device='cuda')
37
+ else:
38
+ embedding_model = SentenceTransformer(model_name)
39
+ print('Load sentence embedding model done')
40
+ return embedding_model
41
+
42
+
43
+ def set_index(vector):
44
+ if torch.cuda.is_available():
45
+ res = faiss.StandardGpuResources()
46
+ index = faiss.IndexFlatL2(vector.shape[1])
47
+ gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)
48
+ gpu_index_flat.add(vector)
49
+ index = gpu_index_flat
50
+ else:
51
+ index = faiss.IndexFlatL2(vector.shape[1])
52
+ index.add(vector)
53
+ return index
54
+
55
+
56
+ def get_embeddings(embedding_model, text_list):
57
+ return embedding_model.encode(text_list)
58
+
59
+
60
+ def prepare_sentences_vector(encoded_list):
61
+ encoded_list = [i.reshape(1, -1) for i in encoded_list]
62
+ encoded_list = np.vstack(encoded_list).astype('float32')
63
+ encoded_list = normalize(encoded_list)
64
+ return encoded_list
65
+
66
+ def load_embeddings(file_path=EMBEDDINGS_PATH):
67
+ with open(file_path, "rb") as fIn:
68
+ stored_data = pickle.load(fIn)
69
+ stored_sentences = stored_data['sentences']
70
+ stored_embeddings = stored_data['embeddings']
71
+ print('Load (questions) embeddings done')
72
+ return stored_embeddings
73
+
74
+ def faiss_search(index, question_vector, k=1):
75
+ distances, indices = index.search(question_vector, k)
76
+ return distances,indices
77
+
78
+ def model_pipeline(model, tokenizer, question, context):
79
+ inputs = tokenizer(question, context, return_tensors="pt")
80
+ with torch.no_grad():
81
+ outputs = model(**inputs)
82
+ answer_start_index = outputs.start_logits.argmax()
83
+ answer_end_index = outputs.end_logits.argmax()
84
+ predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1]
85
+ Answer = tokenizer.decode(predict_answer_tokens)
86
+ return Answer.replace('<unk>','@')
87
+
88
+ def predict_test(model, tokenizer, embedding_model, context, question, index): # sent_tokenize pythainlp
89
+ t = time.time()
90
+ question = question.strip()
91
+ question_vector = get_embeddings(embedding_model, question)
92
+ question_vector = prepare_sentences_vector([question_vector])
93
+ distances, indices = faiss_search(index, question_vector, 3) # Retrieve top 3 indices
94
+
95
+ # most_similar_contexts = []
96
+ most_similar_contexts = ''
97
+ for i in range(3): # Loop through top 3 indices
98
+ most_sim_context = context[indices[0][i]].strip()
99
+ # most_similar_contexts.append(most_sim_context)
100
+ most_similar_contexts += str(i)+': '+most_sim_context + "\n\n"
101
+
102
+ _time = time.time() - t
103
+ output = {
104
+ "user_question": question,
105
+ "answer": most_similar_contexts,
106
+ # "answer": Answer,
107
+ "totaltime": round(_time, 3),
108
+ "distance": round(distances[0][0], 4)
109
+ }
110
+ # print('\nAnswer:',Answer)
111
+
112
+ return most_similar_contexts
113
+
114
+ def chat_interface(question, history):
115
+ response = predict_test(model, tokenizer, embedding_model, context, question, index)
116
+ return response
117
+
118
+ examples=['ภูมิทัศน์สื่อไทยในปี 2567 มีแนวโน้มว่า ',
119
+ 'Fragmentation คือ',
120
+ 'ติ๊กต๊อก คือ',
121
+ 'รายงานจาก Reuters Institute'
122
+ ]
123
+ interface = gr.ChatInterface(fn=chat_interface,
124
+ examples=examples)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ # Load your model, tokenizer, data, and index here...
129
+ # model, tokenizer = load_model('wangchanberta-hyp')
130
+ embedding_model = load_embedding_model()
131
+ # df = load_data()
132
+ index = set_index(prepare_sentences_vector(load_embeddings(EMBEDDINGS_PATH)))
133
+ interface.launch()