Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -35,24 +35,7 @@ if 'history' not in st.session_state:
|
|
35 |
if 'feedback' not in st.session_state:
|
36 |
st.session_state.feedback = {}
|
37 |
|
38 |
-
@st.cache_data
|
39 |
-
def generate_embedding(_model, _tokenizer, text):
|
40 |
-
inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
41 |
-
if torch.cuda.is_available():
|
42 |
-
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
43 |
-
with torch.no_grad():
|
44 |
-
outputs = _model.encoder(**inputs)
|
45 |
-
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
|
46 |
-
if torch.cuda.is_available():
|
47 |
-
embedding = embedding.cpu()
|
48 |
-
return embedding.numpy()
|
49 |
|
50 |
-
#error handling
|
51 |
-
try:
|
52 |
-
query_embedding = generate_embedding(model, tokenizer, user_query)
|
53 |
-
except Exception as e:
|
54 |
-
st.error(f"Error generating embedding: {str(e)}")
|
55 |
-
st.stop()
|
56 |
|
57 |
|
58 |
|
@@ -95,6 +78,25 @@ def load_data_and_model():
|
|
95 |
|
96 |
tokenizer, model = load_model_and_tokenizer()
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
# Precompute embeddings with GPU support
|
99 |
@st.cache_data
|
100 |
def generate_embedding(text):
|
|
|
35 |
if 'feedback' not in st.session_state:
|
36 |
st.session_state.feedback = {}
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
|
|
|
78 |
|
79 |
tokenizer, model = load_model_and_tokenizer()
|
80 |
|
81 |
+
@st.cache_data
|
82 |
+
def generate_embedding(_model, _tokenizer, text):
|
83 |
+
inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
84 |
+
if torch.cuda.is_available():
|
85 |
+
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
86 |
+
with torch.no_grad():
|
87 |
+
outputs = _model.encoder(**inputs)
|
88 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
|
89 |
+
if torch.cuda.is_available():
|
90 |
+
embedding = embedding.cpu()
|
91 |
+
return embedding.numpy()
|
92 |
+
|
93 |
+
#error handling
|
94 |
+
try:
|
95 |
+
query_embedding = generate_embedding(model, tokenizer, user_query)
|
96 |
+
except Exception as e:
|
97 |
+
st.error(f"Error generating embedding: {str(e)}")
|
98 |
+
st.stop()
|
99 |
+
|
100 |
# Precompute embeddings with GPU support
|
101 |
@st.cache_data
|
102 |
def generate_embedding(text):
|