frankjosh commited on
Commit
bdb68e8
·
verified ·
1 Parent(s): 693093c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -35,14 +35,26 @@ if 'history' not in st.session_state:
35
  if 'feedback' not in st.session_state:
36
  st.session_state.feedback = {}
37
 
38
-
39
-
40
  @st.cache_data
41
- def generate_embedding(_model, tokenizer, text):
42
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
43
  with torch.no_grad():
44
- outputs = model(**inputs)
45
- return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Step 1: Load Dataset and Precompute Embeddings
48
  @st.cache_resource
 
35
  if 'feedback' not in st.session_state:
36
  st.session_state.feedback = {}
37
 
 
 
38
  @st.cache_data
39
+ def generate_embedding(_model, _tokenizer, text):
40
+ inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
41
+ if torch.cuda.is_available():
42
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
43
  with torch.no_grad():
44
+ outputs = _model.encoder(**inputs)
45
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
46
+ if torch.cuda.is_available():
47
+ embedding = embedding.cpu()
48
+ return embedding.numpy()
49
+
50
+ #error handling
51
+ try:
52
+ query_embedding = generate_embedding(model, tokenizer, user_query)
53
+ except Exception as e:
54
+ st.error(f"Error generating embedding: {str(e)}")
55
+ st.stop()
56
+
57
+
58
 
59
  # Step 1: Load Dataset and Precompute Embeddings
60
  @st.cache_resource