BirdWatcher / app.py
selamw's picture
Update app.py
aa33474 verified
raw
history blame
3.38 kB
import gradio as gr
from PIL import Image
from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
import torch
import os
import re
access_token = os.getenv('HF_token')
model_id = "selamw/BirdWatcher-AI"
# model_id = "selamw/bird-Identifier"
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
def convert_to_markdown(input_text):
input_text = input_text.replace("!:", ":")
# Find all words before ': **' and replace with bold markdown
output_text = re.sub(r'(\w+)\s*:\s*\*\*', r'**\1**:', input_text)
# Replace double asterisks with double hashtags for remaining headings
output_text = output_text.replace("**", "##")
# Remove any extra whitespace at the beginning of lines
output_text = re.sub(r'^\s+', '', output_text, flags=re.MULTILINE)
# Add an extra newline after each heading (after the colon)
output_text = "## " + re.sub(r'(##\s*.*):', r'\1:\n\n', output_text)
return output_text
@spaces.GPU
def infer_fin_pali(image, question):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token)
processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token)
inputs = processor(images=image, text=question, return_tensors="pt").to(device)
predictions = model.generate(**inputs, max_new_tokens=512)
decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
# Ensure proper Markdown formatting
formatted_output = convert_to_markdown(decoded_output)
# formatted_output = (decoded_output)
return formatted_output
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
h1 {
text-align: center;
}
h3 {
text-align: center;
}
h2 {
text-align: left;
}
span.gray-text {
color: gray;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1>🦩 BirdWatcher AI 🦜</h1>")
gr.HTML("<h3>Upload an image of a bird, and the model will generate a detailed description of its species.</h3>")
with gr.Tab(label="Bird Identification"):
with gr.Row():
input_img = gr.Image(label="Input Bird Image")
with gr.Column():
with gr.Row():
question = gr.Text(label="Default Prompt", value="Describe this bird species", elem_id="default-prompt")
with gr.Row():
submit_btn = gr.Button(value="Run")
with gr.Row():
output = gr.Markdown(label="Response") # Use Markdown component to display output
# output = gr.Text(label="Response") # Use Markdown component to display output
submit_btn.click(infer_fin_pali, [input_img, question], [output])
gr.Examples(
[["01.jpg", "Describe this bird species"],
["02.jpg", "Describe this bird species"],
["03.jpg", "Describe this bird species"],
["04.jpeg", "Describe this bird species"]],
inputs=[input_img, question],
outputs=[output],
fn=infer_fin_pali,
label='Examples πŸ‘‡'
)
demo.launch(debug=True)