File size: 1,606 Bytes
07e1105
 
 
 
 
 
 
 
 
7072d7b
07e1105
 
2153e8d
 
 
 
07e1105
 
 
 
 
 
 
 
a32d7be
07e1105
 
 
 
 
 
 
 
 
a32d7be
 
 
 
 
 
 
 
 
 
 
07e1105
 
2fcfba4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torchvision

import cv2
import numpy as np
from models import monet as MoNet
import argparse
from utils.dataset.process import ToTensor, Normalize
import gradio as gr
import os

def load_image(img_path):
    if isinstance(img_path, str):
        d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    else:
        d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
    d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
    d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
    d_img = np.array(d_img).astype('float32') / 255
    d_img = np.transpose(d_img, (2, 0, 1))
    
    return d_img

def predict(image):
    global model
    trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])

    """Run a single prediction on the model"""
    img = load_image(image)
    img_tensor = trans(img).unsqueeze(0).cuda()
    iq = model(img_tensor).cpu().detach().numpy().tolist()[0]

    return "The image quality of the image is: {}".format(round(iq, 4))

# os.system("wget https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")

parser = argparse.ArgumentParser()
# model related
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.')
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
config = parser.parse_args()

model = MoNet.MoNet(config).cuda()
model.load_state_dict(torch.load('best_model.pkl'))
model.eval()

interface = gr.Interface(fn=predict, inputs="image", outputs="text")
interface.launch()