TuringsSolutions commited on
Commit
72073e1
·
verified ·
1 Parent(s): 1ab12a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -53
app.py CHANGED
@@ -1,51 +1,28 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- import json
4
- import uuid
5
- from PIL import Image
6
- from bs4 import BeautifulSoup
7
- import requests
8
- import random
9
  from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
 
10
  from threading import Thread
11
- import re
12
- import time
13
- import torch
14
 
15
  # Initialize model and processor
16
  model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
17
  processor = LlavaProcessor.from_pretrained(model_id)
18
  model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")
19
 
20
- # Initialize inference clients for different models
21
  client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
22
- client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
23
- client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
24
- client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat")
25
-
26
- def search(query):
27
- """Performs a Google search and extracts text from the top results."""
28
- session = requests.Session()
29
- response = session.get(f"https://www.google.com/search?q={query}",
30
- headers={"User-Agent": "Mozilla/5.0"})
31
- soup = BeautifulSoup(response.text, "html.parser")
32
- results = []
33
- for result in soup.find_all("div", class_="BNeawe vvjwJb AP7Wnd"):
34
- text = result.get_text()
35
- link = result.find_parent("a")["href"]
36
- results.append(f"{text}: {link}")
37
- return "\n".join(results[:3])
38
 
39
  def llava(inputs, history):
40
- """Processes an image and text input with Llava."""
41
  image = Image.open(inputs["files"][0]).convert("RGB")
42
  prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
43
  processed = processor(prompt, image, return_tensors="pt").to("cpu")
44
  return processed
45
 
46
  def respond(message, history):
47
- """Main response function for the chatbot."""
48
  if "files" in message and message["files"]:
 
49
  inputs = llava(message, history)
50
  streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
51
  thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
@@ -55,13 +32,21 @@ def respond(message, history):
55
  buffer += new_text
56
  yield buffer
57
  else:
58
- prompt = [{"role": "user", "content": msg[0]} for msg in history]
59
- prompt.append({"role": "user", "content": message["text"]})
 
 
 
 
60
  response = client_gemma.chat_completion(prompt, max_tokens=200)
61
- yield response["choices"][0]["message"]["content"]
 
 
 
 
62
 
63
  def generate_image(prompt):
64
- """Generates an image using the external model."""
65
  client = InferenceClient("KingNish/Image-Gen-Pro")
66
  return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
67
 
@@ -74,30 +59,18 @@ with gr.Blocks() as demo:
74
  file_input = gr.File(label="Upload an image")
75
  with gr.Column():
76
  output = gr.Image(label="Generated Image")
77
- with gr.Row():
78
- search_button = gr.Button("Search Google")
79
- image_button = gr.Button("Generate Image")
80
- examples = [
81
- {"text": "Who are you?"},
82
- {"text": "Generate an image of the Eiffel Tower at night."},
83
- {"text": "Search for the latest trends on YouTube."},
84
- ]
85
 
86
- def handle_text(text, state):
87
- response = respond({"text": text}, state)
88
- return response, state
89
 
90
- def handle_file_upload(files, state):
91
- response = respond({"files": files, "text": "Describe this image."}, state)
92
- return response, state
93
 
94
  # Connect components to callbacks
95
- text_input.submit(handle_text, [text_input], [chatbot])
96
- file_input.change(handle_file_upload, [file_input], [chatbot])
97
-
98
- # Search button functionality
99
- search_button.click(lambda query: search(query), [text_input], [chatbot])
100
- image_button.click(lambda text: generate_image(text), [text_input], [output])
101
 
102
- # Launch the Gradio interface
103
  demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
3
  from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
4
+ from PIL import Image
5
  from threading import Thread
 
 
 
6
 
7
  # Initialize model and processor
8
  model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
9
  processor = LlavaProcessor.from_pretrained(model_id)
10
  model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")
11
 
12
+ # Initialize inference clients
13
  client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def llava(inputs, history):
16
+ """Processes an image and text input using Llava."""
17
  image = Image.open(inputs["files"][0]).convert("RGB")
18
  prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
19
  processed = processor(prompt, image, return_tensors="pt").to("cpu")
20
  return processed
21
 
22
  def respond(message, history):
23
+ """Generate a response based on text or image input."""
24
  if "files" in message and message["files"]:
25
+ # Handle image + text input
26
  inputs = llava(message, history)
27
  streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
28
  thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
 
32
  buffer += new_text
33
  yield buffer
34
  else:
35
+ # Handle text-only input
36
+ user_message = message["text"]
37
+ history.append([user_message, None]) # Append user message to history
38
+
39
+ # Prepare prompt for the language model
40
+ prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
41
  response = client_gemma.chat_completion(prompt, max_tokens=200)
42
+
43
+ # Extract response and update history
44
+ bot_message = response["choices"][0]["message"]["content"]
45
+ history[-1][1] = bot_message # Update the last entry with bot's response
46
+ yield history
47
 
48
  def generate_image(prompt):
49
+ """Generates an image based on the user prompt."""
50
  client = InferenceClient("KingNish/Image-Gen-Pro")
51
  return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
52
 
 
59
  file_input = gr.File(label="Upload an image")
60
  with gr.Column():
61
  output = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
62
 
63
+ def handle_text(text, history=[]):
64
+ """Handle text input and generate responses."""
65
+ return respond({"text": text}, history), history
66
 
67
+ def handle_file_upload(files, history=[]):
68
+ """Handle file uploads and generate responses."""
69
+ return respond({"files": files, "text": "Describe this image."}, history), history
70
 
71
  # Connect components to callbacks
72
+ text_input.submit(handle_text, [text_input, chatbot], [chatbot])
73
+ file_input.change(handle_file_upload, [file_input, chatbot], [chatbot])
 
 
 
 
74
 
75
+ # Launch the Gradio app
76
  demo.launch()