frankjosh commited on
Commit
2c502aa
·
verified ·
1 Parent(s): bdb68e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -35,24 +35,7 @@ if 'history' not in st.session_state:
35
  if 'feedback' not in st.session_state:
36
  st.session_state.feedback = {}
37
 
38
- @st.cache_data
39
- def generate_embedding(_model, _tokenizer, text):
40
- inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
41
- if torch.cuda.is_available():
42
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
43
- with torch.no_grad():
44
- outputs = _model.encoder(**inputs)
45
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
46
- if torch.cuda.is_available():
47
- embedding = embedding.cpu()
48
- return embedding.numpy()
49
 
50
- #error handling
51
- try:
52
- query_embedding = generate_embedding(model, tokenizer, user_query)
53
- except Exception as e:
54
- st.error(f"Error generating embedding: {str(e)}")
55
- st.stop()
56
 
57
 
58
 
@@ -95,6 +78,25 @@ def load_data_and_model():
95
 
96
  tokenizer, model = load_model_and_tokenizer()
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # Precompute embeddings with GPU support
99
  @st.cache_data
100
  def generate_embedding(text):
 
35
  if 'feedback' not in st.session_state:
36
  st.session_state.feedback = {}
37
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
39
 
40
 
41
 
 
78
 
79
  tokenizer, model = load_model_and_tokenizer()
80
 
81
+ @st.cache_data
82
+ def generate_embedding(_model, _tokenizer, text):
83
+ inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
84
+ if torch.cuda.is_available():
85
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
86
+ with torch.no_grad():
87
+ outputs = _model.encoder(**inputs)
88
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
89
+ if torch.cuda.is_available():
90
+ embedding = embedding.cpu()
91
+ return embedding.numpy()
92
+
93
+ #error handling
94
+ try:
95
+ query_embedding = generate_embedding(model, tokenizer, user_query)
96
+ except Exception as e:
97
+ st.error(f"Error generating embedding: {str(e)}")
98
+ st.stop()
99
+
100
  # Precompute embeddings with GPU support
101
  @st.cache_data
102
  def generate_embedding(text):