emanuelaboros commited on
Commit
6627fc9
Β·
1 Parent(s): da7878b

update app

Browse files
Files changed (1) hide show
  1. app.py +40 -14
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
  tokenizer = AutoTokenizer.from_pretrained("impresso-project/nel-hipe-multilingual")
5
  model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -9,21 +10,46 @@ model = AutoModelForSeq2SeqLM.from_pretrained(
9
  print("Model loaded successfully!")
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def disambiguate_sentence(sentence):
13
- results = []
14
  entities = []
15
- for sentence in [sentence]:
16
- outputs = model.generate(
17
- **tokenizer([sentence], return_tensors="pt"),
18
- num_beams=5,
19
- num_return_sequences=5,
20
- max_new_tokens=30,
21
- )
22
- decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
23
- results.append(decoded)
24
- entities.append({"label": decoded[0]})
25
-
26
- print(f"Decoded: {decoded}")
 
 
 
 
 
27
  return {"text": sentence, "entities": entities}
28
 
29
 
@@ -36,7 +62,7 @@ def nel_app_interface():
36
  "entity should be surrounded by `[START]` and `[END]`. // "
37
  "!Only one entity per sentence is supported at the moment!",
38
  )
39
- output_entities = gr.HighlightedText(label="Linked Entities")
40
 
41
  # Interface definition
42
  interface = gr.Interface(
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import requests
4
 
5
  tokenizer = AutoTokenizer.from_pretrained("impresso-project/nel-hipe-multilingual")
6
  model = AutoModelForSeq2SeqLM.from_pretrained(
 
10
  print("Model loaded successfully!")
11
 
12
 
13
+ def get_wikipedia_title(qid, language="en"):
14
+ url = f"https://www.wikidata.org/w/api.php"
15
+ params = {
16
+ "action": "wbgetentities",
17
+ "format": "json",
18
+ "ids": qid,
19
+ "props": "sitelinks/urls",
20
+ "sitefilter": f"{language}wiki",
21
+ }
22
+
23
+ response = requests.get(url, params=params)
24
+ data = response.json()
25
+
26
+ try:
27
+ title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
28
+ url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
29
+ return title, url
30
+ except KeyError:
31
+ return "NIL", "None"
32
+
33
+
34
  def disambiguate_sentence(sentence):
 
35
  entities = []
36
+ # Generate model outputs for the sentence
37
+ outputs = model.generate(
38
+ **tokenizer([sentence], return_tensors="pt"),
39
+ num_beams=5,
40
+ num_return_sequences=5,
41
+ max_new_tokens=30,
42
+ )
43
+ decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
44
+ qid = decoded[0].split()[-1] # Assuming QID is the last token in the output
45
+
46
+ # Get Wikipedia title and URL
47
+ title, url = get_wikipedia_title(qid)
48
+
49
+ entity_info = f"QID: {qid}, Title: {title}, URL: {url}"
50
+ entities.append(entity_info)
51
+
52
+ print(f"Entities: {entities}")
53
  return {"text": sentence, "entities": entities}
54
 
55
 
 
62
  "entity should be surrounded by `[START]` and `[END]`. // "
63
  "!Only one entity per sentence is supported at the moment!",
64
  )
65
+ output_entities = gr.Textbox(label="Linked Entities")
66
 
67
  # Interface definition
68
  interface = gr.Interface(