import gradio as gr from PIL import Image from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor import spaces import torch import os import re access_token = os.getenv('HF_token') model_id = "selamw/BirdWatcher-AI" # model_id = "selamw/bird-Identifier" bnb_config = BitsAndBytesConfig(load_in_8bit=True) def convert_to_markdown(input_text): input_text = input_text.replace("!:", ":") # Find all words before ': **' and replace with bold markdown output_text = re.sub(r'(\w+)\s*:\s*\*\*', r'**\1**:', input_text) # Replace double asterisks with double hashtags for remaining headings output_text = output_text.replace("**", "##") # Remove any extra whitespace at the beginning of lines output_text = re.sub(r'^\s+', '', output_text, flags=re.MULTILINE) # Add an extra newline after each heading (after the colon) output_text = "## " + re.sub(r'(##\s*.*):', r'\1:\n\n', output_text) return output_text @spaces.GPU def infer_fin_pali(image, question): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token) processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token) inputs = processor(images=image, text=question, return_tensors="pt").to(device) predictions = model.generate(**inputs, max_new_tokens=512) decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n") # Ensure proper Markdown formatting formatted_output = convert_to_markdown(decoded_output) # formatted_output = (decoded_output) return formatted_output css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } h1 { text-align: center; } h3 { text-align: center; } h2 { text-align: left; } span.gray-text { color: gray; } """ with gr.Blocks(css=css) as demo: gr.HTML("