richardkovacs commited on
Commit
793061a
·
1 Parent(s): a220c84

feat: add scademy models

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +27 -6
  3. requirements.txt +13 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  venv
2
  flagged
 
 
1
  venv
2
  flagged
3
+ .env
app.py CHANGED
@@ -1,15 +1,34 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
 
3
 
4
- raw_model_name = 'distilbert-base-uncased'
5
- raw_model = pipeline('sentiment-analysis', model=raw_model_name)
6
- fine_tuned_model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
7
- fine_tuned_model = pipeline('sentiment-analysis', model=fine_tuned_model_name)
 
 
 
 
 
 
 
 
 
 
8
 
9
  def get_model_output(input_text, model_choice):
10
  raw_result = raw_model(input_text)
 
 
11
  fine_tuned_result = fine_tuned_model(input_text)
12
- return format_model_output(raw_result[0]), format_model_output(fine_tuned_result[0])
 
 
 
 
 
13
 
14
  def format_model_output(output):
15
  return f"I am {output['score']*100:.2f}% sure that the sentiment is {output['label']}"
@@ -22,7 +41,9 @@ iface = gr.Interface(
22
  ],
23
  outputs=[
24
  gr.Textbox(label="Base DistilBERT output (distilbert-base-uncased)"),
25
- gr.Textbox(label="Fine-tuned DistilBERT output")
 
 
26
  ],
27
  )
28
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import os
4
+ from dotenv import load_dotenv
5
 
6
+ load_dotenv()
7
+ token = os.environ.get("HF_TOKEN")
8
+
9
+ raw_model_name = "distilbert-base-uncased"
10
+ raw_model = pipeline("sentiment-analysis", model=raw_model_name)
11
+
12
+ scademy_500_500_model_name = "scademy/bert-finetuned-500-500-0"
13
+ scademy_500_500_model = pipeline("sentiment-analysis", model=scademy_500_500_model_name, token=token)
14
+
15
+ scademy_2500_2500_model_name = "scademy/bert-finetuned-2500-2500-0"
16
+ scademy_2500_2500_model = pipeline("sentiment-analysis", model=scademy_2500_2500_model_name, token=token)
17
+
18
+ fine_tuned_model_name = "distilbert-base-uncased-finetuned-sst-2-english"
19
+ fine_tuned_model = pipeline("sentiment-analysis", model=fine_tuned_model_name)
20
 
21
  def get_model_output(input_text, model_choice):
22
  raw_result = raw_model(input_text)
23
+ scademy_500_500_result = scademy_500_500_model(input_text)
24
+ scademy_2500_2500_result = scademy_2500_2500_model(input_text)
25
  fine_tuned_result = fine_tuned_model(input_text)
26
+ return (
27
+ format_model_output(raw_result[0]),
28
+ format_model_output(scademy_500_500_result[0]),
29
+ format_model_output(scademy_2500_2500_result[0]),
30
+ format_model_output(fine_tuned_result[0])
31
+ )
32
 
33
  def format_model_output(output):
34
  return f"I am {output['score']*100:.2f}% sure that the sentiment is {output['label']}"
 
41
  ],
42
  outputs=[
43
  gr.Textbox(label="Base DistilBERT output (distilbert-base-uncased)"),
44
+ gr.Textbox(label="Scademy DistilBERT 500-500 output"),
45
+ gr.Textbox(label="Scademy DistilBERT 2500-2500 output"),
46
+ gr.Textbox(label="Fine-tuned DistilBERT output"),
47
  ],
48
  )
49
 
requirements.txt CHANGED
@@ -1,7 +1,10 @@
1
  aiofiles==23.2.1
 
 
2
  altair==5.1.2
3
  annotated-types==0.6.0
4
  anyio==3.7.1
 
5
  attrs==23.1.0
6
  certifi==2023.11.17
7
  charset-normalizer==3.3.2
@@ -9,11 +12,14 @@ click==8.1.3
9
  colorama==0.4.6
10
  contourpy==1.2.0
11
  cycler==0.12.1
 
 
12
  exceptiongroup==1.2.0
13
  fastapi==0.104.1
14
  ffmpy==0.3.1
15
  filelock==3.13.1
16
  fonttools==4.45.0
 
17
  fsspec==2023.10.0
18
  gradio==4.5.0
19
  gradio_client==0.7.0
@@ -32,12 +38,16 @@ MarkupSafe==2.1.3
32
  matplotlib==3.8.2
33
  mdurl==0.1.2
34
  mpmath==1.3.0
 
 
35
  networkx==3.2.1
36
  numpy==1.26.2
37
  orjson==3.9.10
38
  packaging==23.2
39
  pandas==2.1.3
40
  Pillow==10.1.0
 
 
41
  pydantic==2.5.1
42
  pydantic_core==2.14.3
43
  pydub==0.25.1
@@ -45,6 +55,7 @@ Pygments==2.17.1
45
  pyparsing==3.1.1
46
  pystache==0.6.0
47
  python-dateutil==2.8.2
 
48
  python-multipart==0.0.6
49
  pytz==2023.3.post1
50
  PyYAML==6.0.1
@@ -72,3 +83,5 @@ tzdata==2023.3
72
  urllib3==2.1.0
73
  uvicorn==0.24.0.post1
74
  websockets==11.0.3
 
 
 
1
  aiofiles==23.2.1
2
+ aiohttp==3.9.0
3
+ aiosignal==1.3.1
4
  altair==5.1.2
5
  annotated-types==0.6.0
6
  anyio==3.7.1
7
+ async-timeout==4.0.3
8
  attrs==23.1.0
9
  certifi==2023.11.17
10
  charset-normalizer==3.3.2
 
12
  colorama==0.4.6
13
  contourpy==1.2.0
14
  cycler==0.12.1
15
+ datasets==2.15.0
16
+ dill==0.3.7
17
  exceptiongroup==1.2.0
18
  fastapi==0.104.1
19
  ffmpy==0.3.1
20
  filelock==3.13.1
21
  fonttools==4.45.0
22
+ frozenlist==1.4.0
23
  fsspec==2023.10.0
24
  gradio==4.5.0
25
  gradio_client==0.7.0
 
38
  matplotlib==3.8.2
39
  mdurl==0.1.2
40
  mpmath==1.3.0
41
+ multidict==6.0.4
42
+ multiprocess==0.70.15
43
  networkx==3.2.1
44
  numpy==1.26.2
45
  orjson==3.9.10
46
  packaging==23.2
47
  pandas==2.1.3
48
  Pillow==10.1.0
49
+ pyarrow==14.0.1
50
+ pyarrow-hotfix==0.5
51
  pydantic==2.5.1
52
  pydantic_core==2.14.3
53
  pydub==0.25.1
 
55
  pyparsing==3.1.1
56
  pystache==0.6.0
57
  python-dateutil==2.8.2
58
+ python-dotenv==1.0.0
59
  python-multipart==0.0.6
60
  pytz==2023.3.post1
61
  PyYAML==6.0.1
 
83
  urllib3==2.1.0
84
  uvicorn==0.24.0.post1
85
  websockets==11.0.3
86
+ xxhash==3.4.1
87
+ yarl==1.9.3