File size: 3,207 Bytes
b6ab738 1afc2d5 72073e1 1afc2d5 8cd9c33 1afc2d5 3f88864 8cd9c33 1afc2d5 8cd9c33 1afc2d5 8cd9c33 3f88864 1afc2d5 3f88864 1afc2d5 3f88864 b6ab738 8cd9c33 72073e1 3f88864 8cd9c33 72073e1 1afc2d5 8cd9c33 72073e1 3f88864 b6ab738 1afc2d5 8cd9c33 1afc2d5 b6ab738 8cd9c33 9b4d9ff 8cd9c33 9ede33d 8cd9c33 3f88864 2cb303a 8cd9c33 3f88864 9ede33d 8cd9c33 3f88864 9ede33d 8cd9c33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
from PIL import Image
from threading import Thread
# Initialize model and processor
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")
# Initialize inference clients
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
def llava(inputs):
"""Processes an image and text input using Llava."""
image = Image.open(inputs["files"][0]).convert("RGB")
prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
processed = processor(prompt, image, return_tensors="pt").to("cpu")
return processed
def respond(message, history):
"""Generate a response based on text or image input."""
if "files" in message and message["files"]:
# Handle image + text input
inputs = llava(message)
streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
history[-1][1] = buffer # Update the latest message in history
yield history, history # Yield both chatbot and history for updating
else:
# Handle text-only input
user_message = message["text"]
history.append([user_message, None]) # Add user's message with a placeholder response
# Prepare prompt for the language model
prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
response = client_gemma.chat_completion(prompt, max_tokens=200)
# Extract response and update history
bot_message = response["choices"][0]["message"]["content"]
history[-1][1] = bot_message # Update the latest message with bot's response
yield history, history # Yield both chatbot and history for updating
def generate_image(prompt):
"""Generates an image based on the user prompt."""
client = InferenceClient("KingNish/Image-Gen-Pro")
return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
# Set up Gradio interface
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column():
text_input = gr.Textbox(placeholder="Enter your message...")
file_input = gr.File(label="Upload an image")
def handle_text(text, history=[]):
"""Handle text input and generate responses."""
return respond({"text": text}, history)
def handle_file_upload(files, history=[]):
"""Handle file uploads and generate responses."""
return respond({"files": files, "text": "Describe this image."}, history)
# Connect components to callbacks
text_input.submit(handle_text, [text_input, chatbot], [chatbot, chatbot])
file_input.change(handle_file_upload, [file_input, chatbot], [chatbot, chatbot])
# Launch the Gradio app
demo.launch() |