File size: 5,547 Bytes
fd8d20e 9a54dbb 7eb72e8 9a54dbb 7eb72e8 9a54dbb fd8d20e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM
config = get_config_phase2()
clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
base_model = AutoModelForCausalLM.from_pretrained(
config.get("phi2_model_name"),
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float32,
trust_remote_code=True
)
ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))
projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)
audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")
def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
batch_size = 1
start_iq = tokenizer.encode("<iQ>")
end_iq = tokenizer.encode("</iQ>")
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
inputs_embeddings = []
inputs_embeddings.append(start_iq_embeds)
predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
if img is not None:
images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get("device"))
images = {'pixel_values': images.to(config.get("device"))}
clip_outputs = clip_model(**images)
# remove cls token
images = clip_outputs.last_hidden_state[:, 1:, :]
image_embeddings = projection_layer(images).to(torch.float32)
inputs_embeddings.append(image_embeddings)
if aud is not None:
trans = audio_model.transcribe(aud)
audio_res = ""
for seg in trans['segments']:
audio_res += seg['text']
audio_res = audio_res.strip()
audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
inputs_embeddings.append(audio_embeds)
if q!='':
ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
inputs_embeddings.append(q_embeds)
inputs_embeddings.append(end_iq_embeds)
# Combine embeddings
combined_embeds = torch.cat(inputs_embeddings, dim=1)
predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
max_new_tokens=max_tokens,
return_dict_in_generate = True)
predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
return predicted_captions_decoded
# List of examples (image, audio, question, max_tokens)
examples = [
["./examples/Image_1.jpg", None, "Explain image?", 20],
["./examples/Image_2.jpg", None, "How many animals are there in image?", 10],
["./examples/Image_3.jpg", None, "What is in the image?", 20],
["./examples/Image_4.jpg", None, "What represents this Image?", 20],
]
with gr.Blocks() as demo:
gr.Markdown(
"""
# MultiModelLLM
Multimodel GPT with inputs as Image, Audio, Text with output as Text.
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(label='Image', type="pil", value=None)
audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
question = gr.Text(label ='Question?', value=None)
max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
with gr.Row():
answer = gr.Text(label ='Answer')
with gr.Row():
submit = gr.Button("Submit")
submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
# Add examples
# gr.Examples(examples=examples, inputs=[image, audio_q, question, max_tokens], outputs=answer)
# Add examples
gr.Examples(examples=examples, inputs=[image, audio_q, question, max_tokens])
# Automatically trigger the submit button when examples are loaded
def submit_on_example(image, audio_q, question, max_tokens):
return generate_answers(image, audio_q, question, max_tokens)
# Automatically call generate_answers when an example is clicked
gr.Examples(examples=examples, inputs=[image, audio_q, question, max_tokens], outputs=[answer]).click(
fn=submit_on_example, inputs=[image, audio_q, question, max_tokens], outputs=[answer]
)
if __name__ == "__main__":
demo.launch(share=True, debug=True) |