Hyeonsieun commited on
Commit
afedaa3
ยท
verified ยท
1 Parent(s): d63114d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import gradio as gr
4
+ from transformers import pipeline
5
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
6
+
7
+ import re
8
+ import os
9
+ import json
10
+ import requests
11
+ import whisper
12
+ from yt_dlp import YoutubeDL
13
+
14
+ import matplotlib as plt
15
+
16
+ #whisper_model = whisper.load_model('small')
17
+
18
+ path = "Hyeonsieun/NTtoGT_1epoch"
19
+ tokenizer = T5Tokenizer.from_pretrained(path)
20
+ model = T5ForConditionalGeneration.from_pretrained(path)
21
+
22
+
23
+ MODEL_NAME = "openai/whisper-large-v2"
24
+ BATCH_SIZE = 8
25
+ #FILE_LIMIT_MB = 1000
26
+
27
+ pipe = pipeline(
28
+ task="automatic-speech-recognition",
29
+ model=MODEL_NAME,
30
+ chunk_length_s=30,
31
+ )
32
+
33
+
34
+ def transcribe(inputs):
35
+ if inputs is None:
36
+ raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
37
+
38
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"]
39
+ return text
40
+
41
+ def remove_spaces_within_dollar(text):
42
+ # ๋‹ฌ๋Ÿฌ ๊ธฐํ˜ธ๋กœ ๋‘˜๋Ÿฌ์‹ธ์ธ ๋ถ€๋ถ„์—์„œ ์ŠคํŽ˜์ด์Šค ์ œ๊ฑฐ
43
+ # ์ •๊ทœ ํ‘œํ˜„์‹: \$.*?\$ ๋Š” '$'๋กœ ์‹œ์ž‘ํ•ด์„œ '$'๋กœ ๋๋‚˜๋Š” ์ตœ์†Œํ•œ์˜ ๋ฌธ์ž์—ด์„ ์ฐพ์Œ (non-greedy)
44
+ # re.sub์˜ repl ํŒŒ๋ผ๋ฏธํ„ฐ์— ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋งค์น˜๋œ ๋ถ€๋ถ„์—์„œ๋งŒ ๋ณ€๊ฒฝ์„ ์ ์šฉ
45
+ result = re.sub(r'\$(.*?)\$', lambda match: match.group(0).replace(' ', ''), text)
46
+ return result
47
+
48
+
49
+ def audio_correction(file):
50
+ ASR_result = transcribe(file)
51
+ text_list = split_text_complex_rules_with_warning(ASR_result)
52
+ whole_text = ''
53
+ for text in text_list:
54
+ input_text = f"translate the text pronouncing the formula to a LaTeX equation: {text}"
55
+ inputs = tokenizer.encode(
56
+ input_text,
57
+ return_tensors='pt',
58
+ max_length=325,
59
+ padding='max_length',
60
+ truncation=True
61
+ )
62
+ # Get correct sentence ids.
63
+ corrected_ids = model.generate(
64
+ inputs,
65
+ max_length=325,
66
+ num_beams=5, # `num_beams=1` indicated temperature sampling.
67
+ early_stopping=True
68
+ )
69
+ # Decode.
70
+ corrected_sentence = tokenizer.decode(
71
+ corrected_ids[0],
72
+ skip_special_tokens=False
73
+ )
74
+ whole_text += corrected_sentence
75
+
76
+ return remove_spaces_within_dollar(whole_text)[5:-4]
77
+
78
+ def youtubeASR(link):
79
+ # ์œ ํŠœ๋ธŒ์˜ ์Œ์„ฑ๋งŒ ๋‹ค์šด๋กœ๋“œํ•  ์ž„์‹œ ํŒŒ์ผ๋ช…
80
+ out_fn = 'temp1.mp3'
81
+
82
+ ydl_opts = {
83
+ 'format': 'bestaudio/best', # Audio๋งŒ ๋‹ค์šด๋กœ๋“œ
84
+ 'outtmpl': out_fn, # ์ง€์ •ํ•œ ํŒŒ์ผ๋ช…์œผ๋กœ ์ €์žฅ
85
+ }
86
+
87
+ with YoutubeDL(ydl_opts) as ydl:
88
+ ydl.download([link])
89
+
90
+ result = pipe(out_fn, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"}, return_timestamps=True)["text"] # Youtube์—์„œ ๋ฐ›์€ ์Œ์„ฑ ํŒŒ์ผ(out_fn)์„ ๋ฐ›์•„์“ฐ๊ธฐ
91
+ script = result['text'] # ๋ฐ›์•„์“ฐ๊ธฐ ํ•œ ๋‚ด์šฉ ์ €์žฅ
92
+ return script
93
+
94
+ def split_text_complex_rules_with_warning(text):
95
+ # ์ฝค๋งˆ๋ฅผ ์ œ์™ธํ•œ ๊ตฌ๋‘์ ์œผ๋กœ ๋ฌธ์žฅ ๋ถ„๋ฆฌ
96
+ parts = re.split(r'(?<=[.?!])\s+', text)
97
+
98
+ result = []
99
+ warnings = [] # ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€๋ฅผ ์ €์žฅํ•  ๋ฆฌ์ŠคํŠธ
100
+ for part in parts:
101
+ # ๊ฐ ๋ถ€๋ถ„์˜ ๊ธธ์ด๊ฐ€ 256์ž๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๊ฒฝ์šฐ ์ฝค๋งˆ๋กœ ์ถ”๊ฐ€ ๋ถ„๋ฆฌ
102
+ if len(part) > 256:
103
+ subparts = re.split(r',\s*', part)
104
+ for subpart in subparts:
105
+ # ๋นˆ ๋ฌธ์ž์—ด ์ œ๊ฑฐ ๋ฐ ๊ธธ์ด๊ฐ€ 256์ž ์ดํ•˜์ธ ๊ฒฝ์šฐ๋งŒ ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
106
+ trimmed_subpart = subpart.strip()
107
+ if trimmed_subpart and len(trimmed_subpart) <= 256:
108
+ result.append(trimmed_subpart)
109
+ else:
110
+ # ๊ธธ์ด๊ฐ€ 256์ž๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๊ฒฝ์šฐ ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€
111
+ warnings.append(f"๋ฌธ์žฅ ๊ธธ์ด๊ฐ€ 256์ž๋ฅผ ์ดˆ๊ณผํ•ฉ๋‹ˆ๋‹ค: {trimmed_subpart[:50]}... (๊ธธ์ด: {len(trimmed_subpart)})")
112
+ else:
113
+ # ๊ธธ์ด๊ฐ€ 256์ž ์ดํ•˜์ธ ๊ฒฝ์šฐ ๋ฐ”๋กœ ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€
114
+ result.append(part.strip())
115
+ warnings = 0
116
+
117
+ return result
118
+
119
+
120
+ def youtube_correction(link):
121
+ ASR_result = youtubeASR(link)
122
+ text_list = split_text_complex_rules_with_warning(ASR_result)
123
+ whole_text = ''
124
+ for text in text_list:
125
+ input_text = f"translate the text pronouncing the formula to a LaTeX equation: {text}"
126
+ inputs = tokenizer.encode(
127
+ input_text,
128
+ return_tensors='pt',
129
+ max_length=325,
130
+ padding='max_length',
131
+ truncation=True
132
+ )
133
+ # Get correct sentence ids.
134
+ corrected_ids = model.generate(
135
+ inputs,
136
+ max_length=325,
137
+ num_beams=5, # `num_beams=1` indicated temperature sampling.
138
+ early_stopping=True
139
+ )
140
+ # Decode.
141
+ corrected_sentence = tokenizer.decode(
142
+ corrected_ids[0],
143
+ skip_special_tokens=False
144
+ )
145
+ whole_text += corrected_sentence
146
+
147
+ return remove_spaces_within_dollar(whole_text)[5:-4]
148
+
149
+
150
+ demo = gr.Blocks()
151
+
152
+ file_transcribe = gr.Interface(
153
+ fn=audio_correction,
154
+ inputs=gr.components.Audio(sources="upload", type="filepath"),
155
+ outputs="text"
156
+ )
157
+
158
+ yt_transcribe = gr.Interface(
159
+ fn=youtube_correction,
160
+ inputs="text",
161
+ outputs="text"
162
+ )
163
+
164
+ with demo:
165
+ gr.TabbedInterface([file_transcribe, yt_transcribe], ["Audio file", "YouTube"])
166
+
167
+ demo.launch()