Update app.py
Browse files
app.py
CHANGED
|
@@ -20,7 +20,7 @@ from langchain.chains import RetrievalQA
|
|
| 20 |
import json
|
| 21 |
|
| 22 |
# Initialize models and clients
|
| 23 |
-
MODEL = '
|
| 24 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
| 25 |
|
| 26 |
vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
|
|
@@ -54,8 +54,8 @@ def classify_function(user_prompt):
|
|
| 54 |
You are a function classifier AI assistant. You are given a user input and you need to classify it into one of the following functions:
|
| 55 |
|
| 56 |
- `image_generation`: If the user wants to generate an image.
|
| 57 |
-
- `
|
| 58 |
-
- `
|
| 59 |
- `text_to_text`: If the user wants a text-based response.
|
| 60 |
|
| 61 |
Respond with a JSON object containing only the chosen function. For example:
|
|
@@ -137,6 +137,13 @@ def handle_input(user_prompt, image=None, audio=None, websearch=False, document=
|
|
| 137 |
answer = tavily_client.qna_search(query=user_prompt)
|
| 138 |
return answer, None
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
# Classify user input using LLM
|
| 141 |
function = classify_function(user_prompt)
|
| 142 |
|
|
|
|
| 20 |
import json
|
| 21 |
|
| 22 |
# Initialize models and clients
|
| 23 |
+
MODEL = 'llama-3.1-70b-versatile'
|
| 24 |
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
| 25 |
|
| 26 |
vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
|
|
|
|
| 54 |
You are a function classifier AI assistant. You are given a user input and you need to classify it into one of the following functions:
|
| 55 |
|
| 56 |
- `image_generation`: If the user wants to generate an image.
|
| 57 |
+
- `image_vqa`: If the user wants to ask questions about an image.
|
| 58 |
+
- `document_qa`: If the user wants to ask questions about a document.
|
| 59 |
- `text_to_text`: If the user wants a text-based response.
|
| 60 |
|
| 61 |
Respond with a JSON object containing only the chosen function. For example:
|
|
|
|
| 137 |
answer = tavily_client.qna_search(query=user_prompt)
|
| 138 |
return answer, None
|
| 139 |
|
| 140 |
+
# Handle cases with only image or document input
|
| 141 |
+
if user_prompt is None or user_prompt.strip() == "":
|
| 142 |
+
if image:
|
| 143 |
+
user_prompt = "Describe this image"
|
| 144 |
+
elif document:
|
| 145 |
+
user_prompt = "Summarize this document"
|
| 146 |
+
|
| 147 |
# Classify user input using LLM
|
| 148 |
function = classify_function(user_prompt)
|
| 149 |
|