Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -35,14 +35,26 @@ if 'history' not in st.session_state:
|
|
35 |
if 'feedback' not in st.session_state:
|
36 |
st.session_state.feedback = {}
|
37 |
|
38 |
-
|
39 |
-
|
40 |
@st.cache_data
|
41 |
-
def generate_embedding(_model,
|
42 |
-
inputs =
|
|
|
|
|
43 |
with torch.no_grad():
|
44 |
-
outputs =
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
# Step 1: Load Dataset and Precompute Embeddings
|
48 |
@st.cache_resource
|
|
|
35 |
if 'feedback' not in st.session_state:
|
36 |
st.session_state.feedback = {}
|
37 |
|
|
|
|
|
38 |
@st.cache_data
|
39 |
+
def generate_embedding(_model, _tokenizer, text):
|
40 |
+
inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
41 |
+
if torch.cuda.is_available():
|
42 |
+
inputs = {k: v.to('cuda') for k, v in inputs.items()}
|
43 |
with torch.no_grad():
|
44 |
+
outputs = _model.encoder(**inputs)
|
45 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
|
46 |
+
if torch.cuda.is_available():
|
47 |
+
embedding = embedding.cpu()
|
48 |
+
return embedding.numpy()
|
49 |
+
|
50 |
+
#error handling
|
51 |
+
try:
|
52 |
+
query_embedding = generate_embedding(model, tokenizer, user_query)
|
53 |
+
except Exception as e:
|
54 |
+
st.error(f"Error generating embedding: {str(e)}")
|
55 |
+
st.stop()
|
56 |
+
|
57 |
+
|
58 |
|
59 |
# Step 1: Load Dataset and Precompute Embeddings
|
60 |
@st.cache_resource
|