File size: 8,154 Bytes
b99458b
 
07bd805
b99458b
8103be7
b99458b
 
 
8103be7
38f9446
8103be7
b99458b
8103be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b17b8f7
 
8103be7
b17b8f7
 
 
834fda3
8103be7
 
 
 
 
 
 
 
5b2681b
 
 
8103be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c37b0d6
 
1d15344
8103be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e3b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8103be7
13e3b2f
 
 
 
 
 
 
8103be7
 
 
 
 
 
 
 
 
 
38f9446
8103be7
1162de7
38f9446
 
 
 
 
 
 
8103be7
 
e9ce3cc
38f9446
 
8103be7
55e8d79
7d46eac
38f9446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d46eac
8103be7
 
 
 
 
 
 
 
7d46eac
8103be7
 
 
 
 
 
 
 
 
 
807bf12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import re
import sys
import time
import json
from itertools import cycle

import torch
import gradio as gr
import spaces
from urllib.parse import unquote 
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList

from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields
from examples import examples as input_examples
from nuextract_logging import log_event


MAX_INPUT_SIZE = 10_000
MAX_NEW_TOKENS = 4_000
MAX_WINDOW_SIZE = 4_000

markdown_description = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body>
    <img src="https://cdn.prod.website-files.com/638364a4e52e440048a9529c/64188f405afcf42d0b85b926_logo_numind_final.png" alt="NuMind Logo" style="vertical-align: middle;width: 200px; height: 50px;">
    <br>
    <ul>
        <li>NuMind is a startup developing custom information extraction solutions.</li>
        <li>NuExtract is a zero-shot model. See the blog posts for more info (<a href="https://numind.ai/blog/nuextract-a-foundation-model-for-structured-extraction">NuExtract</a>, <a href="https://numind.ai/blog/nuextract-1-5---multilingual-infinite-context-still-small-and-better-than-gpt-4o">NuExtract-v1.5</a>).</li>
        <li>We have started to deploy NuMind Enterprise to customize, serve, and monitor NuExtract privately. If that interests you, let's chat 😊.</li>
        <li><strong>Website</strong>: <a href="https://www.numind.ai/">https://www.numind.ai/</a></li>
    </ul>
    <h1>NuExtract-v1.5</h1>
    <p>NuExtract-v1.5 is a fine-tuning of Phi-3.5-mini-instruct, trained on a private high-quality dataset for structured information extraction. 
    It supports long documents and several languages (English, French, Spanish, German, Portuguese, and Italian). 
    To use the model, provide an input text and a JSON template describing the information you need to extract.</p>
    <ul>
        <li><strong>Model</strong>: <a href="https://huggingface.co/numind/NuExtract-v1.5">numind/NuExtract-v1.5</a></li>
    </ul>
    <i>⚠️ In this space we restrict the model inputs to a maximum length of 10k tokens, with anything over 4k being processed in a sliding window. For full model performance, self-host the model or contact us.</i>
    <br>
    <i>⚠️ The model is trained to assume a valid JSON template. Attempts to use invalid JSON could lead to unpredictable results.</i>
</body>
</html>
"""


def highlight_words(input_text, json_output):
    colors = cycle(["#90ee90", "#add8e6", "#ffb6c1", "#ffff99", "#ffa07a", "#20b2aa", "#87cefa", "#b0e0e6", "#dda0dd", "#ffdead"])
    color_map = {}
    highlighted_text = input_text

    leaves = extract_leaves(json_output)
    for path, value in leaves:
        path_key = tuple(path)
        if path_key not in color_map:
            color_map[path_key] = next(colors)
        color = color_map[path_key]

        escaped_value = re.escape(value).replace(r'\ ', r'\s+') # escape value and replace spaces with \s+
        pattern = rf"(?<=[ \n\t]){escaped_value}(?=[ \n\t\.\,\?\:\;])"
        replacement = f"<span style='background-color: {color};'>{unquote(value)}</span>"
        highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE)

    return highlighted_text

def predict_chunk(text, template, current, model, tokenizer):
    current = clean_json_text(current)

    input_llm =  f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{"
    input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
    output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True)

    return clean_json_text(output.split("<|output|>")[1])

def sliding_window_prediction(template, text, model, tokenizer, window_size=4000, overlap=128):
    # Split text into chunks of n tokens
    tokens = tokenizer.tokenize(text)
    chunks = split_document(text, window_size, overlap, tokenizer)

    # Iterate over text chunks
    prev = template
    full_pred = ""
    
    for i, chunk in enumerate(chunks):
        print(f"Processing chunk {i}...")
        pred = predict_chunk(chunk, template, prev, model, tokenizer)

        # Handle broken output
        pred = handle_broken_output(pred, prev)
        
        # create highlighted text
        try:
            highlighted_pred = highlight_words(text, json.loads(pred))
        except:
            highlighted_pred = text

        # attempt json parsing
        template_dict = None
        pred_dict = None
        try:
            template_dict = json.loads(template)
        except:
            pass
        try:
            pred_dict = json.loads(pred)
        except:
            pass
        
        # Sync empty fields
        if template_dict and pred_dict:
            synced_pred = sync_empty_fields(pred_dict, template_dict)
            synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False)
        elif pred_dict:
            synced_pred = json.dumps(pred_dict, indent=4, ensure_ascii=False)
        else:
            synced_pred = pred

        # Return progress, current prediction, and updated HTML
        yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred

        # Iterate
        prev = pred


######

# Model is loaded here but will be moved to CUDA only when needed with ZeroGPU
model_name = "numind/NuExtract-v1.5"
auth_token = os.environ.get("HF_TOKEN") or False

# Load tokenizer in advance but not the model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)

# We define a function to load the model when needed
def load_model():
    model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             trust_remote_code=True, 
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto", use_auth_token=auth_token)
    model.eval()
    return model

@spaces.GPU
def gradio_interface_function(template, text, is_example):
    try:
        if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
            yield "", "Input text too long for space. Download model to use unrestricted.", ""
            return  # End the function since there was an error

        # Load the model when needed
        model = load_model()

        # Initialize the sliding window prediction process
        prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=MAX_WINDOW_SIZE)

        # Iterate over the generator to return values at each step
        for progress, full_pred, html_content in prediction_generator:
            # yield gr.update(value=chunk_info), gr.update(value=progress), gr.update(value=full_pred), gr.update(value=html_content)
            yield progress, full_pred, html_content

        # Conditionally log event if not an example and logging is configured
        if not is_example:
            try:
                log_event(text, template, full_pred)
            except Exception as e:
                print(f"Warning: Could not log event: {e}", file=sys.stderr)
    except Exception as e:
        error_message = f"Error processing request: {str(e)}"
        print(error_message, file=sys.stderr)
        yield "", error_message, ""
        

# Set up the Gradio interface
iface = gr.Interface(
    description=markdown_description,
    fn=gradio_interface_function,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter Template here...", label="Template"),
        gr.Textbox(lines=2, placeholder="Enter input Text here...", label="Input Text"),
        gr.Checkbox(label="Is Example?", visible=False),
    ],
    outputs=[
        gr.Textbox(label="Progress"),
        gr.Textbox(label="Model Output"),
        gr.HTML(label="Model Output with Highlighted Words"),
    ],
    examples=input_examples,
    # live=True  # Enable real-time updates
)

iface.launch(debug=True)