venkat-natchi's picture
Upload 8 files
f315cdb verified
raw
history blame
5.31 kB
from PIL import Image
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoTokenizer, pipeline
from transformers import AutoModelForCausalLM
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from model import build_mlp_vector_projector
device = "cpu"
# Load the CLIP model and processor
clip_model_name = "openai/clip-vit-base-patch16"
clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
clip_transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor()
]
)
def process_image(img_path):
image = Image.open(img_path).convert("RGB")
image = clip_transform(image)
inputs = clip_processor(text=[""], images=image,
return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
img_embedding = clip_model(**inputs).image_embeds
img_proj_head = build_mlp_vector_projector().to(device)
img_proj_head.load_state_dict(torch.load(
'stage_2_proj_head_v3.pth', map_location=torch.device(device)))
img_tokens = img_proj_head(img_embedding)
return img_tokens
phi_model_name = "microsoft/phi-2"
text_tokenizer = AutoTokenizer.from_pretrained(
phi_model_name, trust_remote_code=True)
with torch.no_grad():
tuned_phi2 = AutoModelForCausalLM.from_pretrained(
"stage2_adaptor", trust_remote_code=True,
device=device, torch_dtype=torch.float16
)
base_phi2_text = AutoModelForCausalLM.from_pretrained(
phi_model_name, trust_remote_code=True,
device_map="auto", torch_dtype=torch.float16
)
print("phi2 model loaded")
audio_model_name = "openai/whisper-small"
audio_pipe = pipeline(
task="automatic-speech-recognition",
model=audio_model_name,
chunk_length_s=30,
device=device)
def process_text(text, count):
inputs = text_tokenizer(text, return_tensors="pt",
return_attention_mask=False)
prediction = text_tokenizer.batch_decode(
base_phi2_text.generate(
**inputs,
max_new_tokens=count,
bos_token_id=text_tokenizer.bos_token_id,
eos_token_id=text_tokenizer.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id
)
)
return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
def process_audio(audio):
if audio is None:
raise gr.Error(
"Please provide an audio file or record your input"
)
text = audio_pipe(
audio,
batch_size=8,
generate_kwargs={"task": "transcribe"},
return_timestamps=True
)["text"]
return text
def generate_response(image, audio, text, count):
count = int(count)
if audio:
text_from_audio = process_audio(audio)
if text:
overall_input = text + text_from_audio
if image:
img_tokens = process_image(image)
q_tokens = text_tokenizer.encode(
overall_input,
return_tensors='pt').to(device)
question_token_embeddings = base_phi2_text.get_submodule(
'model.embed_tokens')(q_tokens).to(device)
inputs = torch.concat(
(img_tokens.unsqueeze(0), question_token_embeddings),
axis=-2).to(device)
prediction = text_tokenizer.batch_decode(
tuned_phi2.generate(
inputs_embeds=inputs,
max_new_tokens=30,
bos_token_id=text_tokenizer.bos_token_id,
eos_token_id=text_tokenizer.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id
)
)
return prediction[0].rstrip('<|endoftext|>').rstrip("\n")
else:
return process_text(overall_input, count)
return prediction[0].strip('<|endoftext|>').rstrip("\n")
with gr.Blocks() as demo:
gr.Markdown("# **AnyModeAssistant**")
gr.Markdown("Use any mode text/image/audio to interact with AI assistant")
with gr.Column():
with gr.Row("Text"):
text_input = gr.Textbox(placeholder="Enter your question here",
label="Input")
with gr.Row():
image_input = gr.Image(type="filepath")
with gr.Row("Audio mode"):
audio_input = gr.Audio(type="filepath")
with gr.Row("Image"):
response_count = gr.Textbox(
placeholder="Number of tokens to respond",
defualt=20,
label="Count")
with gr.Column():
response = gr.Textbox(label="AI Response")
with gr.Row():
submit_button = gr.Button("Submit")
submit_button.click(generate_response,
inputs=[text_input, response_count,
image_input, audio_input],
outputs=response)
# gr.Examples(
# examples=[
# ["What is a large language model?", "50"]
# ],
# # , image_input, image_text_input, audio_input],
# inputs=[text_input, text_input_count],
# outputs=[text_output], # , image_text_output, audio_text_output],
# fn=example_inference,
# )
demo.launch()