donbr commited on
Commit
38f9446
·
1 Parent(s): 1b211a2

monkey patch

Browse files
Files changed (2) hide show
  1. app.py +44 -24
  2. requirements.txt +1 -1
app.py CHANGED
@@ -7,13 +7,19 @@ from itertools import cycle
7
 
8
  import torch
9
  import gradio as gr
 
10
  from urllib.parse import unquote
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
 
 
 
 
 
 
12
 
13
  from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields
14
  from examples import examples as input_examples
15
  from nuextract_logging import log_event
16
- import spaces
17
 
18
 
19
  MAX_INPUT_SIZE = 10_000
@@ -131,36 +137,50 @@ def sliding_window_prediction(template, text, model, tokenizer, window_size=4000
131
 
132
  ######
133
 
134
- # Load the model and tokenizer
135
  model_name = "numind/NuExtract-v1.5"
136
  auth_token = os.environ.get("HF_TOKEN") or False
137
- model = AutoModelForCausalLM.from_pretrained(model_name,
 
 
 
 
 
 
138
  trust_remote_code=True,
139
  torch_dtype=torch.bfloat16,
140
  device_map="auto", use_auth_token=auth_token)
141
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
142
- model.eval()
143
 
144
- @spaces.GPU
145
  def gradio_interface_function(template, text, is_example):
146
- if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
147
- yield "", "Input text too long for space. Download model to use unrestricted.", ""
148
- return # End the function since there was an error
149
-
150
- # Initialize the sliding window prediction process
151
- prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=MAX_WINDOW_SIZE)
152
-
153
- # Iterate over the generator to return values at each step
154
- for progress, full_pred, html_content in prediction_generator:
155
- # yield gr.update(value=chunk_info), gr.update(value=progress), gr.update(value=full_pred), gr.update(value=html_content)
156
- yield progress, full_pred, html_content
157
-
158
- # Conditionally log event if not an example and logging is configured
159
- if not is_example:
160
- try:
161
- log_event(text, template, full_pred)
162
- except Exception as e:
163
- print(f"Warning: Could not log event: {e}", file=sys.stderr)
 
 
 
 
 
 
 
 
164
 
165
 
166
  # Set up the Gradio interface
 
7
 
8
  import torch
9
  import gradio as gr
10
+ import spaces
11
  from urllib.parse import unquote
12
  from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
13
+ from transformers.cache_utils import DynamicCache
14
+
15
+ # Add get_max_length method to DynamicCache if it doesn't exist
16
+ # This is needed for compatibility with Phi-3.5 models
17
+ if not hasattr(DynamicCache, 'get_max_length'):
18
+ DynamicCache.get_max_length = lambda self: self.get_seq_length()
19
 
20
  from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields
21
  from examples import examples as input_examples
22
  from nuextract_logging import log_event
 
23
 
24
 
25
  MAX_INPUT_SIZE = 10_000
 
137
 
138
  ######
139
 
140
+ # Model is loaded here but will be moved to CUDA only when needed with ZeroGPU
141
  model_name = "numind/NuExtract-v1.5"
142
  auth_token = os.environ.get("HF_TOKEN") or False
143
+
144
+ # Load tokenizer in advance but not the model
145
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
146
+
147
+ # We define a function to load the model when needed
148
+ def load_model():
149
+ model = AutoModelForCausalLM.from_pretrained(model_name,
150
  trust_remote_code=True,
151
  torch_dtype=torch.bfloat16,
152
  device_map="auto", use_auth_token=auth_token)
153
+ model.eval()
154
+ return model
155
 
156
+ @spaces.GPU(duration=300)
157
  def gradio_interface_function(template, text, is_example):
158
+ try:
159
+ if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
160
+ yield "", "Input text too long for space. Download model to use unrestricted.", ""
161
+ return # End the function since there was an error
162
+
163
+ # Load the model when needed
164
+ model = load_model()
165
+
166
+ # Initialize the sliding window prediction process
167
+ prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=MAX_WINDOW_SIZE)
168
+
169
+ # Iterate over the generator to return values at each step
170
+ for progress, full_pred, html_content in prediction_generator:
171
+ # yield gr.update(value=chunk_info), gr.update(value=progress), gr.update(value=full_pred), gr.update(value=html_content)
172
+ yield progress, full_pred, html_content
173
+
174
+ # Conditionally log event if not an example and logging is configured
175
+ if not is_example:
176
+ try:
177
+ log_event(text, template, full_pred)
178
+ except Exception as e:
179
+ print(f"Warning: Could not log event: {e}", file=sys.stderr)
180
+ except Exception as e:
181
+ error_message = f"Error processing request: {str(e)}"
182
+ print(error_message, file=sys.stderr)
183
+ yield "", error_message, ""
184
 
185
 
186
  # Set up the Gradio interface
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  transformers
2
  torch
3
  accelerate
4
- spaces
 
1
  transformers
2
  torch
3
  accelerate
4
+ spaces>=0.1.0