zpn commited on
Commit
5e79c62
·
verified ·
1 Parent(s): e855ea0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -3
README.md CHANGED
@@ -63,7 +63,7 @@ from transformers import AutoTokenizer, AutoModel
63
  tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-code")
64
  model = AutoModel.from_pretrained("nomic-ai/nomic-embed-code")
65
 
66
- def last_token_pooling(model_output, attention_mask):
67
  sequence_lengths = attention_mask.sum(-1) - 1
68
  return hidden_states[torch.arange(hidden_states.shape[0]), sequence_lengths]
69
 
@@ -74,7 +74,8 @@ code_snippets = queries + codes
74
  encoded_input = tokenizer(code_snippets, padding=True, truncation=True, return_tensors='pt')
75
  model.eval()
76
  with torch.no_grad():
77
- model_output = model(**encoded_input)
 
78
  embeddings = last_token_pooling(model_output, encoded_input['attention_mask'])
79
  embeddings = F.normalize(embeddings, p=2, dim=1)
80
  print(embeddings.shape)
@@ -95,7 +96,7 @@ model = SentenceTransformer("nomic-ai/nomic-embed-code")
95
  query_emb = model.encode(queries, prompt_name="query")
96
  code_emb = model.encode(code_snippets)
97
 
98
- similarity = model.similarity(query_emb, code_emb)
99
  print(similarity)
100
  ```
101
 
 
63
  tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-code")
64
  model = AutoModel.from_pretrained("nomic-ai/nomic-embed-code")
65
 
66
+ def last_token_pooling(hidden_states, attention_mask):
67
  sequence_lengths = attention_mask.sum(-1) - 1
68
  return hidden_states[torch.arange(hidden_states.shape[0]), sequence_lengths]
69
 
 
74
  encoded_input = tokenizer(code_snippets, padding=True, truncation=True, return_tensors='pt')
75
  model.eval()
76
  with torch.no_grad():
77
+ model_output = model(**encoded_input)[0]
78
+
79
  embeddings = last_token_pooling(model_output, encoded_input['attention_mask'])
80
  embeddings = F.normalize(embeddings, p=2, dim=1)
81
  print(embeddings.shape)
 
96
  query_emb = model.encode(queries, prompt_name="query")
97
  code_emb = model.encode(code_snippets)
98
 
99
+ similarity = model.similarity(query_emb[0], code_emb[0])
100
  print(similarity)
101
  ```
102