lorenzoscottb commited on
Commit
8fd6eff
Β·
verified Β·
1 Parent(s): 4be9695

Update graph_utils.py

Browse files
Files changed (1) hide show
  1. graph_utils.py +6 -6
graph_utils.py CHANGED
@@ -6,17 +6,17 @@ import os
6
  model_id = "DReAMy-lib/t5-base-DreamBank-Generation-Act-Char"
7
 
8
  def get_graph_dict(graph_text):
9
- edge_labels = {}
10
  if graph_text == "":
11
  edge_labels = {("No_Graphs", None):None}
12
 
13
  else:
14
  try:
15
- for trpl in graph_text[1:-1].split(", "):
16
- h,r,t = trpl.split(" : ")
17
- edge_labels[(h,t)] = "_".join(r.split(" "))
18
  except:
19
- edge_labels = {("Error", None):None}
20
  return edge_labels
21
 
22
  def text_to_graph(text):
@@ -41,7 +41,7 @@ def text_to_graph(text):
41
  net = Network(directed=True)
42
 
43
  # nodes & edges
44
- for (h, t), r in edge_labels.items():
45
  if (h == "Error") or (h == "No_Graphs"):
46
  net.add_node(h, shape="circle")
47
  continue
 
6
  model_id = "DReAMy-lib/t5-base-DreamBank-Generation-Act-Char"
7
 
8
  def get_graph_dict(graph_text):
9
+ edge_labels = []
10
  if graph_text == "":
11
  edge_labels = {("No_Graphs", None):None}
12
 
13
  else:
14
  try:
15
+ for trpl in graph_text[1:-1].split("), ("):
16
+ h,r,t = trpl.split(" : ")
17
+ edge_labels.append((h,t, "_".join(r.split(" "))))
18
  except:
19
+ edge_labels.append(("Error", None, None))
20
  return edge_labels
21
 
22
  def text_to_graph(text):
 
41
  net = Network(directed=True)
42
 
43
  # nodes & edges
44
+ for (h, t, r) in edge_labels:
45
  if (h == "Error") or (h == "No_Graphs"):
46
  net.add_node(h, shape="circle")
47
  continue