aznasut commited on
Commit
e5abf1b
·
1 Parent(s): 097dd3b

fix v1/detect

Browse files
Files changed (2) hide show
  1. main.py +55 -46
  2. models.py +2 -2
main.py CHANGED
@@ -5,9 +5,9 @@ import aiohttp
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
6
  from fastapi.responses import JSONResponse
7
 
8
- # from transformers import AutoImageProcessor, ViTForImageClassification
9
- from transformers import pipeline
10
- from transformers.pipelines import PipelineException
11
  from PIL import Image
12
  from cachetools import Cache
13
  import torch
@@ -27,10 +27,9 @@ logging.basicConfig(
27
  cache = Cache(maxsize=1000)
28
 
29
  # Load the model using the transformers pipeline
30
- model = pipeline("image-classification", model="dima806/deepfake_vs_real_image_detection")
31
- # model = pipeline("image-classification", model="Wvolf/ViT_Deepfake_Detection")
32
- # image_processor = AutoImageProcessor.from_pretrained("Wvolf/ViT_Deepfake_Detection")
33
- # model = ViTForImageClassification.from_pretrained("Wvolf/ViT_Deepfake_Detection")
34
 
35
  # Detect the device used by TensorFlow
36
  # DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
@@ -84,16 +83,26 @@ async def classify_image(file: UploadFile = File(None)):
84
 
85
  image = Image.open(io.BytesIO(image_data))
86
 
87
- # inputs = image_processor(image, return_tensors="pt")
88
- inputs = model(image)
89
-
90
- # with torch.no_grad():
91
- # outpus = model(**inputs)
92
- # logits = outpus.logits
93
- # probs = F.softmax(logits, dim=-1)
94
- # predicted_label_id = probs.argmax(-1).item()
95
- # predicted_label = model.config.id2label[predicted_label_id]
96
- # confidence = probs.max().item()
 
 
 
 
 
 
 
 
 
 
97
 
98
  # model predicts one of the 1000 ImageNet classes
99
  # predicted_label = logits.argmax(-1).item()
@@ -101,16 +110,16 @@ async def classify_image(file: UploadFile = File(None)):
101
  # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
102
  # # print(model.config.id2label[predicted_label])
103
  # Find the prediction with the highest confidence using the max() function
104
- predicted_label = max(inputs, key=lambda x: x["score"])
105
  # logging.info("best_prediction %s", best_prediction)
106
  # best_prediction2 = results[1]["label"]
107
  # logging.info("best_prediction2 %s", best_prediction2)
108
 
109
  # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
110
- confidence = round(predicted_label["score"] * 100, 1)
111
 
112
  # # Prepare the custom response data
113
- detection_result = {
114
  "prediction": predicted_label,
115
  "confidence":confidence,
116
  }
@@ -130,20 +139,20 @@ async def classify_image(file: UploadFile = File(None)):
130
  # }
131
 
132
  # Populate hash
133
- cache[image_hash] = detection_result.copy()
134
 
135
  # Add url to the API response
136
- detection_result["file_name"] = file.filename
137
 
138
- response_data.append(detection_result)
139
 
140
  # Add file_name to the API response
141
- response_data["file_name"] = file.filename
142
 
143
  return FileImageDetectionResponse(**response_data)
144
 
145
- # except Exception as e:
146
- except PipelineException as e:
147
  logging.error("Error processing image: %s", str(e))
148
  raise HTTPException(
149
  status_code=500, detail=f"Error processing image: {str(e)}"
@@ -172,29 +181,29 @@ async def classify_images(request: ImageUrlsRequest):
172
  continue
173
 
174
  image = Image.open(io.BytesIO(image_data))
175
- # inputs = image_processor(image, return_tensors="pt")
176
- inputs = model(image)
177
-
178
-
179
- # with torch.no_grad():
180
- # outpus = model(**inputs)
181
- # logits = outpus.logits
182
- # logging.info("logits %s", logits)
183
- # probs = F.softmax(logits, dim=-1)
184
- # logging.info("probs %s", probs)
185
- # predicted_label_id = probs.argmax(-1).item()
186
- # logging.info("predicted_label_id %s", predicted_label_id)
187
- # predicted_label = model.config.id2label[predicted_label_id]
188
- # logging.info("model.config.id2label %s", model.config.id2label)
189
- # confidence = probs.max().item()
190
 
191
  # model predicts one of the 1000 ImageNet classes
192
  # predicted_label = logits.argmax(-1).item()
193
  # logging.info("predicted_label", predicted_label)
194
  # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
195
  # # print(model.config.id2label[predicted_label])
196
- logging.info("inputs %s", inputs)
197
- predicted_label = max(inputs, key=lambda x: x["score"])
198
  # best_prediction = max(results, key=lambda x: x["score"])
199
  # logging.info("best_prediction %s", best_prediction)
200
  # best_prediction2 = results[1]["label"]
@@ -202,7 +211,7 @@ async def classify_images(request: ImageUrlsRequest):
202
 
203
  # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
204
  # confidence_percentage = round(best_prediction["score"] * 100, 1)
205
- confidence = round(predicted_label["score"] * 100, 1)
206
 
207
  # # Prepare the custom response data
208
  detection_result = {
@@ -232,8 +241,8 @@ async def classify_images(request: ImageUrlsRequest):
232
 
233
  response_data.append(detection_result)
234
 
235
- # except Exception as e:
236
- except PipelineException as e:
237
  logging.error("Error processing image from %s: %s", image_url, str(e))
238
  raise HTTPException(
239
  status_code=500,
 
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
6
  from fastapi.responses import JSONResponse
7
 
8
+ from transformers import AutoImageProcessor, ViTForImageClassification
9
+ # from transformers import pipeline
10
+ # from transformers.pipelines import PipelineException
11
  from PIL import Image
12
  from cachetools import Cache
13
  import torch
 
27
  cache = Cache(maxsize=1000)
28
 
29
  # Load the model using the transformers pipeline
30
+ # model = pipeline("image-classification", model="dima806/deepfake_vs_real_image_detection")
31
+ image_processor = AutoImageProcessor.from_pretrained("dima806/deepfake_vs_real_image_detection")
32
+ model = ViTForImageClassification.from_pretrained("dima806/deepfake_vs_real_image_detection")
 
33
 
34
  # Detect the device used by TensorFlow
35
  # DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
 
83
 
84
  image = Image.open(io.BytesIO(image_data))
85
 
86
+ inputs = image_processor(image, return_tensors="pt")
87
+ # inputs = model(image)
88
+
89
+ with torch.no_grad():
90
+ outpus = model(**inputs)
91
+ logits = outpus.logits
92
+ logging.info("logits %s", logits)
93
+ probs = F.softmax(logits, dim=-1)
94
+ logging.info("probs %s", probs)
95
+ predicted_label_id = probs.argmax(-1).item()
96
+ logging.info("predicted_label_id %s", predicted_label_id)
97
+ predicted_label = model.config.id2label[predicted_label_id]
98
+ logging.info("model.config.id2label %s", model.config.id2label)
99
+ confidence = probs.max().item()
100
+ # outpus = model(**inputs)
101
+ # logits = outpus.logits
102
+ # probs = F.softmax(logits, dim=-1)
103
+ # predicted_label_id = probs.argmax(-1).item()
104
+ # predicted_label = model.config.id2label[predicted_label_id]
105
+ # confidence = probs.max().item()
106
 
107
  # model predicts one of the 1000 ImageNet classes
108
  # predicted_label = logits.argmax(-1).item()
 
110
  # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
111
  # # print(model.config.id2label[predicted_label])
112
  # Find the prediction with the highest confidence using the max() function
113
+ # predicted_label = max(inputs, key=lambda x: x["score"])
114
  # logging.info("best_prediction %s", best_prediction)
115
  # best_prediction2 = results[1]["label"]
116
  # logging.info("best_prediction2 %s", best_prediction2)
117
 
118
  # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
119
+ # confidence = round(predicted_label["score"] * 100, 1)
120
 
121
  # # Prepare the custom response data
122
+ response_data = {
123
  "prediction": predicted_label,
124
  "confidence":confidence,
125
  }
 
139
  # }
140
 
141
  # Populate hash
142
+ cache[image_hash] = response_data.copy()
143
 
144
  # Add url to the API response
145
+ response_data["file_name"] = file.filename
146
 
147
+ # response_data.append(detection_result)
148
 
149
  # Add file_name to the API response
150
+ # response_data["file_name"] = file.filename
151
 
152
  return FileImageDetectionResponse(**response_data)
153
 
154
+ except Exception as e:
155
+ # except PipelineException as e:
156
  logging.error("Error processing image: %s", str(e))
157
  raise HTTPException(
158
  status_code=500, detail=f"Error processing image: {str(e)}"
 
181
  continue
182
 
183
  image = Image.open(io.BytesIO(image_data))
184
+ inputs = image_processor(image, return_tensors="pt")
185
+ # inputs = model(image)
186
+
187
+
188
+ with torch.no_grad():
189
+ outpus = model(**inputs)
190
+ logits = outpus.logits
191
+ logging.info("logits %s", logits)
192
+ probs = F.softmax(logits, dim=-1)
193
+ logging.info("probs %s", probs)
194
+ predicted_label_id = probs.argmax(-1).item()
195
+ logging.info("predicted_label_id %s", predicted_label_id)
196
+ predicted_label = model.config.id2label[predicted_label_id]
197
+ logging.info("model.config.id2label %s", model.config.id2label)
198
+ confidence = probs.max().item()
199
 
200
  # model predicts one of the 1000 ImageNet classes
201
  # predicted_label = logits.argmax(-1).item()
202
  # logging.info("predicted_label", predicted_label)
203
  # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
204
  # # print(model.config.id2label[predicted_label])
205
+ # logging.info("inputs %s", inputs)
206
+ # predicted_label = max(inputs, key=lambda x: x["score"])
207
  # best_prediction = max(results, key=lambda x: x["score"])
208
  # logging.info("best_prediction %s", best_prediction)
209
  # best_prediction2 = results[1]["label"]
 
211
 
212
  # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
213
  # confidence_percentage = round(best_prediction["score"] * 100, 1)
214
+ # confidence = round(predicted_label["score"] * 100, 1)
215
 
216
  # # Prepare the custom response data
217
  detection_result = {
 
241
 
242
  response_data.append(detection_result)
243
 
244
+ except Exception as e:
245
+ # except PipelineException as e:
246
  logging.error("Error processing image from %s: %s", image_url, str(e))
247
  raise HTTPException(
248
  status_code=500,
models.py CHANGED
@@ -23,8 +23,8 @@ class ImageDetectionResponse(BaseModel):
23
  confidence_percentage (float): Confidence level of the NSFW classification.
24
  """
25
 
26
- is_nsfw: bool
27
- confidence_percentage: float
28
 
29
 
30
  class FileImageDetectionResponse(ImageDetectionResponse):
 
23
  confidence_percentage (float): Confidence level of the NSFW classification.
24
  """
25
 
26
+ prediction: str
27
+ confidence: float
28
 
29
 
30
  class FileImageDetectionResponse(ImageDetectionResponse):