Kr08 commited on
Commit
9148c64
·
verified ·
1 Parent(s): a050ac4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -1,20 +1,19 @@
1
  import gradio as gr
2
  from audio_processing import process_audio, print_results
3
- from transformers import pipeline
4
  import spaces
5
  import torch
6
 
7
  # Check if CUDA is available
8
  cuda_available = torch.cuda.is_available()
 
9
 
10
  # Initialize the summarization and question-answering models
11
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
12
- qa_model = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
13
 
14
- # Move models to GPU if available
15
- if cuda_available:
16
- summarizer.to('cuda')
17
- qa_model.to('cuda')
18
 
19
  @spaces.GPU
20
  def transcribe_audio(audio_file, translate, model_size):
@@ -41,13 +40,19 @@ def transcribe_audio(audio_file, translate, model_size):
41
 
42
  @spaces.GPU
43
  def summarize_text(text):
44
- summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
 
 
45
  return summary
46
 
47
  @spaces.GPU
48
  def answer_question(context, question):
49
- result = qa_model(question=question, context=context)
50
- return result['answer']
 
 
 
 
51
 
52
  @spaces.GPU
53
  def process_and_summarize(audio_file, translate, model_size):
 
1
  import gradio as gr
2
  from audio_processing import process_audio, print_results
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
4
  import spaces
5
  import torch
6
 
7
  # Check if CUDA is available
8
  cuda_available = torch.cuda.is_available()
9
+ device = "cuda" if cuda_available else "cpu"
10
 
11
  # Initialize the summarization and question-answering models
12
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn").to(device)
13
+ summarizer_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
14
 
15
+ qa_model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad").to(device)
16
+ qa_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
 
 
17
 
18
  @spaces.GPU
19
  def transcribe_audio(audio_file, translate, model_size):
 
40
 
41
  @spaces.GPU
42
  def summarize_text(text):
43
+ inputs = summarizer_tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
44
+ summary_ids = summarizer_model.generate(inputs["input_ids"], max_length=150, min_length=50, do_sample=False)
45
+ summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
46
  return summary
47
 
48
  @spaces.GPU
49
  def answer_question(context, question):
50
+ inputs = qa_tokenizer(question, context, return_tensors="pt").to(device)
51
+ outputs = qa_model(**inputs)
52
+ answer_start = torch.argmax(outputs.start_logits)
53
+ answer_end = torch.argmax(outputs.end_logits) + 1
54
+ answer = qa_tokenizer.decode(inputs["input_ids"][0][answer_start:answer_end])
55
+ return answer
56
 
57
  @spaces.GPU
58
  def process_and_summarize(audio_file, translate, model_size):