Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
|
2 |
import transformers
|
3 |
from transformers import GraphormerForGraphClassification
|
|
|
4 |
|
5 |
import os
|
6 |
try:
|
@@ -9,4 +10,24 @@ except ImportError:
|
|
9 |
os.system('pip install toml')
|
10 |
import toml
|
11 |
print('todo en orden')
|
12 |
-
model = GraphormerForGraphClassification.from_pretrained("PedroLancharesSanchez/graph-regression")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
import transformers
|
3 |
from transformers import GraphormerForGraphClassification
|
4 |
+
import gradio as gr
|
5 |
|
6 |
import os
|
7 |
try:
|
|
|
10 |
os.system('pip install toml')
|
11 |
import toml
|
12 |
print('todo en orden')
|
13 |
+
model = GraphormerForGraphClassification.from_pretrained("PedroLancharesSanchez/graph-regression")
|
14 |
+
|
15 |
+
def predict(instancia):
|
16 |
+
instancia_preprocesada=preprocess_item(instancia)
|
17 |
+
inputs={}
|
18 |
+
inputs['input_nodes'] = torch.tensor([instancia_preprocesada['input_nodes']])
|
19 |
+
inputs['input_edges'] = torch.tensor([instancia_preprocesada['input_edges']])
|
20 |
+
inputs['attn_bias'] = torch.tensor([instancia_preprocesada['attn_bias']])
|
21 |
+
inputs['in_degree'] = torch.tensor([instancia_preprocesada['in_degree']])
|
22 |
+
inputs['out_degree'] = torch.tensor([instancia_preprocesada['out_degree']])
|
23 |
+
inputs['spatial_pos'] = torch.tensor([instancia_preprocesada['spatial_pos']])
|
24 |
+
inputs['attn_edge_type'] = torch.tensor([instancia_preprocesada['attn_edge_type']])
|
25 |
+
with torch.no_grad():
|
26 |
+
logits = model(**inputs).logits
|
27 |
+
predicted_class_id = logits.argmax().item()
|
28 |
+
return logits
|
29 |
+
|
30 |
+
graph_input = gr.inputs.Graph(graph_type="networkx", label="Grafo de entrada")
|
31 |
+
regression_output = gr.outputs.Textbox(label="Valor de regresi贸n")
|
32 |
+
|
33 |
+
gr.Interface(fn=predict, inputs=graph_input, outputs=regression_output,examples=['grafo1.json','grafo2.json']).launch(share=False)
|