import gradio as gr from transformers import BertTokenizer, BertForTokenClassification from transformers import pipeline from collections import defaultdict model_name = "b3x0m/bert-xomlac-ner" tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForTokenClassification.from_pretrained(model_name) nlp_ner = pipeline("ner", model=model, tokenizer=tokenizer) def ner(file, selected_entities, min_count): with open(file.name) as f: text = f.read() lines = text.splitlines() batch_size = 32 batches = [lines[i:i + batch_size] for i in range(0, len(lines), batch_size)] entity_count = defaultdict(int) for batch in batches: batch_text = " ".join(batch) tokens = tokenizer(batch_text)['input_ids'] if len(tokens) > 128: for i in range(0, len(tokens), 128): sub_tokens = tokens[i:i + 128] sub_batch_text = tokenizer.decode(sub_tokens, skip_special_tokens=True) ner_results = nlp_ner(sub_batch_text) current_entity = None for entity in ner_results: if entity['entity'].startswith("B-") or entity['entity'].startswith("M-") or entity['entity'].startswith("I-"): if current_entity is None: current_entity = {'text': entity['word'], 'label': entity['entity'][2:]} else: current_entity['text'] += entity['word'] elif entity['entity'].startswith("E-"): if current_entity: current_entity['text'] += entity['word'] current_entity['label'] = entity['entity'][2:] entity_count[(current_entity['text'], current_entity['label'])] += 1 current_entity = None else: ner_results = nlp_ner(batch_text) current_entity = None for entity in ner_results: if entity['entity'].startswith("B-") or entity['entity'].startswith("M-") or entity['entity'].startswith("I-"): if current_entity is None: current_entity = {'text': entity['word'], 'label': entity['entity'][2:]} else: current_entity['text'] += entity['word'] elif entity['entity'].startswith("E-"): if current_entity: current_entity['text'] += entity['word'] current_entity['label'] = entity['entity'][2:] entity_count[(current_entity['text'], current_entity['label'])] += 1 current_entity = None output = [] for (name, label), count in entity_count.items(): if count >= min_count and (not selected_entities or label in selected_entities): output.append(f"{name}={label}={count}") return "\n".join(output) css = ''' h1#title { text-align: center; } ''' theme = gr.themes.Soft() demo = gr.Blocks(css=css, theme=theme) with demo: input_file = gr.File(label="Upload File (.txt)", file_types=[".txt"]) entity_filter = gr.CheckboxGroup( label="Entities", choices=["PER", "ORG", "LOC", "GPE"], type="value" ) count_entities = gr.Number( label="Frequency", minimum=1, maximum=10, step=1, value=3 ) output_text = gr.Textbox(label="Output", show_copy_button=True, interactive=False, lines=10, max_lines=20) interface = gr.Interface( fn=ner, inputs=[input_file, entity_filter, count_entities], outputs=[output_text], allow_flagging="never", ) demo.launch()