Unggi commited on
Commit
6ade039
·
1 Parent(s): 17bc171

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pip
2
+ pip.main(['install', 'torch'])
3
+ pip.main(['install', 'transformers'])
4
+
5
+ import torch
6
+ import gradio as gr
7
+ import transformers
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+
10
+ def load_model(model_name):
11
+ # model
12
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
+ # tokenizer
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ return model, tokenizer
17
+
18
+
19
+ def inference(prompt):
20
+ model_name = "Unggi/feedback_prize_kor"
21
+
22
+ model, tokenizer = load_model(
23
+ model_name = model_name
24
+ )
25
+
26
+ inputs = tokenizer(
27
+ prompt,
28
+ return_tensors="pt"
29
+ )
30
+
31
+ with torch.no_grad():
32
+ logits = model(**inputs).logits
33
+
34
+ predicted_class_id = logits.argmax().item()
35
+ class_id = model.config.id2label[predicted_class_id]
36
+
37
+ return class_id
38
+
39
+ demo = gr.Interface(
40
+ fn=inference,
41
+ inputs="text",
42
+ outputs="text", #return 값
43
+ examples=[
44
+ "민주주의 국가에서 국민은 주인이다."
45
+ ]
46
+ examples=[
47
+ "민주주의 국가에서 국민은 주인이다."
48
+ ]
49
+ ).launch() # launch(share=True)를 설정하면 외부에서 접속 가능한 링크가 생성됨
50
+
51
+ demo.launch()