captain1-1 / app.py
mrbeliever's picture
Update app.py
3901da8 verified
raw
history blame
3.05 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
models = {
"microsoft/Phi-3.5-vision-instruct": AutoModelForCausalLM.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
}
processors = {
"microsoft/Phi-3.5-vision-instruct": AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True)
}
default_question = "You are an image to prompt converter. Your work is to observe each and every detail of the image and craft a detailed prompt under 100 words in this format: [image content/subject, description of action, state, and mood], [art form, style], [artist/photographer reference if needed], [additional settings such as camera and lens settings, lighting, colors, effects, texture, background, rendering]."
@spaces.GPU
def run_example(image, text_input=default_question, model_id="microsoft/Phi-3.5-vision-instruct"):
model = models[model_id]
processor = processors[model_id]
prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
image = Image.fromarray(image).convert("RGB")
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
generate_ids = model.generate(**inputs,
max_new_tokens=1000,
eos_token_id=processor.tokenizer.eos_token_id,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
return response
css = """
.container {
border: 2px solid #333;
padding: 20px;
max-width: 400px;
margin: auto;
}
#input_img, #output_text {
border: 2px solid #333;
width: 100%;
height: 300px;
object-fit: cover;
}
.gr-button {
width: 100%;
margin-top: 10px;
}
#copy_button {
float: right;
margin-top: -30px;
cursor: pointer;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Box(elem_id="container"):
input_img = gr.Image(label="Input Picture", elem_id="input_img", type="pil")
generate_button = gr.Button("Generate Prompt", elem_id="generate_button")
with gr.Row():
output_text = gr.Textbox(label="Output Text", elem_id="output_text", interactive=False)
copy_button = gr.Button("Copy", elem_id="copy_button")
# Copy functionality
copy_button.click(fn=lambda text: text, inputs=output_text, outputs=None)
# Generate functionality
generate_button.click(run_example, [input_img, default_question], [output_text])
demo.queue(api_open=False)
demo.launch(debug=True, show_api=False)