Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -8,21 +8,16 @@ from transformers import AutoModel, AutoTokenizer | |
| 8 | 
             
            from diffusers import StableDiffusion3Pipeline
         | 
| 9 | 
             
            from parler_tts import ParlerTTSForConditionalGeneration
         | 
| 10 | 
             
            import soundfile as sf
         | 
| 11 | 
            -
            from langchain.agents import AgentExecutor, create_react_agent, initialize_agent, Tool
         | 
| 12 | 
            -
            from langchain.agents import AgentType
         | 
| 13 | 
             
            from langchain_groq import ChatGroq
         | 
| 14 | 
            -
            from langchain.prompts import PromptTemplate
         | 
| 15 | 
             
            from PIL import Image
         | 
| 16 | 
             
            from tavily import TavilyClient
         | 
| 17 | 
            -
            import requests
         | 
| 18 | 
            -
            from huggingface_hub import hf_hub_download
         | 
| 19 | 
            -
            from safetensors.torch import load_file
         | 
| 20 | 
             
            from langchain.schema import AIMessage
         | 
| 21 | 
             
            from langchain_community.embeddings import HuggingFaceEmbeddings
         | 
| 22 | 
             
            from langchain_community.vectorstores import FAISS
         | 
| 23 | 
             
            from langchain_community.document_loaders import TextLoader
         | 
| 24 | 
             
            from langchain.text_splitter import CharacterTextSplitter
         | 
| 25 | 
             
            from langchain.chains import RetrievalQA
         | 
|  | |
| 26 |  | 
| 27 | 
             
            # Initialize models and clients
         | 
| 28 | 
             
            MODEL = 'llama3-groq-70b-8192-tool-use-preview'
         | 
| @@ -53,54 +48,46 @@ def play_voice_output(response): | |
| 53 | 
             
                sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
         | 
| 54 | 
             
                return "output.wav"
         | 
| 55 |  | 
| 56 | 
            -
            #  | 
| 57 | 
            -
             | 
| 58 | 
            -
                 | 
| 59 | 
            -
                 | 
| 60 | 
            -
             | 
| 61 | 
            -
                def _run(self, query: str) -> str:
         | 
| 62 | 
            -
                    print("Executing NumpyCodeCalculator tool")
         | 
| 63 | 
            -
                    try:
         | 
| 64 | 
            -
                        local_dict = {"np": np}
         | 
| 65 | 
            -
                        exec(query, local_dict)
         | 
| 66 | 
            -
                        result = local_dict.get("result", "No result found")
         | 
| 67 | 
            -
                        return str(result)
         | 
| 68 | 
            -
                    except Exception as e:
         | 
| 69 | 
            -
                        return f"Error: {e}"
         | 
| 70 | 
            -
             | 
| 71 | 
            -
            # Web Search Tool
         | 
| 72 | 
            -
            class WebSearch(Tool):
         | 
| 73 | 
            -
                name = "Web"
         | 
| 74 | 
            -
                description = "Useful for advanced web searching beyond general information"
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                def _run(self, query: str) -> str:
         | 
| 77 | 
            -
                    print("Executing WebSearch tool")
         | 
| 78 | 
            -
                    answer = tavily_client.qna_search(query=query)
         | 
| 79 | 
            -
                    return answer
         | 
| 80 | 
            -
             | 
| 81 | 
            -
            # Image Generation Tool
         | 
| 82 | 
            -
            class ImageGeneration(Tool):
         | 
| 83 | 
            -
                name = "Image"
         | 
| 84 | 
            -
                description = "Useful for generating images based on text descriptions"
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                def _run(self, query: str) -> str:
         | 
| 87 | 
            -
                    print("Executing ImageGeneration tool")
         | 
| 88 | 
            -
                    image = pipe(
         | 
| 89 | 
            -
                        query,
         | 
| 90 | 
            -
                        negative_prompt="",
         | 
| 91 | 
            -
                        num_inference_steps=15,
         | 
| 92 | 
            -
                        guidance_scale=7.0,
         | 
| 93 | 
            -
                    ).images[0]
         | 
