Spaces:
Runtime error
Runtime error
File size: 3,353 Bytes
8ff5883 563f98d 7d58261 563f98d e0c81f0 563f98d bc3802f 6770ba9 bc3802f 7d58261 8e0a954 29ba44d 7d58261 8e0a954 7d58261 605b0aa 7d58261 14eb553 7d58261 c739636 6bf6d32 c739636 6bf6d32 c739636 252e8ea fb7a950 8e0a954 14eb553 fb7a950 bc3802f 14eb553 60eaa44 252e8ea c4d9f0b e80c4ee bc3802f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from threading import Thread
import gradio as gr
import torch
from transformers import AutoModel, AutoProcessor
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
device = "cuda:0" if torch.cuda.is_available() else "cpu"
title = """<h1 style="text-align: center;">Product description generator</h1>"""
css = """
div#col-container {
margin: 0 auto;
max-width: 840px;
}
"""
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [151645]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@torch.no_grad()
def response(history, image):
gr.Info('Starting...' + message)
stop = StopOnTokens()
message = "Generate a product title for the image"
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
if len(messages) == 1:
message = f" <image>{message}"
messages.append({"role": "user", "content": message})
model_inputs = processor.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
image = (
processor.feature_extractor(image)
.unsqueeze(0)
)
attention_mask = torch.ones(
1, model_inputs.shape[1] + processor.num_image_latents - 1
)
model_inputs = {
"input_ids": model_inputs,
"images": image,
"attention_mask": attention_mask
}
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
history.append([message, ""])
partial_response = ""
for new_token in streamer:
partial_response += new_token
history[-1][1] = partial_response
yield history
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(elem_id="col-container"):
image = gr.Image(type="pil")
message = gr.Textbox(interactive=True, show_label=False, container=False)
chat = gr.Chatbot(show_label=False)
submit = gr.Button(value="Upload", variant="primary")
with gr.Column():
output = gr.Image(type="pil")
response_handler = (
response,
[chat, image],
[chat]
)
# postresponse_handler = (
# lambda: (gr.Button(visible=False), gr.Button(visible=True)),
# None,
# [submit]
# )
event = submit.click(*response_handler)
# event.then(*postresponse_handler)
demo.launch() |