Spaces:
Runtime error
Runtime error
File size: 6,633 Bytes
563f98d 7d58261 c5f4497 6e02423 95b0407 6e02423 c5f4497 c49cbd7 c5f4497 563f98d e0c81f0 563f98d 7d58261 7beb62a 2b76edc 572b329 23f3832 4cef152 1bbc851 572b329 2b76edc 63db7c6 f617eac 63db7c6 f617eac 63db7c6 2b76edc c434d8b 2b76edc 572b329 373b3b1 b15ce46 373b3b1 b15ce46 572b329 7d58261 3fb1f8e 0f4c15d 29ba44d 7d58261 cb86cd4 3fb1f8e cb86cd4 605b0aa 7d58261 cb86cd4 3fb1f8e 7d58261 3fb1f8e 2eee6e2 b15ce46 7d58261 c5f4497 aeba1bd 32a8e2c 688c0f1 c5f4497 c739636 6bf6d32 c739636 6bf6d32 688c0f1 252e8ea fb7a950 b15ce46 23f3832 fb7a950 bc3802f c5f4497 604742b c5f4497 14eb553 60eaa44 252e8ea c5f4497 c4d9f0b e80c4ee bc3802f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import gradio as gr
import torch
import numpy as np
import torch.nn.functional as F
import PIL
import random
from threading import Thread
from transformers import AutoModel, AutoProcessor
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download, InferenceClient
from briarmbg import BriaRMBG
from PIL import Image
from typing import Tuple
net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [151645]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}"
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>modelo"
return prompt
def getProductDetails(history, image):
product_description=getImageDescription(image)
client = InferenceClient("google/gemma-7b")
rand_val = random.randint(1, 1111111111111111)
if not history:
history = []
generate_kwargs = dict(
temperature=0.67,
max_new_tokens=1024,
top_p=0.9,
repetition_penalty=1,
do_sample=True,
seed=rand_val,
)
system_prompt="you're a helpful e-commerce marketting assitant"
prompt="Write me a poem"
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=stream_output, details=True, return_full_text=False)
output = ""
# for response in stream:
# output += response.token.text
# yield [(prompt, output)]
gr.Info('Gemma:' + product_description)
# history.append((prompt, output))
return history
@torch.no_grad()
def getImageDescription(image):
message = "Generate a product title for the image"
gr.Info('Starting...' + message)
stop = StopOnTokens()
messages = [{"role": "system", "content": "You are a helpful assistant."}]
# for user_msg, assistant_msg in history:
# messages.append({"role": "user", "content": user_msg})
# messages.append({"role": "assistant", "content": assistant_msg})
if len(messages) == 1:
message = f" <image>{message}"
messages.append({"role": "user", "content": message})
model_inputs = processor.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
)
image = (
processor.feature_extractor(image)
.unsqueeze(0)
)
attention_mask = torch.ones(
1, model_inputs.shape[1] + processor.num_image_latents - 1
)
model_inputs = {
"input_ids": model_inputs,
"images": image,
"attention_mask": attention_mask
}
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# history.append([message, ""])
partial_response = ""
for new_token in streamer:
partial_response += new_token
# history[-1][1] = partial_response
# yield history
gr.Info('Got:' + partial_response)
return partial_response
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# prepare input
orig_image = image
w,h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
im_tensor = torch.unsqueeze(im_tensor,0)
im_tensor = torch.divide(im_tensor,255.0)
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
if torch.cuda.is_available():
im_tensor=im_tensor.cuda()
#inference
result=net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
ma = torch.max(result)
mi = torch.min(result)
result = (result-mi)/(ma-mi)
# image to pil
im_array = (result*255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
new_im.paste(orig_image, mask=pil_im)
# new_orig_image = orig_image.convert('RGBA')
return new_im
title = """<h1 style="text-align: center;">Product description generator</h1>"""
css = """
div#col-container {
margin: 0 auto;
max-width: 840px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(elem_id="col-container"):
image = gr.Image(type="pil")
chat = gr.Chatbot(show_label=False)
submit = gr.Button(value="Upload", variant="primary")
with gr.Column():
output = gr.Image(type="pil", interactive=False)
response_handler = (
getProductDetails,
[chat, image],
[chat]
)
background_remover_handler = (
process,
[image],
[output]
)
# postresponse_handler = (
# lambda: (gr.Button(visible=False), gr.Button(visible=True)),
# None,
# [submit]
# )
event = submit.click(*response_handler)
event2 = submit.click(*background_remover_handler)
# event.then(*postresponse_handler)
demo.launch() |