PedroLancharesSanchez commited on
Commit
a95ce8e
verified
1 Parent(s): e3fc442

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -2,6 +2,7 @@
2
  import transformers
3
  from transformers import GraphormerForGraphClassification
4
  import gradio as gr
 
5
 
6
  import os
7
  try:
@@ -13,19 +14,19 @@ print('todo en orden')
13
  model = GraphormerForGraphClassification.from_pretrained("PedroLancharesSanchez/graph-regression")
14
 
15
  def predict(instancia):
16
- instancia_preprocesada=preprocess_item(instancia)
17
- inputs={}
18
- inputs['input_nodes'] = torch.tensor([instancia_preprocesada['input_nodes']])
19
- inputs['input_edges'] = torch.tensor([instancia_preprocesada['input_edges']])
20
- inputs['attn_bias'] = torch.tensor([instancia_preprocesada['attn_bias']])
21
- inputs['in_degree'] = torch.tensor([instancia_preprocesada['in_degree']])
22
- inputs['out_degree'] = torch.tensor([instancia_preprocesada['out_degree']])
23
- inputs['spatial_pos'] = torch.tensor([instancia_preprocesada['spatial_pos']])
24
- inputs['attn_edge_type'] = torch.tensor([instancia_preprocesada['attn_edge_type']])
25
- with torch.no_grad():
26
- logits = model(**inputs).logits
27
- predicted_class_id = logits.argmax().item()
28
- return logits
 
29
 
30
-
31
- gr.Interface(fn=predict, inputs='Grafo', outputs='Puntuaci贸n de regresi贸n',examples=['grafo1.json','grafo2.json']).launch(share=False)
 
2
  import transformers
3
  from transformers import GraphormerForGraphClassification
4
  import gradio as gr
5
+ import json
6
 
7
  import os
8
  try:
 
14
  model = GraphormerForGraphClassification.from_pretrained("PedroLancharesSanchez/graph-regression")
15
 
16
  def predict(instancia):
17
+ instancia=json.loads(instancia)
18
+ instancia_preprocesada=preprocess_item(instancia)
19
+ inputs={}
20
+ inputs['input_nodes'] = torch.tensor([instancia_preprocesada['input_nodes']])
21
+ inputs['input_edges'] = torch.tensor([instancia_preprocesada['input_edges']])
22
+ inputs['attn_bias'] = torch.tensor([instancia_preprocesada['attn_bias']])
23
+ inputs['in_degree'] = torch.tensor([instancia_preprocesada['in_degree']])
24
+ inputs['out_degree'] = torch.tensor([instancia_preprocesada['out_degree']])
25
+ inputs['spatial_pos'] = torch.tensor([instancia_preprocesada['spatial_pos']])
26
+ inputs['attn_edge_type'] = torch.tensor([instancia_preprocesada['attn_edge_type']])
27
+ with torch.no_grad():
28
+ logits = model(**inputs).logits
29
+ predicted_class_id = logits.argmax().item()
30
+ return str(logits.item())
31
 
32
+ gr.Interface(fn=predict, inputs='str', outputs='str',examples=['grafo1.txt','grafo2.txt','grafo3.txt']).launch(share=False)