SreyanG-NVIDIA commited on
Commit
67492cd
Β·
verified Β·
1 Parent(s): 7272785

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -12
app.py CHANGED
@@ -18,24 +18,48 @@ generation_config_multi = model_multi.default_generation_config
18
  # ---------------------------------
19
  # MULTI-TURN INFERENCE FUNCTION
20
  # ---------------------------------
21
- def multi_turn_chat(user_input, audio_file, history, current_audio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
  if audio_file is not None:
24
- current_audio = audio_file # Update state if a new file is uploaded
25
 
26
- if current_audio is None:
27
- return history + [("System", "❌ Please upload an audio file before chatting.")], history, current_audio
28
 
29
- sound = llava.Sound(current_audio)
30
- prompt = f"<sound>\n{user_input}"
31
 
32
- response = model_multi.generate_content([sound, prompt], generation_config=generation_config_multi)
 
 
33
 
34
  history.append((user_input, response))
35
- return history, history, current_audio
 
36
  except Exception as e:
37
  history.append((user_input, f"❌ Error: {str(e)}"))
38
- return history, history, current_audio
 
 
39
  def speech_prompt_infer(audio_prompt_file):
40
  try:
41
  sound = llava.Sound(audio_prompt_file)
@@ -118,12 +142,16 @@ with gr.Blocks(css="""
118
  user_input_multi = gr.Textbox(label="Your message", placeholder="Ask a question about the audio...", lines=8)
119
  btn_multi = gr.Button("Send")
120
  history_state = gr.State([]) # Chat history
121
- current_audio_state = gr.State(None) # Most recent audio file path
 
 
122
 
123
  btn_multi.click(
124
  fn=multi_turn_chat,
125
- inputs=[user_input_multi, audio_input_multi, history_state, current_audio_state],
126
- outputs=[chatbot, history_state, current_audio_state]
 
 
127
  )
128
  gr.Examples(
129
  examples=[
 
18
  # ---------------------------------
19
  # MULTI-TURN INFERENCE FUNCTION
20
  # ---------------------------------
21
+ # def multi_turn_chat(user_input, audio_file, history, current_audio):
22
+ # try:
23
+ # if audio_file is not None:
24
+ # current_audio = audio_file # Update state if a new file is uploaded
25
+
26
+ # if current_audio is None:
27
+ # return history + [("System", "❌ Please upload an audio file before chatting.")], history, current_audio
28
+
29
+ # sound = llava.Sound(current_audio)
30
+ # prompt = f"<sound>\n{user_input}"
31
+
32
+ # response = model_multi.generate_content([sound, prompt], generation_config=generation_config_multi)
33
+
34
+ # history.append((user_input, response))
35
+ # return history, history, current_audio
36
+ # except Exception as e:
37
+ # history.append((user_input, f"❌ Error: {str(e)}"))
38
+ # return history, history, current_audio
39
+
40
+ def multi_turn_chat(user_input, audio_file, history, audio_history):
41
  try:
42
  if audio_file is not None:
43
+ audio_history.append(audio_file) # Append new audio to the list
44
 
45
+ if not audio_history:
46
+ return history + [("System", "❌ Please upload an audio file before chatting.")], history, audio_history
47
 
48
+ # Create list of llava.Sound objects for each audio in history
49
+ audio_sounds = [llava.Sound(audio) for audio in audio_history]
50
 
51
+ # Add the user prompt after all audio sounds
52
+ prompt = f"<sound>\n{user_input}"
53
+ response = model_multi.generate_content(audio_sounds + [prompt], generation_config=generation_config_multi)
54
 
55
  history.append((user_input, response))
56
+ return history, history, audio_history
57
+
58
  except Exception as e:
59
  history.append((user_input, f"❌ Error: {str(e)}"))
60
+ return history, history, audio_history
61
+
62
+
63
  def speech_prompt_infer(audio_prompt_file):
64
  try:
65
  sound = llava.Sound(audio_prompt_file)
 
142
  user_input_multi = gr.Textbox(label="Your message", placeholder="Ask a question about the audio...", lines=8)
143
  btn_multi = gr.Button("Send")
144
  history_state = gr.State([]) # Chat history
145
+ # current_audio_state = gr.State(None) # Most recent audio file path
146
+ audio_history_state = gr.State([]) # List of audio file paths
147
+
148
 
149
  btn_multi.click(
150
  fn=multi_turn_chat,
151
+ inputs=[user_input_multi, audio_input_multi, history_state, audio_history_state],
152
+ outputs=[chatbot, history_state, audio_history_state]
153
+ # inputs=[user_input_multi, audio_input_multi, history_state, current_audio_state],
154
+ # outputs=[chatbot, history_state, current_audio_state]
155
  )
156
  gr.Examples(
157
  examples=[