Reality123b commited on
Commit
125d37d
·
verified ·
1 Parent(s): f80e2a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -249,8 +249,14 @@ class XylariaChat:
249
 
250
  def query_knowledge_graph(self, query):
251
  query_embedding = self.embedding_model.encode(query, convert_to_tensor=True)
252
-
253
- node_embeddings = {node: self.embedding_model.encode(node, convert_to_tensor=True) for node in self.knowledge_graph.nodes()}
 
 
 
 
 
 
254
 
255
  similarities = {node: util.pytorch_cos_sim(query_embedding, embedding)[0][0].item() for node, embedding in node_embeddings.items()}
256
 
@@ -340,27 +346,17 @@ class XylariaChat:
340
  return f"Error during Math OCR: {e}"
341
 
342
  def extract_entities_and_relations(self, text):
343
- doc = self.embedding_model.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
344
-
345
- with torch.no_grad():
346
- outputs = self.embedding_model(**doc)
347
-
348
  entities = []
349
  relations = []
350
- for i in range(len(doc['input_ids'][0])):
351
- token = self.embedding_model.tokenizer.decode(doc['input_ids'][0][i])
352
- if outputs['last_hidden_state'][0][i].norm() > 3:
353
- entities.append(token)
354
 
355
- if len(entities) >= 2:
356
- for i in range(len(entities) - 1):
357
- relation = f"{entities[i]} related_to {entities[i+1]}"
358
- relations.append(relation)
359
 
360
  return entities, relations
361
 
362
- def update_knowledge_graph(self, text):
363
- entities, relations = self.extract_entities_and_relations(text)
364
  for entity in entities:
365
  self.knowledge_graph.add_node(entity)
366
  for relation in relations:
@@ -372,7 +368,8 @@ class XylariaChat:
372
 
373
  def get_response(self, user_input, image=None):
374
  try:
375
- self.update_knowledge_graph(user_input)
 
376
 
377
  messages = []
378
 
 
249
 
250
  def query_knowledge_graph(self, query):
251
  query_embedding = self.embedding_model.encode(query, convert_to_tensor=True)
252
+
253
+ node_embeddings = {}
254
+ for node in self.knowledge_graph.nodes():
255
+ try:
256
+ node_embedding = self.embedding_model.encode(node, convert_to_tensor=True)
257
+ node_embeddings[node] = node_embedding
258
+ except Exception as e:
259
+ print(f"Error encoding node {node}: {e}")
260
 
261
  similarities = {node: util.pytorch_cos_sim(query_embedding, embedding)[0][0].item() for node, embedding in node_embeddings.items()}
262
 
 
346
  return f"Error during Math OCR: {e}"
347
 
348
  def extract_entities_and_relations(self, text):
349
+ inputs = self.embedding_model.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
350
+
 
 
 
351
  entities = []
352
  relations = []
 
 
 
 
353
 
354
+ entities, relations = self.extract_entities_and_relations(message)
355
+ self.update_knowledge_graph(entities, relations)
 
 
356
 
357
  return entities, relations
358
 
359
+ def update_knowledge_graph(self, entities, relations):
 
360
  for entity in entities:
361
  self.knowledge_graph.add_node(entity)
362
  for relation in relations:
 
368
 
369
  def get_response(self, user_input, image=None):
370
  try:
371
+ entities, relations = self.extract_entities_and_relations(user_input)
372
+ self.update_knowledge_graph(entities, relations)
373
 
374
  messages = []
375