import gradio as gr from PIL import Image from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor import spaces import torch import os access_token = os.getenv('HF_token') model_id = "selamw/BirdWatcher" # model_id = "selamw/bird-Identifier" bnb_config = BitsAndBytesConfig(load_in_8bit=True) def convert_to_markdown(input_text): """Converts bird information text to Markdown format, making specific keywords bold and adding headings. Args: input_text (str): The input text containing bird information. Returns: str: The formatted Markdown text. """ bold_words = ['Look:', 'Cool Fact!:', 'Habitat:', 'Food:', 'Birdie Behaviors:'] # Split into title and content based on the first ":", handling extra whitespace title, content = map(str.strip, input_text.split(":", 1)) # Bold the keywords for word in bold_words: # content = content.replace(word, f'\n\n**{word}\n') content = content.replace(word, f'\n\n**{word}') # content = content.replace(f': **', f':**') # Construct the Markdown output with headings formatted_output = f"**{title}**{content}" return formatted_output.strip() @spaces.GPU def infer_fin_pali(image, question): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token) processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token) inputs = processor(images=image, text=question, return_tensors="pt").to(device) predictions = model.generate(**inputs, max_new_tokens=512) decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n") # Ensure proper Markdown formatting formatted_output = convert_to_markdown(decoded_output) # formatted_output = (decoded_output) return formatted_output css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } h1 { text-align: center; } h3 { text-align: center; } h2 { text-align: left; } span.gray-text { color: gray; } """ with gr.Blocks(css=css) as demo: gr.HTML("