Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from model import preprocess, load_model | |
from transformers import CLIPModel, CLIPProcessor | |
MODEL = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = CLIPModel.from_pretrained(MODEL) | |
vision_model = model.vision_model | |
vision_model.to(DEVICE) | |
del model | |
clip_processor = CLIPProcessor.from_pretrained(MODEL) | |
rating_model = load_model("aesthetics_scorer_rating_openclip_vit_l_14.pth").to(DEVICE) | |
artifacts_model = load_model("aesthetics_scorer_artifacts_openclip_vit_l_14.pth").to(DEVICE) | |
def predict(img): | |
inputs = clip_processor(images=img, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
vision_output = vision_model(**inputs) | |
pooled_output = vision_output.pooler_output | |
embedding = preprocess(pooled_output) | |
with torch.no_grad(): | |
rating = rating_model(embedding) | |
artifact = artifacts_model(embedding) | |
return rating.detach().cpu().item(), artifact.detach().cpu().item() | |
gr.Interface( | |
title="Aesthetics Scorer", | |
description="Predicts aesthetics and artifact scores for images using CLIP-ViT-L. Demo for https://github.com/kenjiqq/aesthetics-scorer", | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Number(label="Rating ~1-10 (high is good)"), gr.Number(label="Artifacts ~0-5 (low is good)")] | |
).launch() |