File size: 5,744 Bytes
fd8d20e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import whisperx
import gradio as gr
from peft import PeftModel
from configs import get_config_phase2
from transformers import AutoTokenizer, AutoProcessor, CLIPVisionModel, AutoModelForCausalLM
config = get_config_phase2()
clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
base_model = AutoModelForCausalLM.from_pretrained(
config.get("phi2_model_name"),
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float32,
trust_remote_code=True
)
ckpts = "ckpts/Qlora_adaptor/"
phi2_model = PeftModel.from_pretrained(base_model, ckpts)
phi2_model = phi2_model.merge_and_unload().to(config.get("device"))
projection_layer = torch.nn.Linear(config.get("clip_embed"), config.get("phi_embed"))
projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
processor = AutoProcessor.from_pretrained(config.get("clip_model_name"), trust_remote_code=True)
audio_model = whisperx.load_model('tiny', 'cpu', compute_type="float32")
def generate_answers(img=None, aud = None, q = None, max_tokens = 30):
batch_size = 1
start_iq = tokenizer.encode("<iQ>")
end_iq = tokenizer.encode("</iQ>")
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
start_iq_embeds = phi2_model.model.embed_tokens(start_iq_embeds.to(config.get("device")))
end_iq_embeds = phi2_model.model.embed_tokens(end_iq_embeds.to(config.get("device")))
inputs_embeddings = []
inputs_embeddings.append(start_iq_embeds)
predicted_caption = torch.full((batch_size, max_tokens), 50256, dtype=torch.long, device=config.get('device'))
if img is not None:
images = processor(images=img, return_tensors="pt")['pixel_values'].to(config.get("device"))
images = {'pixel_values': images.to(config.get("device"))}
clip_outputs = clip_model(**images)
# remove cls token
images = clip_outputs.last_hidden_state[:, 1:, :]
image_embeddings = projection_layer(images).to(torch.float32)
inputs_embeddings.append(image_embeddings)
if aud is not None:
trans = audio_model.transcribe(aud)
audio_res = ""
for seg in trans['segments']:
audio_res += seg['text']
audio_res = audio_res.strip()
audio_tokens = tokenizer(audio_res,return_tensors="pt", return_attention_mask=False)['input_ids']
audio_embeds = phi2_model.model.embed_tokens(audio_tokens.to(config.get("device")))
inputs_embeddings.append(audio_embeds)
if q!='':
ques = tokenizer(q, return_tensors="pt", return_attention_mask=False)['input_ids']
q_embeds = phi2_model.model.embed_tokens(ques.to(config.get("device")))
inputs_embeddings.append(q_embeds)
inputs_embeddings.append(end_iq_embeds)
# Combine embeddings
combined_embeds = torch.cat(inputs_embeddings, dim=1)
predicted_caption = phi2_model.generate(inputs_embeds=combined_embeds,
max_new_tokens=max_tokens,
return_dict_in_generate = True)
# for pos in range(max_tokens - 1):
# model_output_logits = phi2_model.forward(inputs_embeds = combined_embeds)['logits']
# print(model_output_logits.shape)
# predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
# predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
# predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
# print(predicted_caption)
# next_token_embeds = phi2_model.model.embed_tokens(predicted_word_token)
# combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
# print("combined_embeds", combined_embeds.shape)
# predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
predicted_captions_decoded =tokenizer.batch_decode(predicted_caption.sequences[:, 1:])[0]
predicted_captions_decoded = predicted_captions_decoded.replace("<|endoftext|>","")
return predicted_captions_decoded
# List of examples (image, audio, question, max_tokens)
examples = [
["./examples/Image_1.jpg", None, "Explain image?", 20],
["./examples/Image_2.jpg", None, "How many animals are there in image?", 10],
["./examples/Image_3.jpg", None, "What is in the image?", 20],
["./examples/Image_4.jpg", None, "What represents this Image?", 20],
]
with gr.Blocks() as demo:
gr.Markdown(
"""
# MultiModelLLM
Multimodel GPT with inputs as Image, Audio, Text with output as Text.
"""
)
with gr.Row():
with gr.Column():
image = gr.Image(label='Image', type="pil", value=None)
audio_q = gr.Audio(label="Audio Question", value=None, sources=['microphone', 'upload'], type='filepath')
question = gr.Text(label ='Question?', value=None)
max_tokens = gr.Slider(1, 50, value=10, step=1, label="Max tokens")
with gr.Row():
answer = gr.Text(label ='Answer')
with gr.Row():
submit = gr.Button("Submit")
submit.click(generate_answers, inputs=[image, audio_q, question, max_tokens], outputs=[answer])
clear_btn = gr.ClearButton([image, audio_q, question, max_tokens, answer])
# Add examples
gr.Examples(examples=examples, inputs=[image, audio_q, question, max_tokens], outputs=answer)
if __name__ == "__main__":
demo.launch(share=True, debug=True) |