hugging2021 commited on
Commit
9bc983f
·
verified ·
1 Parent(s): 550c97e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -1,22 +1,38 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- # Laden des Modells für Masked Language Modeling
5
- unmasker = pipeline('fill-mask', model='bert-base-uncased')
 
 
6
 
7
- # Gradio Interface
8
- def masked_language_modeling(text):
9
- results = unmasker(text)
10
- return results[0]['sequence']
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  iface = gr.Interface(
13
- fn=masked_language_modeling,
14
- inputs=gr.Textbox(),
15
- outputs=gr.Textbox(),
16
- title='BERT Masked Language Modeling',
17
- description='Enter a sentence with a [MASK] and see the predictions.'
18
- )
19
 
20
- # Starte die Gradio Benutzeroberfläche
21
- if __name__ == '__main__':
22
- iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import BertForMaskedLM, BertTokenizer
4
 
5
+ # Modell und Tokenizer laden mit force_download=True
6
+ model_name = "bert-base-uncased"
7
+ model = BertForMaskedLM.from_pretrained(model_name, force_download=True)
8
+ tokenizer = BertTokenizer.from_pretrained(model_name, force_download=True)
9
 
10
+ # Inferenz-Funktion definieren
11
+ def inference(input_text):
12
+ if "[MASK]" not in input_text:
13
+ return "Error: The input text must contain the [MASK] token."
14
+
15
+ # Tokenisierung
16
+ inputs = tokenizer(input_text, return_tensors="pt")
17
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
18
 
19
+ # Vorhersage
20
+ with torch.no_grad():
21
+ outputs = model(**inputs)
22
+ logits = outputs.logits
23
+
24
+ # Wahrscheinlichsten Token für [MASK] finden
25
+ mask_token_logits = logits[0, mask_token_index, :]
26
+ top_token = torch.topk(mask_token_logits, 1, dim=1).indices[0].tolist()
27
+
28
+ # Vorhersage in den Text einfügen
29
+ predicted_token = tokenizer.decode(top_token)
30
+ result_text = input_text.replace("[MASK]", predicted_token, 1)
31
+
32
+ return result_text
33
+
34
+ # Gradio Interface definieren
35
  iface = gr.Interface(
36
+ fn=inference,
 
 
 
 
 
37
 
38
+ iface.launch(server_port=7862)