| 94 | 
            -
                    image.save("output.jpg")
         | 
| 95 | 
            -
                    return "output.jpg"
         | 
| 96 |  | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
                 | 
| 100 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 101 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 102 | 
             
                def __init__(self, document):
         | 
| 103 | 
            -
                    super().__init__()
         | 
| 104 | 
             
                    self.document = document
         | 
| 105 | 
             
                    self.qa_chain = self._setup_qa_chain()
         | 
| 106 |  | 
| @@ -120,79 +107,94 @@ class DocumentQuestionAnswering(Tool): | |
| 120 | 
             
                    )
         | 
| 121 | 
             
                    return qa_chain
         | 
| 122 |  | 
| 123 | 
            -
                def  | 
| 124 | 
             
                    print("Executing DocumentQuestionAnswering tool")
         | 
| 125 | 
             
                    response = self.qa_chain.run(query)
         | 
| 126 | 
             
                    return str(response)
         | 
| 127 |  | 
| 128 | 
            -
             | 
| 129 | 
            -
            # Function to handle different input types and choose the right tool
         | 
| 130 | 
             
            def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
         | 
| 131 | 
             
                print(f"Handling input: {user_prompt}")
         | 
| 132 |  | 
| 133 | 
             
                # Initialize the LLM
         | 
| 134 | 
             
                llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
         | 
| 135 |  | 
| 136 | 
            -
                #  | 
| 137 | 
            -
                 | 
| 138 | 
            -
             | 
| 139 | 
            -
                # Add Image Generation Tool
         | 
| 140 | 
            -
                tools.append(ImageGeneration())
         | 
| 141 | 
            -
             | 
| 142 | 
            -
                # Add Calculator Tool
         | 
| 143 | 
            -
                tools.append(NumpyCodeCalculator())
         | 
| 144 | 
            -
             | 
| 145 | 
            -
                # Add Web Search Tool if enabled
         | 
| 146 | 
            -
                if websearch:
         | 
| 147 | 
            -
                    tools.append(WebSearch())
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                # Add Document QA Tool if document is provided
         | 
| 150 | 
            -
                if document:
         | 
| 151 | 
            -
                    tools.append(DocumentQuestionAnswering(document))
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                # Check if any tools are mentioned in the user prompt
         | 
| 154 | 
            -
                requires_tool = any([tool.name.lower() in user_prompt.lower() for tool in tools])
         | 
| 155 | 
            -
             | 
| 156 | 
            -
                # Handle different input scenarios
         | 
| 157 | 
            -
                if image:
         | 
| 158 | 
            -
                    print("Processing image input")
         | 
| 159 | 
            -
                    image = Image.open(image).convert('RGB')
         | 
| 160 | 
            -
                    messages = [{"role": "user", "content": [image, user_prompt]}]
         | 
| 161 | 
            -
                    response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
         | 
| 162 | 
            -
                elif audio:
         | 
| 163 | 
             
                    print("Processing audio input")
         | 
| 164 | 
             
                    transcription = client.audio.transcriptions.create(
         | 
| 165 | 
             
                        file=(audio.name, audio.read()),
         | 
| 166 | 
             
                        model="whisper-large-v3"
         | 
| 167 | 
             
                    )
         | 
| 168 | 
             
                    user_prompt = transcription.text
         | 
| 169 | 
            -
                     | 
| 170 | 
            -
                     | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
| 177 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 178 | 
             
                    else:
         | 
| 179 | 
            -
                         | 
| 180 | 
            -
             | 
| 181 | 
            -
             | 
| 182 | 
            -
                     | 
| 183 | 
            -
             | 
| 184 | 
            -
                         | 
| 185 | 
            -
                         | 
| 186 | 
            -
                         | 
| 187 | 
            -
                     | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
