Shiwanni commited on
Commit
21aeebf
·
verified ·
1 Parent(s): da75a07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -17
app.py CHANGED
@@ -2,32 +2,52 @@ from transformers import ViTForImageClassification, ViTImageProcessor
2
  import torch
3
  from PIL import Image
4
  import gradio as gr
 
5
 
6
- # Load pre-trained model and processor
7
- model_name = "facebook/deit-base-distilled-patch16-224"
8
- processor = ViTImageProcessor.from_pretrained(model_name)
9
- model = ViTForImageClassification.from_pretrained(model_name)
10
 
11
- def detect_deepfake(image):
12
- # Preprocess the image
13
- inputs = processor(images=image, return_tensors="pt")
 
 
14
 
15
- # Make prediction
16
- outputs = model(**inputs)
17
- logits = outputs.logits
18
- predicted_class_idx = logits.argmax(-1).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # For demonstration, we'll assume class 0 is real and 1 is fake
21
- # (In a real project, you'd need to verify this with your model)
22
- return "Real" if predicted_class_idx == 0 else "Fake (Possible Deepfake)"
23
 
24
- # Create a simple interface
25
  iface = gr.Interface(
26
  fn=detect_deepfake,
27
  inputs=gr.Image(type="pil"),
28
  outputs="text",
29
  title="Deepfake Detection",
30
- description="Upload an image to check if it might be a deepfake."
 
 
 
31
  )
32
 
33
- iface.launch()
 
2
  import torch
3
  from PIL import Image
4
  import gradio as gr
5
+ import warnings
6
 
7
+ # Suppress warnings (optional)
8
+ warnings.filterwarnings('ignore')
 
 
9
 
10
+ try:
11
+ # Load model (smaller version for better performance)
12
+ model_name = "google/vit-base-patch16-224"
13
+ processor = ViTImageProcessor.from_pretrained(model_name)
14
+ model = ViTForImageClassification.from_pretrained(model_name)
15
 
16
+ print("Model loaded successfully!")
17
+ except Exception as e:
18
+ print(f"Error loading model: {e}")
19
+ raise
20
+
21
+ def detect_deepfake(image):
22
+ try:
23
+ # Convert image to RGB
24
+ if image.mode != 'RGB':
25
+ image = image.convert('RGB')
26
+
27
+ # Process image
28
+ inputs = processor(images=image, return_tensors="pt")
29
+
30
+ # Predict
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+
34
+ # Get result
35
+ predicted_class = outputs.logits.argmax(-1).item()
36
+ return "Real" if predicted_class == 0 else "Fake (Possible Deepfake)"
37
 
38
+ except Exception as e:
39
+ return f"Error processing image: {str(e)}"
 
40
 
41
+ # Create interface
42
  iface = gr.Interface(
43
  fn=detect_deepfake,
44
  inputs=gr.Image(type="pil"),
45
  outputs="text",
46
  title="Deepfake Detection",
47
+ examples=[
48
+ ["real_example.jpg"], # Add your example files
49
+ ["fake_example.jpg"]
50
+ ]
51
  )
52
 
53
+ iface.launch(server_port=7860, share=False) # Disable share for local use