jvdzwaan commited on
Commit
fa785ea
·
1 Parent(s): adfd272

First version of task 1 demo

Browse files

Uses 'random' token classification head on top of bert model.

Files changed (2) hide show
  1. app.py +52 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+
4
+ from datasets import Dataset
5
+ from transformers import AutoTokenizer, BertForTokenClassification
6
+
7
+ from ocrpostcorrection.icdar_data import generate_sentences, process_input_ocr
8
+ from ocrpostcorrection.token_classification import tokenize_and_align_labels
9
+ from ocrpostcorrection.utils import predictions_to_labels, predictions2entity_output
10
+
11
+ model_name = 'bert-base-multilingual-cased'
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = BertForTokenClassification.from_pretrained(model_name)
15
+
16
+
17
+ def get_datasets(text_obj, key, size=150, step=150):
18
+ data = {key: text_obj}
19
+ md = pd.DataFrame({'language': ['?'],
20
+ 'file_name': ['ocr_input'],
21
+ 'score': [text_obj.score],
22
+ 'num_tokens': [len(text_obj.tokens)],
23
+ 'num_input_tokens': [len(text_obj.input_tokens)]})
24
+
25
+ df = generate_sentences(md, data, size=size, step=step)
26
+ dataset = Dataset.from_pandas(df)
27
+ tokenized = tokenize_and_align_labels(tokenizer, return_tensors='pt')(dataset)
28
+ del tokenized['labels']
29
+ return data, dataset, tokenized
30
+
31
+
32
+ def tag(text):
33
+ key = 'ocr_input'
34
+ text_obj = process_input_ocr(text)
35
+ data, dataset, tokenized = get_datasets(text_obj, key=key)
36
+ pred = model(**tokenized)
37
+ predictions = predictions_to_labels(pred.logits.detach().numpy())
38
+
39
+ outputs = predictions2entity_output(dataset, predictions, tokenizer, data)
40
+ output = outputs[key]
41
+
42
+ return {"text": text, "entities": output}
43
+
44
+ examples = ['This is a cxample...']
45
+
46
+ demo = gr.Interface(tag,
47
+ gr.Textbox(placeholder="Enter sentence here..."),
48
+ gr.HighlightedText(),
49
+ examples=examples,
50
+ allow_flagging='never')
51
+
52
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets
2
+ git+https://github.com/jvdzwaan/ocrpostcorrection.git#egg=ocrpostcorrection
3
+ pandas
4
+ transformers