test-gpt-omni / app.py
TuringsSolutions's picture
Update app.py
1afc2d5 verified
raw
history blame
4.07 kB
import gradio as gr
from huggingface_hub import InferenceClient
import json
import uuid
from PIL import Image
from bs4 import BeautifulSoup
import requests
import random
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
from threading import Thread
import re
import time
import torch
# Initialize model and processor
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")
# Initialize inference clients for different models
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat")
def search(query):
"""Performs a Google search and extracts text from the top results."""
session = requests.Session()
response = session.get(f"https://www.google.com/search?q={query}",
headers={"User-Agent": "Mozilla/5.0"})
soup = BeautifulSoup(response.text, "html.parser")
results = []
for result in soup.find_all("div", class_="BNeawe vvjwJb AP7Wnd"):
text = result.get_text()
link = result.find_parent("a")["href"]
results.append(f"{text}: {link}")
return "\n".join(results[:3])
def llava(inputs, history):
"""Processes an image and text input with Llava."""
image = Image.open(inputs["files"][0]).convert("RGB")
prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
processed = processor(prompt, image, return_tensors="pt").to("cpu")
return processed
def respond(message, history):
"""Main response function for the chatbot."""
if "files" in message and message["files"]:
inputs = llava(message, history)
streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
else:
prompt = [{"role": "user", "content": msg[0]} for msg in history]
prompt.append({"role": "user", "content": message["text"]})
response = client_gemma.chat_completion(prompt, max_tokens=200)
yield response["choices"][0]["message"]["content"]
def generate_image(prompt):
"""Generates an image using the external model."""
client = InferenceClient("KingNish/Image-Gen-Pro")
return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
# Set up Gradio interface
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column():
text_input = gr.Textbox(placeholder="Enter your message...")
file_input = gr.File(label="Upload an image")
with gr.Column():
output = gr.Image(label="Generated Image")
with gr.Row():
search_button = gr.Button("Search Google")
image_button = gr.Button("Generate Image")
examples = [
{"text": "Who are you?"},
{"text": "Generate an image of the Eiffel Tower at night."},
{"text": "Search for the latest trends on YouTube."},
]
def handle_text(text, state):
response = respond({"text": text}, state)
return response, state
def handle_file_upload(files, state):
response = respond({"files": files, "text": "Describe this image."}, state)
return response, state
# Connect components to callbacks
text_input.submit(handle_text, [text_input], [chatbot])
file_input.change(handle_file_upload, [file_input], [chatbot])
# Search button functionality
search_button.click(lambda query: search(query), [text_input], [chatbot])
image_button.click(lambda text: generate_image(text), [text_input], [output])
# Launch the Gradio interface
demo.launch()