zhangxiyi.amos commited on
Commit
f0986b2
·
1 Parent(s): 312d284

fix: codet5p 池化张量长度不一致

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -20,9 +20,11 @@ model6 = AutoModel.from_pretrained("Salesforce/codet5p-110m-embedding", config=c
20
 
21
  # 创建一个简单的平均池化函数来获取嵌入
22
  def mean_pooling(model_output, attention_mask):
23
- token_embeddings = model_output[0] #First element of model_output contains all token embeddings
24
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
25
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
 
 
26
 
27
  @spaces.GPU
28
  def generate(query1, query2, source_code):
@@ -47,8 +49,8 @@ def generate(query1, query2, source_code):
47
  with torch.no_grad():
48
  model_output = model6(**inputs)
49
  embeddings = mean_pooling(model_output, inputs['attention_mask'])
50
- score1 = cos_sim(embeddings[0], embeddings[2])
51
- score2 = cos_sim(embeddings[1], embeddings[2])
52
  results.append([model_names[-1], float(score1), float(score2)])
53
 
54
  return results
 
20
 
21
  # 创建一个简单的平均池化函数来获取嵌入
22
  def mean_pooling(model_output, attention_mask):
23
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
24
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
25
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
26
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
27
+ return sum_embeddings / sum_mask
28
 
29
  @spaces.GPU
30
  def generate(query1, query2, source_code):
 
49
  with torch.no_grad():
50
  model_output = model6(**inputs)
51
  embeddings = mean_pooling(model_output, inputs['attention_mask'])
52
+ score1 = cos_sim(embeddings[0].unsqueeze(0), embeddings[2].unsqueeze(0))
53
+ score2 = cos_sim(embeddings[1].unsqueeze(0), embeddings[2].unsqueeze(0))
54
  results.append([model_names[-1], float(score1), float(score2)])
55
 
56
  return results