Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
import torch | |
from PIL import Image | |
# Model and Processor Initialization | |
models = { | |
"microsoft/Phi-3.5-vision-instruct": AutoModelForCausalLM.from_pretrained( | |
"microsoft/Phi-3.5-vision-instruct", | |
trust_remote_code=True, | |
torch_dtype="auto", | |
_attn_implementation="flash_attention_2" | |
).cuda().eval() | |
} | |
processors = { | |
"microsoft/Phi-3.5-vision-instruct": AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True) | |
} | |
# Default question | |
default_question = ( | |
"You are an image-to-prompt converter. Your work is to observe each and every detail of the image and " | |
"craft a detailed prompt under 100 words in this format: [image content/subject, description of action, state, " | |
"and mood], [art form, style], [artist/photographer reference if needed], [additional settings such as camera " | |
"and lens settings, lighting, colors, effects, texture, background, rendering]." | |
) | |
# Function to generate prompt | |
def generate_caption(image): | |
model = models["microsoft/Phi-3.5-vision-instruct"] | |
processor = processors["microsoft/Phi-3.5-vision-instruct"] | |
prompt = f"<|user|>\n<|image_1|>\n{default_question}<|end|>\n<|assistant|>\n" | |
image = Image.fromarray(image).convert("RGB") | |
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0") | |
generate_ids = model.generate( | |
**inputs, | |
max_new_tokens=1000, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
) | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
response = processor.batch_decode( | |
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
return response | |
# Enhanced CSS for streamlined UI | |
css = """ | |
#container { | |
background-color: #f9f9f9; | |
padding: 20px; | |
border-radius: 15px; | |
border: 2px solid #333; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); | |
max-width: 450px; | |
margin: auto; | |
} | |
#input_image { | |
margin-top: 15px; | |
border: 2px solid #333; | |
border-radius: 8px; | |
height: 180px; | |
object-fit: contain; | |
} | |
#output_caption { | |
margin-top: 15px; | |
border: 2px solid #333; | |
border-radius: 8px; | |
height: 180px; | |
overflow-y: auto; | |
} | |
#run_button { | |
background-color: #fff; | |
color: black; | |
border-radius: 10px; | |
padding: 10px; | |
cursor: pointer; | |
transition: background-color 0.3s ease; | |
margin-top: 15px; | |
} | |
#run_button:hover { | |
background-color: #333; | |
} | |
""" | |
# Gradio Interface with Adjustments | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="container"): | |
input_image = gr.Image(type="pil", elem_id="input_image", label="Upload Image") | |
run_button = gr.Button(value="Generate Prompt", elem_id="run_button") | |
output_caption = gr.Textbox(label="Generated Prompt", show_copy_button=True, elem_id="output_caption", lines=6) | |
run_button.click( | |
fn=generate_caption, | |
inputs=[input_image], | |
outputs=output_caption, | |
) | |
demo.launch(share=False) | |