Spaces:
Runtime error
Runtime error
File size: 3,224 Bytes
563f98d bc3802f 7d58261 563f98d e0c81f0 563f98d bc3802f 7d58261 bc3802f fb7a950 b220b28 fb7a950 b220b28 fb7a950 bc3802f 60eaa44 b220b28 e80c4ee bc3802f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import spaces
import argparse
import torch
from transformers import AutoModel, AutoProcessor
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
device = "cuda:0" if torch.cuda.is_available() else "cpu"
title = """<h1 style="text-align: center;">Product description generator</h1>"""
css = """
div#col-container {
margin: 0 auto;
max-width: 840px;
}
"""
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [151645]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@torch.no_grad()
def response(message, history, image):
stop = StopOnTokens()
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
if len(messages) == 1:
message = f" <image>{message}"
messages.append({"role": "user", "content": message})
model_inputs = processor.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
image = (
processor.feature_extractor(image)
.unsqueeze(0)
)
attention_mask = torch.ones(
1, model_inputs.shape[1] + processor.num_image_latents - 1
)
model_inputs = {
"input_ids": model_inputs,
"images": image,
"attention_mask": attention_mask
}
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
history.append([message, ""])
partial_response = ""
for new_token in streamer:
partial_response += new_token
history[-1][1] = partial_response
yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True)
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
image = gr.Image(type="pil")
submit = gr.Button(value="Upload", variant="primary")
chat = gr.Chatbot(show_label=False)
message = gr.Textbox(interactive=True, show_label=False, container=False)
response_handler = (
response,
[message, chat, image],
[submit]
)
postresponse_handler = (
lambda: (gr.Button(visible=False), gr.Button(visible=True)),
None,
[submit]
)
event = submit.click(*response_handler)
event.then(*postresponse_handler)
demo.launch() |