jwalanthi commited on
Commit
2ffd102
·
1 Parent(s): 541d16e

this should work out there

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import torch
3
  from minicons import cwe
4
- import pandas as pd
5
- import numpy as np
6
 
7
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
@@ -14,16 +14,19 @@ def predict (Word, Sentence, LM, Layer, Norm):
14
  lm = cwe.CWE(models[LM])
15
  if Layer not in range (lm.layers): return "invalid input: layer not in lm"
16
 
 
 
 
17
  # labels = "These are some fake features".split(" ")
18
  # vals = np.random.randint(-10,10,(5))
19
  # return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
20
  model = FeatureNormPredictor.load_from_checkpoint(
21
- checkpoint_path='models/'+model_name+'.ckpt',
22
  map_location=None
23
  )
24
  model.eval()
25
 
26
- with open ('models/'+model_name+'.txt', "r") as file:
27
  labels = [line.rstrip() for line in file.readlines()]
28
 
29
  data = (Sentence, Word)
 
1
  import gradio as gr
2
  import torch
3
  from minicons import cwe
4
+ from huggingface_hub import hf_hub_download
5
+ import os
6
 
7
  from model import FFNModule, FeatureNormPredictor, FFNParams, TrainingParams
8
 
 
14
  lm = cwe.CWE(models[LM])
15
  if Layer not in range (lm.layers): return "invalid input: layer not in lm"
16
 
17
+ model_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".ckpt", use_auth_token=os.environ['TOKEN'])
18
+ label_path = hf_hub_download("jwalanthi/bert_layer8_to_binder", model_name+".txt", use_auth_token=os.environ['TOKEN'])
19
+
20
  # labels = "These are some fake features".split(" ")
21
  # vals = np.random.randint(-10,10,(5))
22
  # return model_name+" \n"+"\n".join([labels[i]+" "+str(vals[i]) for i in range(len(labels)) if vals[i]>0])
23
  model = FeatureNormPredictor.load_from_checkpoint(
24
+ checkpoint_path=model_path,
25
  map_location=None
26
  )
27
  model.eval()
28
 
29
+ with open (label_path, "r") as file:
30
  labels = [line.rstrip() for line in file.readlines()]
31
 
32
  data = (Sentence, Word)