Update README.md
Browse files
README.md
CHANGED
@@ -63,7 +63,7 @@ from transformers import AutoTokenizer, AutoModel
|
|
63 |
tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-code")
|
64 |
model = AutoModel.from_pretrained("nomic-ai/nomic-embed-code")
|
65 |
|
66 |
-
def last_token_pooling(
|
67 |
sequence_lengths = attention_mask.sum(-1) - 1
|
68 |
return hidden_states[torch.arange(hidden_states.shape[0]), sequence_lengths]
|
69 |
|
@@ -74,7 +74,8 @@ code_snippets = queries + codes
|
|
74 |
encoded_input = tokenizer(code_snippets, padding=True, truncation=True, return_tensors='pt')
|
75 |
model.eval()
|
76 |
with torch.no_grad():
|
77 |
-
model_output = model(**encoded_input)
|
|
|
78 |
embeddings = last_token_pooling(model_output, encoded_input['attention_mask'])
|
79 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
80 |
print(embeddings.shape)
|
@@ -95,7 +96,7 @@ model = SentenceTransformer("nomic-ai/nomic-embed-code")
|
|
95 |
query_emb = model.encode(queries, prompt_name="query")
|
96 |
code_emb = model.encode(code_snippets)
|
97 |
|
98 |
-
similarity = model.similarity(query_emb, code_emb)
|
99 |
print(similarity)
|
100 |
```
|
101 |
|
|
|
63 |
tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-code")
|
64 |
model = AutoModel.from_pretrained("nomic-ai/nomic-embed-code")
|
65 |
|
66 |
+
def last_token_pooling(hidden_states, attention_mask):
|
67 |
sequence_lengths = attention_mask.sum(-1) - 1
|
68 |
return hidden_states[torch.arange(hidden_states.shape[0]), sequence_lengths]
|
69 |
|
|
|
74 |
encoded_input = tokenizer(code_snippets, padding=True, truncation=True, return_tensors='pt')
|
75 |
model.eval()
|
76 |
with torch.no_grad():
|
77 |
+
model_output = model(**encoded_input)[0]
|
78 |
+
|
79 |
embeddings = last_token_pooling(model_output, encoded_input['attention_mask'])
|
80 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
81 |
print(embeddings.shape)
|
|
|
96 |
query_emb = model.encode(queries, prompt_name="query")
|
97 |
code_emb = model.encode(code_snippets)
|
98 |
|
99 |
+
similarity = model.similarity(query_emb[0], code_emb[0])
|
100 |
print(similarity)
|
101 |
```
|
102 |
|