Iruos8805 commited on
Commit
852ec4a
·
1 Parent(s): 1e271fc
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -5,30 +5,29 @@ import gradio as gr
5
  from transformers import pipeline
6
 
7
 
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
  # Loading in Model
11
  model_name = "dima806/ai_vs_real_image_detection"
12
- model = ViTForImageClassification.from_pretrained(model_name).to(device)
13
- model.to(device)
14
 
15
 
16
- #Classification function
17
  def classify_image(img: Image.Image):
18
- inputs = model(images=img, return_tensors="pt").to(device)
19
- results = model(inputs)
20
  top = results[0]
21
  label = top["label"]
22
  score = top["score"]
23
  return f"Prediction: {label} (Confidence: {score:.2f})"
24
 
25
-
26
-
27
- # Interface
28
  interface = gr.Interface(
29
  fn=classify_image,
30
  inputs=gr.Image(type="pil"),
31
  outputs="text",
32
- title="Real vs AI Image detection",
33
- description="Check if your image is Real or AI"
34
- )
 
 
 
5
  from transformers import pipeline
6
 
7
 
8
+ device = 0 if torch.cuda.is_available() else -1
9
 
10
  # Loading in Model
11
  model_name = "dima806/ai_vs_real_image_detection"
12
+ pipe = pipeline("image-classification", model=model_name, device = device)
13
+
14
 
15
 
16
+ # Classification function
17
  def classify_image(img: Image.Image):
18
+ results = pipe(img)
 
19
  top = results[0]
20
  label = top["label"]
21
  score = top["score"]
22
  return f"Prediction: {label} (Confidence: {score:.2f})"
23
 
24
+ # Gradio interface
 
 
25
  interface = gr.Interface(
26
  fn=classify_image,
27
  inputs=gr.Image(type="pil"),
28
  outputs="text",
29
+ title="Real vs AI Image Detection",
30
+ description="Upload an image to see if it's REAL or AI-generated."
31
+ )
32
+
33
+ interface.launch()