import gradio as gr from PIL import Image from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor import spaces import torch import os access_token = os.getenv('HF_token') # model_id = "selamw/BirdWatcher-AI" model_id = "selamw/bird-Identifier" bnb_config = BitsAndBytesConfig(load_in_8bit=True) def convert_to_markdown(input_text): # Split the input text into sections based on the '**' delimiter sections = input_text.split("**") # Initialize the formatted output with the bird name formatted_output = f"**{sections[0].strip()}**\n" # Process each section to format it for i in range(1, len(sections), 2): if i + 1 < len(sections): # Use '##' for subheadings and clean up the text header = sections[i].strip() + "** " content = sections[i + 1].strip() formatted_output += f"\n**{header}{content}\n" # Return the formatted output return formatted_output.strip() @spaces.GPU def infer_fin_pali(image, question): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token) processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token) inputs = processor(images=image, text=question, return_tensors="pt").to(device) predictions = model.generate(**inputs, max_new_tokens=512) decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n") # Ensure proper Markdown formatting formatted_output = convert_to_markdown(decoded_output) return formatted_output css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } h1 { text-align: center; } h3 { text-align: center; } h2 { text-align: left; } span.gray-text { color: gray; } """ with gr.Blocks(css=css) as demo: gr.HTML("

🦩 BirdWatcher AI 🦜

") gr.HTML("

Upload an image of a bird, and the model will generate a detailed description of its species.

") with gr.Tab(label="Bird Identification"): with gr.Row(): input_img = gr.Image(label="Input Bird Image") with gr.Column(): with gr.Row(): question = gr.Text(label="Default Prompt", value="Describe this bird species", elem_id="default-prompt") with gr.Row(): submit_btn = gr.Button(value="Run") with gr.Row(): output = gr.Markdown(label="Response") # Use Markdown component to display output # output = gr.Text(label="Response") # Use Markdown component to display output submit_btn.click(infer_fin_pali, [input_img, question], [output]) gr.Examples( [["01.jpg", "Describe this bird species"], ["02.jpg", "Describe this bird species"], ["03.jpg", "Describe this bird species"], ["04.jpeg", "Describe this bird species"]], inputs=[input_img, question], outputs=[output], fn=infer_fin_pali, label='Examples 👇' ) demo.launch(debug=True)