JasonTPhillipsJr commited on
Commit
857dba3
·
verified ·
1 Parent(s): 6ed7a92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -72,18 +72,18 @@ def process_entity(batch, model, device):
72
  position_list_y=position_list_y)
73
  #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
74
 
75
- embeddings = outputs.hidden_states[-1].to(device)
76
 
77
  # Extract the [CLS] token embedding (first token)
78
- embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
79
 
80
  #pivot_token_len = batch['pivot_token_len'].item()
81
  #pivot_embeddings = embeddings[:, :pivot_token_len, :]
82
 
83
  #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
84
- return embedding, input_ids
85
 
86
- all_embeddings = []
87
  for i, batch in enumerate(data_loader):
88
  if i >= 2: # Stop after processing 3 batches
89
  break
@@ -105,8 +105,8 @@ def get_bert_embedding(review_text):
105
  outputs = bert_model(**inputs)
106
 
107
  # Extract embeddings from the last hidden state
108
- embeddings = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
109
- return embeddings
110
 
111
 
112
 
@@ -160,7 +160,7 @@ if st.button("Highlight Geo-Entities"):
160
  st.write("Embeddings:", bert_embedding)
161
 
162
  #combine the embeddings (NOTE: come back and update after testing)
163
- combined_embedding = torch.cat((bert_embedding,all_embeddings[0]),dim=-1)
164
  st.write("Concatenated Embedding Shape:", combined_embedding.shape)
165
  st.write("Concatenated Embedding:", combined_embedding)
166
 
 
72
  position_list_y=position_list_y)
73
  #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
74
 
75
+ spaBERT_embedding = outputs.hidden_states[-1].to(device)
76
 
77
  # Extract the [CLS] token embedding (first token)
78
+ spaBERT_embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
79
 
80
  #pivot_token_len = batch['pivot_token_len'].item()
81
  #pivot_embeddings = embeddings[:, :pivot_token_len, :]
82
 
83
  #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
84
+ return spaBERT_embedding, input_ids
85
 
86
+ spaBERT_embeddings = []
87
  for i, batch in enumerate(data_loader):
88
  if i >= 2: # Stop after processing 3 batches
89
  break
 
105
  outputs = bert_model(**inputs)
106
 
107
  # Extract embeddings from the last hidden state
108
+ bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
109
+ return bert_embedding
110
 
111
 
112
 
 
160
  st.write("Embeddings:", bert_embedding)
161
 
162
  #combine the embeddings (NOTE: come back and update after testing)
163
+ combined_embedding = torch.cat((bert_embedding,spaBERT_embeddings[0]),dim=-1)
164
  st.write("Concatenated Embedding Shape:", combined_embedding.shape)
165
  st.write("Concatenated Embedding:", combined_embedding)
166