Spaces:
Runtime error
Runtime error
File size: 5,292 Bytes
bcb5222 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import peft
from peft import LoraConfig
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
import torch
from peft import PeftModel
import torch.nn as nn
import whisperx
clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
phi_model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(clip_model_name)
tokenizer.pad_token = tokenizer.eos_token
IMAGE_TOKEN_ID = 23893 # token for word comment
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_embed = 640
phi_embed = 2560
compute_type = "float32"
audio_batch_size = 16
class SimpleResBlock(nn.Module):
def __init__(self, phi_embed):
super().__init__()
self.pre_norm = nn.LayerNorm(phi_embed)
self.proj = nn.Sequential(
nn.Linear(phi_embed, phi_embed),
nn.GELU(),
nn.Linear(phi_embed, phi_embed)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
# models
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
resblock = SimpleResBlock(phi_embed).to(device)
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
audio_model = whisperx.load_model("tiny", device, compute_type=compute_type)
# load weights
model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/qlora_adaptor')
merged_model = model_to_merge.merge_and_unload()
projection.load_state_dict(torch.load('./model_chkpt/ft_projection_layer.pth',map_location=torch.device(device)))
resblock.load_state_dict(torch.load('./model_chkpt/ft_projection_model.pth',map_location=torch.device(device)))
def model_generate_ans(img=None,img_audio=None,val_q=None):
max_generate_length = 100
val_combined_embeds = []
with torch.no_grad():
# image
if img is not None:
image_processed = processor(images=img, return_tensors="pt").to(device)
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
val_image_embeds = projection(clip_val_outputs)
val_image_embeds = resblock(val_image_embeds).to(torch.float16)
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
val_combined_embeds.append(val_image_embeds)
val_combined_embeds.append(img_token_embeds)
# audio
if img_audio is not None:
audio_result = audio_model.transcribe(img_audio)
audio_text = ''
for seg in audio_result['segments']:
audio_text += seg['text']
audio_text = audio_text.strip()
audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
val_combined_embeds.append(audio_embeds)
# text question
if len(val_q) != 0:
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
val_combined_embeds.append(val_q_embeds)
val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
for g in range(max_generate_length):
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
predicted_caption[:,g] = predicted_word_token.view(1,-1)
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
return predicted_captions_decoded
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chat with MultiModal GPT !
Build using combining clip model and phi-2 model.
"""
)
# app GUI
with gr.Row():
with gr.Column():
img_input = gr.Image(label='Image',type="pil")
img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
img_question = gr.Text(label ='Text Query')
with gr.Column():
img_answer = gr.Text(label ='Answer')
section_btn = gr.Button("Submit")
section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
demo.launch() |