File size: 2,417 Bytes
6974603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch 
from PIL import Image
import gradio as gr

device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification", 
                          model= AutoModelForImageClassification.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
                          device=device,
                          torch_dtype=dtype)


style_pipe = pipeline("image-classification", 
                          model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_style"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_style"),
                          device=device,
                          torch_dtype=dtype)

aesthetic_pipe = pipeline("image-classification", 
                          model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_aesthetic"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_aesthetic"),
                          device=device,
                          torch_dtype=dtype)

def predict(image, files=None):
  print(image, files)
  images_paths = [image]
  if not files == None: 
    images_paths = list(map(lambda x: x.name, files))
  pil_images = [Image.open(image_path).convert("RGB") for image_path in images_paths]
  
  style = style_pipe(pil_images)
  aesthetic = aesthetic_pipe(pil_images)
  nsfw = nsfw_pipe(pil_images)
  results = [ a + b + c  for (a,b,c) in zip(style, aesthetic, nsfw)]

  label_data = [{ row["label"]:row["score"] for row in image } for image in results]
                
  return label_data[0], label_data

with gr.Blocks() as blocks:
  with gr.Row():
    with gr.Column():
      image = gr.Image(label="Image to test", type="filepath")
      files = gr.File(label="Multipls Images", file_types=["image"], file_count="multiple")
    with gr.Column():
      label = gr.Label(label="style")
      results = gr.JSON(label="Results")
      # gallery = gr.Gallery().style(grid=[2], height="auto")
  btn = gr.Button("Run")

  btn.click(fn=predict, inputs=[image, files], outputs=[label, results], api_name="inference")

blocks.queue()
blocks.launch(debug=True,inline=True)