SivaMallikarjun commited on
Commit
76b3996
·
verified ·
1 Parent(s): ba6b64b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -22
app.py CHANGED
@@ -1,22 +1,22 @@
1
- import gradio as gr
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
- import torch
4
-
5
- model_path = "./models/fine_tuned_xlm_roberta_quantized"
6
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
7
- tokenizer = AutoTokenizer.from_pretrained(model_path)
8
-
9
- def classify_text(text):
10
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
11
- outputs = model(**inputs)
12
- prediction = torch.argmax(outputs.logits, dim=1).item()
13
- label = "Correct" if prediction == 1 else "Incorrect"
14
- return label
15
-
16
- iface = gr.Interface(fn=classify_text,
17
- inputs="text",
18
- outputs="text",
19
- title="Multi-Language RL Text Classifier")
20
-
21
- if __name__ == "__main__":
22
- iface.launch()
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ import torch
4
+
5
+ model_path = "SivaMallikarjun/multi-language-rl-model"
6
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
8
+
9
+ def classify_text(text):
10
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
11
+ outputs = model(**inputs)
12
+ prediction = torch.argmax(outputs.logits, dim=1).item()
13
+ label = "Correct" if prediction == 1 else "Incorrect"
14
+ return label
15
+
16
+ iface = gr.Interface(fn=classify_text,
17
+ inputs="text",
18
+ outputs="text",
19
+ title="Multi-Language RL Text Classifier")
20
+
21
+ if __name__ == "__main__":
22
+ iface.launch()