Spaces:
Sleeping
Sleeping
from PIL import Image | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, pipeline | |
from transformers import AutoModelForCausalLM | |
from torchvision import transforms | |
from transformers import CLIPProcessor, CLIPModel | |
from model import build_mlp_vector_projector | |
device = "cpu" | |
# Load the CLIP model and processor | |
clip_model_name = "openai/clip-vit-base-patch16" | |
clip_model = CLIPModel.from_pretrained(clip_model_name).to(device) | |
clip_processor = CLIPProcessor.from_pretrained(clip_model_name) | |
clip_transform = transforms.Compose( | |
[ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
] | |
) | |
def process_image(img_path): | |
image = Image.open(img_path).convert("RGB") | |
image = clip_transform(image) | |
inputs = clip_processor(text=[""], images=image, | |
return_tensors="pt", padding=True) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
img_embedding = clip_model(**inputs).image_embeds | |
img_proj_head = build_mlp_vector_projector().to(device) | |
img_proj_head.load_state_dict(torch.load( | |
'stage_2_proj_head_v3.pth', map_location=torch.device(device))) | |
img_tokens = img_proj_head(img_embedding) | |
return img_tokens | |
phi_model_name = "microsoft/phi-2" | |
text_tokenizer = AutoTokenizer.from_pretrained( | |
phi_model_name, trust_remote_code=True) | |
with torch.no_grad(): | |
base_phi2_text = AutoModelForCausalLM.from_pretrained( | |
phi_model_name, trust_remote_code=True, | |
device_map="auto", torch_dtype=torch.float16 | |
) | |
tuned_phi2 = AutoModelForCausalLM.from_pretrained( | |
"stage2_adaptor", trust_remote_code=True, | |
).to("cpu") | |
print("phi2 model loaded") | |
audio_model_name = "openai/whisper-small" | |
audio_pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=audio_model_name, | |
chunk_length_s=30, | |
device=device) | |
def process_text(text, count): | |
inputs = text_tokenizer.encode(text, return_tensors="pt") | |
input_embeds = tuned_phi2.get_submodule( | |
'model.embed_tokens')(inputs).to(device) | |
prediction = text_tokenizer.batch_decode( | |
tuned_phi2.generate( | |
inputs_embeds=input_embeds, | |
max_new_tokens=30, | |
bos_token_id=text_tokenizer.bos_token_id, | |
eos_token_id=text_tokenizer.eos_token_id, | |
pad_token_id=text_tokenizer.pad_token_id | |
) | |
) | |
return prediction[0].rstrip('<|endoftext|>').rstrip("\n") | |
def process_audio(audio): | |
if audio is None: | |
raise gr.Error( | |
"Please provide an audio file or record your input" | |
) | |
text = audio_pipe( | |
audio, | |
batch_size=8, | |
generate_kwargs={"task": "transcribe"}, | |
return_timestamps=True | |
)["text"] | |
return text | |
def generate_response(image, audio, text, count): | |
count = int(count) | |
overall_input = "" | |
if audio: | |
overall_input = process_audio(audio) | |
if text: | |
overall_input = text + overall_input | |
if image: | |
img_tokens = process_image(image) | |
overall_input = "Question: " + overall_input + "Answer:" | |
q_tokens = text_tokenizer.encode( | |
overall_input, | |
return_tensors='pt').to(device) | |
question_token_embeddings = tuned_phi2.get_submodule( | |
'model.embed_tokens')(q_tokens).to(device) | |
inputs = torch.concat( | |
(img_tokens.unsqueeze(0), question_token_embeddings), | |
axis=-2).to(device) | |
prediction = text_tokenizer.batch_decode( | |
tuned_phi2.generate( | |
inputs_embeds=inputs, | |
max_new_tokens=30, | |
bos_token_id=text_tokenizer.bos_token_id, | |
eos_token_id=text_tokenizer.eos_token_id, | |
pad_token_id=text_tokenizer.pad_token_id | |
) | |
) | |
return prediction[0].rstrip('<|endoftext|>').rstrip("\n") | |
else: | |
return process_text(overall_input, count) | |
with gr.Blocks() as demo: | |
gr.Markdown("# **AnyModeAssistant**") | |
gr.Markdown("Use any mode text/image/audio to interact with AI assistant") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Row("Text"): | |
text_input = gr.Textbox(placeholder="Enter your question here", | |
label="Input") | |
with gr.Row(): | |
image_input = gr.Image(type="filepath") | |
with gr.Row("Audio mode"): | |
audio_input = gr.Audio(type="filepath") | |
with gr.Row("Image"): | |
response_count = gr.Textbox( | |
placeholder="Number of tokens to respond", | |
value=20, | |
label="Count") | |
with gr.Column(scale=2): | |
response = gr.Textbox(label="AI Response") | |
with gr.Row(): | |
submit_button = gr.Button("Submit") | |
submit_button.click(generate_response, | |
inputs=[ | |
image_input, audio_input, | |
text_input, response_count | |
], | |
outputs=response) | |
gr.Examples( | |
examples=[ | |
["dog_man_forest.jpg", "audio.m4a", "Is there a dog present in the image?"], | |
], | |
inputs=[input_image, audio_input, text_input, response_count], | |
outputs=[response], | |
fn=generate_response, | |
) | |
demo.launch(share=True) | |