File size: 2,609 Bytes
2ed7223
 
2ba8923
c621812
039f770
ee53056
011a958
9c0a84a
e48e518
 
dcc2eea
e48e518
dc03737
2ba8923
d649fba
 
 
 
 
ee53056
d649fba
2ba8923
ee53056
9c0a84a
ee53056
 
e48e518
9c0a84a
e48e518
ee53056
9c0a84a
e48e518
ee53056
e48e518
2ba8923
ee53056
9c0a84a
e48e518
 
 
d649fba
ee53056
9c0a84a
 
 
 
 
 
 
 
 
2ed7223
ab07d9e
dcc2eea
2ed7223
9c0a84a
2ed7223
 
9c0a84a
c591299
9c0a84a
 
 
 
 
 
dcc2eea
 
 
e48e518
dcc2eea
2ed7223
 
c621812
d649fba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import transformers
import gradio as gr
import librosa
import torch
import spaces
import numpy as np

# Initialize the conversation history globally
conversation_history = []

@spaces.GPU(duration=120)
def transcribe_and_respond(audio_file, chat_history):
    try:
        pipe = transformers.pipeline(
            model='sarvamai/shuka_v1',
            trust_remote_code=True,
            device=0,
            torch_dtype=torch.bfloat16
        )

        # Load the audio file
        audio, sr = librosa.load(audio_file, sr=16000)

        # Debug: Print audio properties for debugging
        print(f"Audio dtype: {audio.dtype}, Audio shape: {audio.shape}, Sample rate: {sr}")

        # Prepare conversation turns
        turns = chat_history.copy()  # Take the existing chat history and append user input
        turns.append({'role': 'user', 'content': '<|audio|>'})

        # Debug: Print the updated turns for debugging purposes
        print(f"Updated turns: {turns}")

        # Call the model with the updated conversation turns and audio
        output = pipe({'audio': audio, 'turns': turns, 'sampling_rate': sr}, max_new_tokens=512)

        # Append the model's response to the conversation history
        turns.append({'role': 'system', 'content': output})

        # Debug: Print the model's response
        print(f"Model output: {output}")

        # Format the chat history for Gradio's Chatbot
        chat_history_for_display = []
        for turn in turns:
            if turn['role'] == 'user':
                chat_history_for_display.append(("User", "🗣️ (Spoken Audio)"))
            else:
                chat_history_for_display.append(("AI", turn['content']))

        return chat_history_for_display, turns  # Return the formatted chat history for display and the updated history

    except Exception as e:
        return f"Error: {str(e)}", chat_history  # Ensure history is returned even on error

# Define the Gradio interface
iface = gr.Interface(
    fn=transcribe_and_respond,
    inputs=[
        gr.Audio(sources="microphone", type="filepath", label="Your Audio (Microphone)"), 
        gr.State([])  # Hidden state to maintain conversation history
    ],
    outputs=[
        gr.Chatbot(label="Conversation History"),  # Display the conversation
        gr.State([])  # Hidden state to keep track of the updated conversation history
    ],
    title="Shuka demo",
    description="shuka live demo",
    live=True,  # Enable live mode for real-time interaction
    allow_flagging="auto",
#    enable_queue=True
)

if __name__ == "__main__":
    iface.launch()