Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -72,18 +72,18 @@ def process_entity(batch, model, device):
|
|
72 |
position_list_y=position_list_y)
|
73 |
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
|
74 |
|
75 |
-
|
76 |
|
77 |
# Extract the [CLS] token embedding (first token)
|
78 |
-
|
79 |
|
80 |
#pivot_token_len = batch['pivot_token_len'].item()
|
81 |
#pivot_embeddings = embeddings[:, :pivot_token_len, :]
|
82 |
|
83 |
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
|
84 |
-
return
|
85 |
|
86 |
-
|
87 |
for i, batch in enumerate(data_loader):
|
88 |
if i >= 2: # Stop after processing 3 batches
|
89 |
break
|
@@ -105,8 +105,8 @@ def get_bert_embedding(review_text):
|
|
105 |
outputs = bert_model(**inputs)
|
106 |
|
107 |
# Extract embeddings from the last hidden state
|
108 |
-
|
109 |
-
return
|
110 |
|
111 |
|
112 |
|
@@ -160,7 +160,7 @@ if st.button("Highlight Geo-Entities"):
|
|
160 |
st.write("Embeddings:", bert_embedding)
|
161 |
|
162 |
#combine the embeddings (NOTE: come back and update after testing)
|
163 |
-
combined_embedding = torch.cat((bert_embedding,
|
164 |
st.write("Concatenated Embedding Shape:", combined_embedding.shape)
|
165 |
st.write("Concatenated Embedding:", combined_embedding)
|
166 |
|
|
|
72 |
position_list_y=position_list_y)
|
73 |
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
|
74 |
|
75 |
+
spaBERT_embedding = outputs.hidden_states[-1].to(device)
|
76 |
|
77 |
# Extract the [CLS] token embedding (first token)
|
78 |
+
spaBERT_embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
|
79 |
|
80 |
#pivot_token_len = batch['pivot_token_len'].item()
|
81 |
#pivot_embeddings = embeddings[:, :pivot_token_len, :]
|
82 |
|
83 |
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
|
84 |
+
return spaBERT_embedding, input_ids
|
85 |
|
86 |
+
spaBERT_embeddings = []
|
87 |
for i, batch in enumerate(data_loader):
|
88 |
if i >= 2: # Stop after processing 3 batches
|
89 |
break
|
|
|
105 |
outputs = bert_model(**inputs)
|
106 |
|
107 |
# Extract embeddings from the last hidden state
|
108 |
+
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
|
109 |
+
return bert_embedding
|
110 |
|
111 |
|
112 |
|
|
|
160 |
st.write("Embeddings:", bert_embedding)
|
161 |
|
162 |
#combine the embeddings (NOTE: come back and update after testing)
|
163 |
+
combined_embedding = torch.cat((bert_embedding,spaBERT_embeddings[0]),dim=-1)
|
164 |
st.write("Concatenated Embedding Shape:", combined_embedding.shape)
|
165 |
st.write("Concatenated Embedding:", combined_embedding)
|
166 |
|