hassoudi commited on
Commit
a4f3cd1
·
verified ·
1 Parent(s): b9cfeca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -45
app.py CHANGED
@@ -1,33 +1,28 @@
1
  import gradio as gr
2
  from huggingface_hub import login
3
- from transformers import AutoModelForTokenClassification, AutoTokenizer
4
  import os
5
- import torch
6
 
7
- # Initialize global model and tokenizer
8
- model = None
9
- tokenizer = None
10
 
11
- def load_healthcare_ner():
12
- """Load the Healthcare NER model and tokenizer."""
13
- global model, tokenizer
14
- if model is None or tokenizer is None:
15
  login(token=os.environ["HF_TOKEN"])
16
- model = AutoModelForTokenClassification.from_pretrained(
17
- "TypicaAI/HealthcareNER-Fr",
18
- use_auth_token=os.environ["HF_TOKEN"]
 
 
19
  )
20
- tokenizer = AutoTokenizer.from_pretrained("TypicaAI/HealthcareNER-Fr")
21
- return model, tokenizer
22
 
23
  def process_text(text):
24
  """Process input text and return highlighted entities."""
25
- model, tokenizer = load_healthcare_ner()
26
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
27
- outputs = model(**inputs)
28
-
29
- # Decode entities from outputs
30
- entities = extract_entities(outputs, tokenizer, text)
31
 
32
  # Highlight entities in the text
33
  html_output = highlight_entities(text, entities)
@@ -37,35 +32,14 @@ def process_text(text):
37
 
38
  return html_output
39
 
40
- def extract_entities(outputs, tokenizer, text):
41
- """Extract entities from model outputs."""
42
- tokens = tokenizer.tokenize(text)
43
- predictions = torch.argmax(outputs.logits, dim=2).squeeze().tolist()
44
-
45
- entities = []
46
- current_entity = None
47
- for token, prediction in zip(tokens, predictions):
48
- label = model.config.id2label[prediction]
49
- if label.startswith("B-"):
50
- if current_entity:
51
- entities.append(current_entity)
52
- current_entity = {"entity": label[2:], "text": token, "start": len(text)}
53
- elif label.startswith("I-") and current_entity:
54
- current_entity["text"] += f" {token}"
55
- elif current_entity:
56
- entities.append(current_entity)
57
- current_entity = None
58
- if current_entity:
59
- entities.append(current_entity)
60
- return entities
61
-
62
  def highlight_entities(text, entities):
63
  """Highlight identified entities in the input text."""
64
  highlighted_text = text
65
  for entity in entities:
 
66
  highlighted_text = highlighted_text.replace(
67
- entity["text"],
68
- f'<mark style="background-color: yellow;">{entity["text"]}</mark>'
69
  )
70
  return f"<p>{highlighted_text}</p>"
71
 
@@ -122,4 +96,3 @@ with gr.Blocks() as marketing_elements:
122
  # Launch the Gradio demo
123
  if __name__ == "__main__":
124
  demo.launch()
125
-
 
1
  import gradio as gr
2
  from huggingface_hub import login
3
+ from transformers import pipeline
4
  import os
 
5
 
6
+ # Initialize global pipeline
7
+ ner_pipeline = None
 
8
 
9
+ def load_healthcare_ner_pipeline():
10
+ """Load the Hugging Face pipeline for Healthcare NER."""
11
+ global ner_pipeline
12
+ if ner_pipeline is None:
13
  login(token=os.environ["HF_TOKEN"])
14
+ ner_pipeline = pipeline(
15
+ "token-classification",
16
+ model="TypicaAI/HealthcareNER-Fr",
17
+ use_auth_token=os.environ["HF_TOKEN"],
18
+ aggregation_strategy="simple" # Groups B- and I- tokens into entities
19
  )
20
+ return ner_pipeline
 
21
 
22
  def process_text(text):
23
  """Process input text and return highlighted entities."""
24
+ pipeline = load_healthcare_ner_pipeline()
25
+ entities = pipeline(text)
 
 
 
 
26
 
27
  # Highlight entities in the text
28
  html_output = highlight_entities(text, entities)
 
32
 
33
  return html_output
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def highlight_entities(text, entities):
36
  """Highlight identified entities in the input text."""
37
  highlighted_text = text
38
  for entity in entities:
39
+ entity_text = entity["word"]
40
  highlighted_text = highlighted_text.replace(
41
+ entity_text,
42
+ f'<mark style="background-color: yellow;">{entity_text}</mark>'
43
  )
44
  return f"<p>{highlighted_text}</p>"
45
 
 
96
  # Launch the Gradio demo
97
  if __name__ == "__main__":
98
  demo.launch()