julian-schelb commited on
Commit
1e77467
·
verified ·
1 Parent(s): 6bfc3e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import gradio as gr
3
+ import torch
4
+
5
+ # Load the model and tokenizer from Hugging Face Hub
6
+ model_name = "julian-schelb/xlm-roberta-base-latin-intertextuality"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10
+ model = model.to(device)
11
+
12
+
13
+ def predict_intertextuality(sentence1, sentence2):
14
+ """
15
+ Predict intertextuality using the specified model.
16
+ """
17
+ # Prepare input for the model
18
+ inputs = tokenizer(
19
+ sentence1,
20
+ sentence2,
21
+ return_tensors="pt",
22
+ truncation=True,
23
+ padding="max_length",
24
+ max_length=512 # Adjust based on model's configuration
25
+ ).to(device)
26
+
27
+ # Perform inference
28
+ model.eval()
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits
32
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
33
+
34
+ # Map probabilities to labels
35
+ return {"Yes": probs[1], "No": probs[0]}
36
+
37
+
38
+ # Define the Gradio interface
39
+ inputs = [
40
+ gr.Textbox(label="Latin Sentence 1"),
41
+ gr.Textbox(label="Latin Sentence 2")
42
+ ]
43
+ outputs = gr.Label(label="Intertextuality Probabilities", num_top_classes=2)
44
+
45
+ gradio_app = gr.Interface(
46
+ fn=predict_intertextuality,
47
+ inputs=inputs,
48
+ outputs=outputs,
49
+ title="Latin Intertextuality Checker",
50
+ description="Enter two Latin sentences to get the probabilities for 'Yes' (intertextual) or 'No' (not intertextual).",
51
+ # flagging="never" # Disable the flag button
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ gradio_app.launch()