zzarif commited on
Commit
17e181f
·
1 Parent(s): a92483e

pooling update

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -28,10 +28,16 @@ def predict_similarity(question, candidate_answer, ai_answer):
28
  }
29
  ort_outputs = ort_session.run(None, ort_inputs)
30
 
31
- # Calculate cosine similarity
32
  embeddings = ort_outputs[0]
33
- similarity = np.dot(embeddings[0], embeddings[1]) / \
34
- (np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]))
 
 
 
 
 
 
35
 
36
  return float(similarity)
37
 
 
28
  }
29
  ort_outputs = ort_session.run(None, ort_inputs)
30
 
31
+ # Get embeddings (shape: (seq_length, 768))
32
  embeddings = ort_outputs[0]
33
+
34
+ # Apply mean pooling to reduce (seq_length, 768) to (768,)
35
+ candidate_embedding = np.mean(embeddings[0], axis=0) # Shape (768,)
36
+ ai_embedding = np.mean(embeddings[1], axis=0) # Shape (768,)
37
+
38
+ # Calculate cosine similarity
39
+ similarity = np.dot(candidate_embedding, ai_embedding) / \
40
+ (np.linalg.norm(candidate_embedding) * np.linalg.norm(ai_embedding))
41
 
42
  return float(similarity)
43