Spaces:
Running
Running
File size: 2,478 Bytes
a0a6a64 e02e941 a539d3b a0a6a64 a539d3b a0a6a64 a539d3b a0a6a64 a539d3b e02e941 a539d3b e02e941 a539d3b a0a6a64 9cb2953 442299b 9cb2953 e02e941 a539d3b 9cb2953 3e0b719 a539d3b 9cb2953 442299b 9cb2953 e02e941 a8ec507 16b4096 a539d3b 5ff3449 a539d3b 9cb2953 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import torch
import spaces
import gradio as gr
from PIL import Image
from transformers.utils import move_cache
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the model and processor
MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B"
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
MODEL_PATH = snapshot_download(MODEL_PATH)
move_cache()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
).to(DEVICE).eval()
text_only_template = """USER: {} ASSISTANT:"""
@spaces.GPU
def generate_caption(image, prompt):
print(DEVICE)
# Process the image and the prompt
# image = Image.open(image_path).convert('RGB')
image = image.convert('RGB')
query = "USER: %s ASSISTANT:" % prompt
input_by_model = model.build_conversation_input_ids(
tokenizer,
query=query,
history=[],
images=[image],
template_version='chat'
)
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None,
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("<|end_of_text|>")[0]
print("\nCogVLM2:", response)
return response
## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
demo = gr.Interface(
fn=generate_caption,
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")],
outputs=gr.Textbox(label="Generated Caption")
)
# Launch the interface
demo.launch(share=True) |