anamargarida commited on
Commit
78b0615
·
verified ·
1 Parent(s): e4e2e9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import streamlit as st
2
  import torch
3
- from safetensors.torch import load_file
4
  from transformers import AutoConfig, AutoTokenizer, AutoModel
5
- from ST2ModelV2_6 import ST2ModelV2
6
  from huggingface_hub import login
7
  import re
8
  import copy
 
 
 
9
 
10
  hf_token = st.secrets["HUGGINGFACE_TOKEN"]
11
  login(token=hf_token)
@@ -16,10 +17,9 @@ login(token=hf_token)
16
  @st.cache_resource
17
  def load_model():
18
 
19
- model_name = "anamargarida/Final"
20
-
21
- config = AutoConfig.from_pretrained(model_name)
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
24
  class Args:
25
  def __init__(self):
@@ -30,9 +30,19 @@ def load_model():
30
 
31
  args = Args()
32
 
33
- # Load the model directly from Hugging Face
34
- model = ST2ModelV2.from_pretrained(model_name, config=config, args=args)
 
 
 
 
 
 
 
 
 
35
 
 
36
 
37
  return tokenizer, model
38
 
 
1
  import streamlit as st
2
  import torch
 
3
  from transformers import AutoConfig, AutoTokenizer, AutoModel
 
4
  from huggingface_hub import login
5
  import re
6
  import copy
7
+ from src.models.modeling_st2 import ST2ModelV2, SignalDetector
8
+ from huggingface_hub import hf_hub_download
9
+ from safetensors.torch import load_file
10
 
11
  hf_token = st.secrets["HUGGINGFACE_TOKEN"]
12
  login(token=hf_token)
 
17
  @st.cache_resource
18
  def load_model():
19
 
20
+ config = AutoConfig.from_pretrained("roberta-large")
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("roberta-large", use_fast=True, add_prefix_space=True)
 
23
 
24
  class Args:
25
  def __init__(self):
 
30
 
31
  args = Args()
32
 
33
+ model = ST2ModelV2(args)
34
+
35
+
36
+ repo_id = "anamargarida/SpanExtractionWithSignalCls_2"
37
+ filename = "model.safetensors"
38
+
39
+ # Download the model file
40
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
41
+
42
+ # Load the model weights
43
+ state_dict = load_file(model_path)
44
 
45
+ model.load_state_dict(state_dict)
46
 
47
  return tokenizer, model
48