PedroLancharesSanchez commited on
Commit
d34a682
verified
1 Parent(s): 611a628

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
- # Load model directly
2
  import transformers
3
  from transformers import GraphormerForGraphClassification
 
4
 
5
  import os
6
  try:
@@ -9,4 +10,24 @@ except ImportError:
9
  os.system('pip install toml')
10
  import toml
11
  print('todo en orden')
12
- model = GraphormerForGraphClassification.from_pretrained("PedroLancharesSanchez/graph-regression")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import transformers
3
  from transformers import GraphormerForGraphClassification
4
+ import gradio as gr
5
 
6
  import os
7
  try:
 
10
  os.system('pip install toml')
11
  import toml
12
  print('todo en orden')
13
+ model = GraphormerForGraphClassification.from_pretrained("PedroLancharesSanchez/graph-regression")
14
+
15
+ def predict(instancia):
16
+ instancia_preprocesada=preprocess_item(instancia)
17
+ inputs={}
18
+ inputs['input_nodes'] = torch.tensor([instancia_preprocesada['input_nodes']])
19
+ inputs['input_edges'] = torch.tensor([instancia_preprocesada['input_edges']])
20
+ inputs['attn_bias'] = torch.tensor([instancia_preprocesada['attn_bias']])
21
+ inputs['in_degree'] = torch.tensor([instancia_preprocesada['in_degree']])
22
+ inputs['out_degree'] = torch.tensor([instancia_preprocesada['out_degree']])
23
+ inputs['spatial_pos'] = torch.tensor([instancia_preprocesada['spatial_pos']])
24
+ inputs['attn_edge_type'] = torch.tensor([instancia_preprocesada['attn_edge_type']])
25
+ with torch.no_grad():
26
+ logits = model(**inputs).logits
27
+ predicted_class_id = logits.argmax().item()
28
+ return logits
29
+
30
+ graph_input = gr.inputs.Graph(graph_type="networkx", label="Grafo de entrada")
31
+ regression_output = gr.outputs.Textbox(label="Valor de regresi贸n")
32
+
33
+ gr.Interface(fn=predict, inputs=graph_input, outputs=regression_output,examples=['grafo1.json','grafo2.json']).launch(share=False)