Spaces:
Runtime error
Runtime error
| import os, yaml | |
| import gradio as gr | |
| import requests | |
| import argparse | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| from huggingface_hub import hf_hub_download | |
| ## InstructIR Plugin ## | |
| from insir_models import instructir | |
| from insir_text.models import LanguageModel, LMHead | |
| hf_hub_download(repo_id="marcosv/InstructIR", filename="im_instructir-7d.pt", local_dir="./") | |
| hf_hub_download(repo_id="marcosv/InstructIR", filename="lm_instructir-7d.pt", local_dir="./") | |
| CONFIG = "eval5d.yml" | |
| LM_MODEL = "lm_instructir-7d.pt" | |
| MODEL_NAME = "im_instructir-7d.pt" | |
| def dict2namespace(config): | |
| namespace = argparse.Namespace() | |
| for key, value in config.items(): | |
| if isinstance(value, dict): | |
| new_value = dict2namespace(value) | |
| else: | |
| new_value = value | |
| setattr(namespace, key, new_value) | |
| return namespace | |
| # parse config file | |
| with open(os.path.join(CONFIG), "r") as f: | |
| config = yaml.safe_load(f) | |
| cfg = dict2namespace(config) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| ir_model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks, | |
| middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim) | |
| ir_model = ir_model.to(device) | |
| print ("IMAGE MODEL CKPT:", MODEL_NAME) | |
| ir_model.load_state_dict(torch.load(MODEL_NAME, map_location="cpu"), strict=True) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| LMODEL = cfg.llm.model | |
| language_model = LanguageModel(model=LMODEL) | |
| lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses) | |
| lm_head = lm_head.to(device) | |
| print("LMHEAD MODEL CKPT:", LM_MODEL) | |
| lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True) | |
| def process_img(image, prompt=None): | |
| if prompt is None: | |
| prompt = chat("How to improve the quality of the image?", [], image, None, None, None) | |
| prompt += "Please help me improve its quality!" | |
| print(prompt) | |
| img = np.array(image) | |
| img = img / 255. | |
| img = img.astype(np.float32) | |
| y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device) | |
| lm_embd = language_model(prompt) | |
| lm_embd = lm_embd.to(device) | |
| with torch.no_grad(): | |
| text_embd, deg_pred = lm_head(lm_embd) | |
| x_hat = ir_model(y, text_embd) | |
| restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy() | |
| restored_img = np.clip(restored_img, 0. , 1.) | |
| restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8 | |
| return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img)) | |
| ## InstructIR Plugin ## | |
| model = AutoModelForCausalLM.from_pretrained("q-future/co-instruct-preview", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| attn_implementation="eager", | |
| device_map={"":"cuda:0"}) | |
| def chat(message, history, image_1, image_2, image_3, image_4): | |
| print(history) | |
| if history: | |
| if image_1 is not None and image_2 is None: | |
| past_message = "USER: The input image: <|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] | |
| for i in range((len(history) - 1)): | |
| past_message += "USER:" +history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
| message = past_message + "USER:" + message + " ASSISTANT:" | |
| images = [image_1] | |
| if image_1 is not None and image_2 is not None: | |
| if image_3 is None: | |
| past_message = "USER: The first image: <|image|>\nThe second image: <|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] + "</s>" | |
| for i in range((len(history) - 1)): | |
| past_message += "USER:" + history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
| message = past_message + "USER:" + message + " ASSISTANT:" | |
| images = [image_1, image_2] | |
| else: | |
| if image_4 is None: | |
| past_message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] + "</s>" | |
| for i in range((len(history) - 1)): | |
| past_message += "USER:" + history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
| message = past_message + "USER:" + message + " ASSISTANT:" | |
| images = [image_1, image_2, image_3] | |
| else: | |
| past_message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>\nThe fourth image:<|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] + "</s>" | |
| for i in range((len(history) - 1)): | |
| past_message += "USER:" + history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
| message = past_message + "USER:" + message + " ASSISTANT:" | |
| images = [image_1, image_2, image_3, image_4] | |
| else: | |
| if image_1 is not None and image_2 is None: | |
| message = "USER: The input image: <|image|>" + message + " ASSISTANT:" | |
| images = [image_1] | |
| if image_1 is not None and image_2 is not None: | |
| if image_3 is None: | |
| message = "USER: The first image: <|image|>\nThe second image: <|image|>" + message + " ASSISTANT:" | |
| images = [image_1, image_2] | |
| else: | |
| if image_4 is None: | |
| message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>" + message + " ASSISTANT:" | |
| images = [image_1, image_2, image_3] | |
| else: | |
| message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>\nThe fourth image:<|image|>" + message + " ASSISTANT:" | |
| images = [image_1, image_2, image_3, image_4] | |
| print(message) | |
| return model.tokenizer.batch_decode(model.chat(message, images, max_new_tokens=600).clamp(0, 100000))[0].split("ASSISTANT:")[-1] | |
| #### Image,Prompts examples | |
| examples = [ | |
| ["Which part of the image is relatively clearer, the upper part or the lower part? Please analyze in details.", "examples/sausage.jpg", None], | |
| ["Which image is noisy, and which one is with motion blur? Please analyze in details.", "examples/211.jpg", "examples/frog.png"], | |
| ["What is the problem in this image, and how to fix it? Please answer my questions one by one.", "examples/lol_748.png", None], | |
| ] | |
| title = "Q-Instruct-Plus🧑🏫🖌️" | |
| with gr.Blocks(title="Q-Instruct-Plus🧑🏫🖌️") as demo: | |
| title_markdown = (""" | |
| <h1 align="center"><a href="https://github.com/Q-Future/Q-Instruct"><img src="https://github.com/Q-Future/Q-Instruct/blob/main/q_instruct_logo.png?raw=true", alt="Q-Instruct (mPLUG-Owl-2)" border="0" style="margin: 0 auto; height: 85px;" /></a> </h1> | |
| <h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2> | |
| <div align="center">Super Version of Q-Instruct with Multi-image (up to 4, same as GPT-4V) Support! We also support <a href='https://huggingface.co/marcosv/InstructIR'>InstructIR</a> as PLUGIN!</div> | |
| <h5 align="center"> Please find our more accurate visual scoring demo on <a href='https://huggingface.co/spaces/teowu/OneScorer'>[OneScorer]</a>!</h2> | |
| <div align="center"> | |
| <div style="display:flex; gap: 0.25rem;" align="center"> | |
| <a href='https://github.com/Q-Future/Q-Instruct'><img src='https://img.shields.io/badge/Github-Code-blue'></a> | |
| <a href="https://Q-Instruct.github.io/Q-Instruct/fig/Q_Instruct_v0_1_preview.pdf"><img src="https://img.shields.io/badge/Technical-Report-red"></a> | |
| <a href='https://github.com/Q-Future/Q-Instruct/stargazers'><img src='https://img.shields.io/github/stars/Q-Future/Q-Instruct.svg?style=social'></a> | |
| </div> | |
| </div> | |
| """) | |
| gr.Markdown(title_markdown) | |
| with gr.Row(): | |
| input_img_1 = gr.Image(type='pil', label="Image 1 (First image)") | |
| input_img_2 = gr.Image(type='pil', label="Image 2 (Second image)") | |
| input_img_3 = gr.Image(type='pil', label="Image 3 (Third image)") | |
| input_img_4 = gr.Image(type='pil', label="Image 4 (Third image)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.ChatInterface(fn = chat, additional_inputs=[input_img_1, input_img_2, input_img_3, input_img_4], theme="Soft", examples=examples) | |
| with gr.Column(scale=1): | |
| input_image_ir = gr.Image(type="pil", label="Image for Auto Restoration") | |
| output_image_ir = gr.Image(type="pil", label="Output of Auto Restoration") | |
| gr.Interface( | |
| fn=process_img, | |
| inputs=[input_image_ir], | |
| outputs=[output_image_ir], | |
| examples=["examples/gopro.png", "examples/noise50.png", "examples/lol_748.png"], | |
| ) | |
| demo.launch(share=True) |