captain1-1 / app.py
mrbeliever's picture
Update app.py
2c8fc65 verified
raw
history blame
3.12 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
# Model and Processor Initialization
models = {
"microsoft/Phi-3.5-vision-instruct": AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
torch_dtype="auto",
_attn_implementation="flash_attention_2"
).cuda().eval()
}
processors = {
"microsoft/Phi-3.5-vision-instruct": AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True)
}
# Default question
default_question = (
"You are an image-to-prompt converter. Your work is to observe each and every detail of the image and "
"craft a detailed prompt under 100 words in this format: [image content/subject, description of action, state, "
"and mood], [art form, style], [artist/photographer reference if needed], [additional settings such as camera "
"and lens settings, lighting, colors, effects, texture, background, rendering]."
)
# Function to generate prompt
def generate_caption(image):
model = models["microsoft/Phi-3.5-vision-instruct"]
processor = processors["microsoft/Phi-3.5-vision-instruct"]
prompt = f"<|user|>\n<|image_1|>\n{default_question}<|end|>\n<|assistant|>\n"
image = Image.fromarray(image).convert("RGB")
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
generate_ids = model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=processor.tokenizer.eos_token_id,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return response
# Enhanced CSS for streamlined UI
css = """
#container {
background-color: #f9f9f9;
padding: 20px;
border-radius: 15px;
border: 2px solid #333;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
max-width: 450px;
margin: auto;
}
#input_image {
margin-top: 15px;
border: 2px solid #333;
border-radius: 8px;
height: 180px;
object-fit: contain;
}
#output_caption {
margin-top: 15px;
border: 2px solid #333;
border-radius: 8px;
height: 180px;
overflow-y: auto;
}
#run_button {
background-color: #fff;
color: black;
border-radius: 10px;
padding: 10px;
cursor: pointer;
transition: background-color 0.3s ease;
margin-top: 15px;
}
#run_button:hover {
background-color: #333;
}
"""
# Gradio Interface with Adjustments
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="container"):
input_image = gr.Image(type="pil", elem_id="input_image", label="Upload Image")
run_button = gr.Button(value="Generate Prompt", elem_id="run_button")
output_caption = gr.Textbox(label="Generated Prompt", show_copy_button=True, elem_id="output_caption", lines=6)
run_button.click(
fn=generate_caption,
inputs=[input_image],
outputs=output_caption,
)
demo.launch(share=False)