captain1-1 / app.py
mrbeliever's picture
Update app.py
ebb3eb4 verified
raw
history blame
3.75 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, GPT2LMHeadModel, GPT2Tokenizer
import torch
from PIL import Image
import subprocess
# Install flash-attn with no CUDA build isolation
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Define models and processors with pinning to a stable revision
models = {
"microsoft/Phi-3.5-vision-instruct": AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3.5-vision-instruct",
revision="specific-revision-hash", # Pinning to a specific revision for stability
trust_remote_code=True,
torch_dtype="auto",
_attn_implementation="flash_attention_2"
).cuda().eval()
}
processors = {
"microsoft/Phi-3.5-vision-instruct": AutoProcessor.from_pretrained(
"microsoft/Phi-3.5-vision-instruct",
revision="specific-revision-hash", # Pinning to a specific revision for stability
trust_remote_code=True
)
}
# Fallback to GPT-2 for testing
def load_fallback_model():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").cuda().eval()
return model, tokenizer
# Default description and prompt
DESCRIPTION = "[Phi-3.5-vision Demo](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)"
default_question = "You are an image to prompt converter. Your work is to observe each and every detail of the image and craft a detailed prompt under 100 words."
# Gradio function for generating output from image input with error handling
@spaces.GPU
def run_example(image, text_input=default_question, model_id="microsoft/Phi-3.5-vision-instruct"):
try:
model = models[model_id]
processor = processors[model_id]
except KeyError as e:
print(f"Error loading model: {e}. Falling back to GPT-2.")
model, processor = load_fallback_model()
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"
prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
image = Image.fromarray(image).convert("RGB")
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
generate_ids = model.generate(**inputs, max_new_tokens=1000, eos_token_id=processor.tokenizer.eos_token_id)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response
# Custom CSS for styling
css = """
#output_text {
height: 500px;
overflow: auto;
border: 1px solid #333;
}
#model_selector, #text_input {
display: none !important;
}
#main_container {
border: 2px solid black;
padding: 20px;
border-radius: 10px;
}
"""
# Gradio interface with styling and layout improvements
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row(id="main_container"):
with gr.Column():
input_img = gr.Image(label="Input Image", interactive=True)
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="microsoft/Phi-3.5-vision-instruct", visible=False)
text_input = gr.Textbox(label="Question", value=default_question, visible=False)
submit_btn = gr.Button(value="Generate Prompt")
output_text = gr.Textbox(label="Output", id="output_text", interactive=False)
# Link button action to function
submit_btn.click(run_example, [input_img, text_input, model_selector], output_text)
# Launch Gradio interface
demo.queue(api_open=False)
demo.launch(debug=True, show_api=False)