|  | |
| 191 | 
             
                    response = llm.call(query=user_prompt)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 192 |  | 
| 193 | 
             
                return response
         | 
| 194 |  | 
| 195 | 
            -
                
         | 
| 196 | 
             
            def create_ui():
         | 
| 197 | 
             
                with gr.Blocks(css="""
         | 
| 198 | 
             
                    /* Overall Styling */
         | 
| @@ -403,40 +405,6 @@ def create_ui(): | |
| 403 |  | 
| 404 | 
             
                return demo
         | 
| 405 |  | 
| 406 | 
            -
            # Main interface function
         | 
| 407 | 
            -
            @spaces.GPU(duration=720)
         | 
| 408 | 
            -
            def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
         | 
| 409 | 
            -
                print("Starting main_interface function")
         | 
| 410 | 
            -
                vqa_model.to(device='cuda', dtype=torch.bfloat16)
         | 
| 411 | 
            -
                tts_model.to("cuda")
         | 
| 412 | 
            -
                pipe.to("cuda")
         | 
| 413 | 
            -
             | 
| 414 | 
            -
                print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
         | 
| 415 | 
            -
             | 
| 416 | 
            -
                try:
         | 
| 417 | 
            -
                    response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
         | 
| 418 | 
            -
                    print("handle_input function executed successfully")
         | 
| 419 | 
            -
                except Exception as e:
         | 
| 420 | 
            -
                    print(f"Error in handle_input: {e}")
         | 
| 421 | 
            -
                    response = "Error occurred during processing."
         | 
| 422 | 
            -
             | 
| 423 | 
            -
                if voice_only:
         | 
| 424 | 
            -
                    try:
         | 
| 425 | 
            -
                        transcription = client.audio.transcriptions.create(
         | 
| 426 | 
            -
                            file=("input.wav", open("input.wav", "rb").read()),
         | 
| 427 | 
            -
                            model="whisper-large-v3"
         | 
| 428 | 
            -
                        )
         | 
| 429 | 
            -
                        user_prompt = transcription.text
         | 
| 430 | 
            -
                        response = handle_input(user_prompt)
         | 
| 431 | 
            -
                        audio_output = play_voice_output(response)
         | 
| 432 | 
            -
                        print("play_voice_output function executed successfully")
         | 
| 433 | 
            -
                        return "Response generated.", audio_output
         | 
| 434 | 
            -
                    except Exception as e:
         | 
| 435 | 
            -
                        print(f"Error in play_voice_output: {e}")
         | 
| 436 | 
            -
                        return "Error occurred during voice output.", None
         | 
| 437 | 
            -
                else:
         | 
| 438 | 
            -
                    return response, None
         | 
| 439 | 
            -
                    
         | 
| 440 | 
             
            # Launch the UI
         | 
| 441 | 
             
            demo = create_ui()
         | 
| 442 | 
             
            demo.launch()
         | 
|  | |
| 8 | 
             
            from diffusers import StableDiffusion3Pipeline
         | 
| 9 | 
             
            from parler_tts import ParlerTTSForConditionalGeneration
         | 
| 10 | 
             
            import soundfile as sf
         | 
|  | |
|  | |
| 11 | 
             
            from langchain_groq import ChatGroq
         | 
|  | |
| 12 | 
             
            from PIL import Image
         | 
| 13 | 
             
            from tavily import TavilyClient
         | 
|  | |
|  | |
|  | |
| 14 | 
             
            from langchain.schema import AIMessage
         | 
| 15 | 
             
            from langchain_community.embeddings import HuggingFaceEmbeddings
         | 
| 16 | 
             
            from langchain_community.vectorstores import FAISS
         | 
| 17 | 
             
            from langchain_community.document_loaders import TextLoader
         | 
| 18 | 
             
            from langchain.text_splitter import CharacterTextSplitter
         | 
| 19 | 
             
            from langchain.chains import RetrievalQA
         | 
| 20 | 
            +
            import json
         | 
| 21 |  | 
| 22 | 
             
            # Initialize models and clients
         | 
| 23 | 
             
            MODEL = 'llama3-groq-70b-8192-tool-use-preview'
         | 
|  | |
| 48 | 
             
                sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
         | 
| 49 | 
             
                return "output.wav"
         | 
| 50 |  | 
| 51 | 
            +
            # Function to classify user input using LLM
         | 
| 52 | 
            +
            def classify_function(user_prompt):
         | 
| 53 | 
            +
                prompt = f"""
         | 
| 54 | 
            +
                You are a function classifier AI assistant. You are given a user input and you need to classify it into one of the following functions:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
            +
                - `image_generation`: If the user wants to generate an image.
         | 
| 57 | 
            +
                - `image_description`: If the user wants to describe an image.
         | 
| 58 | 
            +
                - `document_summarization`: If the user wants to summarize a document.
         | 
| 59 | 
            +
                - `text_to_text`: If the user wants a text-based response.
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                Respond with a JSON object containing only the chosen function. For example:
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                ```json
         | 
| 64 | 
            +
                {{"function": "image_generation"}}
         | 
| 65 | 
            +
                ```
         | 
| 66 |  | 
| 67 | 
            +
                User input: {user_prompt}
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                chat_completion = client.chat.completions.create(
         | 
| 71 | 
            +
                    messages=[
         | 
| 72 | 
            +
                        {
         | 
| 73 | 
            +
                            "role": "user",
         | 
| 74 | 
            +
                            "content": prompt,
         | 
| 75 | 
            +
                        }
         | 
| 76 | 
            +
                    ],
         | 
| 77 | 
            +
                    model="llama3-8b-8192",
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                try:
         | 
| 81 | 
            +
                    response = json.loads(chat_completion.choices[0].message.content)
         | 
| 82 | 
            +
                    function = response.get("function")
         | 
| 83 | 
            +
                    return function
         | 
| 84 | 
            +
                except json.JSONDecodeError:
         | 
| 85 | 
            +
                    print(f"Error decoding JSON: {chat_completion.choices[0].message.content}")
         | 
| 86 | 
            +
                    return "text_to_text"  # Default to text-to-text if JSON parsing fails
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            # Document Question Answering Tool
         | 
| 89 | 
            +
            class DocumentQuestionAnswering:
         | 
| 90 | 
             
                def __init__(self, document):
         | 
|  | |
| 91 | 
             
                    self.document = document
         | 
| 92 | 
             
                    self.qa_chain = self._setup_qa_chain()
         | 
| 93 |  | 
|  | |
| 107 | 
             
                    )
         | 
| 108 | 
             
                    return qa_chain
         | 
| 109 |  | 
| 110 | 
            +
                def run(self, query: str) -> str:
         | 
| 111 | 
             
                    print("Executing DocumentQuestionAnswering tool")
         | 
| 112 | 
             
                    response = self.qa_chain.run(query)
         | 
| 113 | 
             
                    return str(response)
         | 
| 114 |  | 
| 115 | 
            +
            # Function to handle different input types and choose the right pipeline
         | 
|  | |
| 116 | 
             
            def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
         | 
| 117 | 
             
                print(f"Handling input: {user_prompt}")
         | 
| 118 |  | 
| 119 | 
             
                # Initialize the LLM
         | 
| 120 | 
             
                llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
         | 
| 121 |  | 
| 122 | 
            +
                # Handle voice-only mode
         | 
| 123 | 
            +
                if audio:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 124 | 
             
                    print("Processing audio input")
         | 
| 125 | 
             
                    transcription = client.audio.transcriptions.create(
         | 
| 126 | 
             
                        file=(audio.name, audio.read()),
         | 
| 127 | 
             
                        model="whisper-large-v3"
         | 
| 128 | 
             
                    )
         | 
| 129 | 
             
                    user_prompt = transcription.text
         | 
| 130 | 
            +
                    response = llm.call(query=user_prompt)
         | 
| 131 | 
            +
                    audio_output = play_voice_output(response)
         | 
| 132 | 
            +
                    return "Response generated.", audio_output
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                # Handle websearch mode
         | 
| 135 | 
            +
                if websearch:
         | 
| 136 | 
            +
                    print("Executing Web Search")
         | 
| 137 | 
            +
                    answer = tavily_client.qna_search(query=user_prompt)
         | 
| 138 | 
            +
                    return answer, None
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # Classify user input using LLM
         | 
| 141 | 
            +
                function = classify_function(user_prompt)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # Handle different functions
         | 
| 144 | 
            +
                if function == "image_generation":
         | 
| 145 | 
            +
                    print("Executing Image Generation")
         | 
| 146 | 
            +
                    image = pipe(
         | 
| 147 | 
            +
                        user_prompt,
         | 
| 148 | 
            +
                        negative_prompt="",
         | 
| 149 | 
            +
                        num_inference_steps=15,
         | 
| 150 | 
            +
                        guidance_scale=7.0,
         | 
| 151 | 
            +
                    ).images[0]
         | 
| 152 | 
            +
                    image.save("output.jpg")
         | 
| 153 | 
            +
                    return "output.jpg", None
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                elif function == "image_description":
         | 
| 156 | 
            +
                    print("Executing Image Description")
         | 
| 157 | 
            +
                    if image:
         | 
| 158 | 
            +
                        image = Image.open(image).convert('RGB')
         | 
| 159 | 
            +
                        messages = [{"role": "user", "content": [image, user_prompt]}]
         | 
| 160 | 
            +
                        response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
         | 
| 161 | 
            +
                        return response, None
         | 
| 162 | 
             
                    else:
         | 
| 163 | 
            +
                        return "Please upload an image.", None
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                elif function == "document_summarization":
         | 
| 166 | 
            +
                    print("Executing Document Summarization")
         | 
| 167 | 
            +
                    if document:
         | 
| 168 | 
            +
                        document_qa = DocumentQuestionAnswering(document)
         | 
| 169 | 
            +
                        response = document_qa.run(user_prompt)
         | 
| 170 | 
            +
                        return response, None
         | 
| 171 | 
            +
                    else:
         | 
| 172 | 
            +
                        return "Please upload a document.", None
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                else:  # function == "text_to_text"
         | 
| 175 | 
            +
                    print("Executing Text-to-Text")
         | 
| 176 | 
             
                    response = llm.call(query=user_prompt)
         | 
| 177 | 
            +
                    return response, None
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            # Main interface function
         | 
| 180 | 
            +
            @spaces.GPU(duration=720)
         | 
| 181 | 
            +
            def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
         | 
| 182 | 
            +
                print("Starting main_interface function")
         | 
| 183 | 
            +
                vqa_model.to(device='cuda', dtype=torch.bfloat16)
         | 
| 184 | 
            +
                tts_model.to("cuda")
         | 
| 185 | 
            +
                pipe.to("cuda")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                try:
         | 
| 190 | 
            +
                    response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
         | 
| 191 | 
            +
                    print("handle_input function executed successfully")
         | 
| 192 | 
            +
                except Exception as e:
         | 
| 193 | 
            +
                    print(f"Error in handle_input: {e}")
         | 
| 194 | 
            +
                    response = "Error occurred during processing."
         | 
| 195 |  | 
| 196 | 
             
                return response
         | 
| 197 |  | 
|  | |
| 198 | 
             
            def create_ui():
         | 
| 199 | 
             
                with gr.Blocks(css="""
         | 
| 200 | 
             
                    /* Overall Styling */
         | 
|  | |
| 405 |  | 
| 406 | 
             
                return demo
         | 
| 407 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 408 | 
             
            # Launch the UI
         | 
| 409 | 
             
            demo = create_ui()
         | 
| 410 | 
             
            demo.launch()
         | 
