Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	fix v1/detect
Browse files
    	
        main.py
    CHANGED
    
    | @@ -5,9 +5,9 @@ import aiohttp | |
| 5 | 
             
            from fastapi import FastAPI, File, UploadFile, HTTPException
         | 
| 6 | 
             
            from fastapi.responses import JSONResponse
         | 
| 7 |  | 
| 8 | 
            -
             | 
| 9 | 
            -
            from transformers import pipeline
         | 
| 10 | 
            -
            from transformers.pipelines import PipelineException
         | 
| 11 | 
             
            from PIL import Image
         | 
| 12 | 
             
            from cachetools import Cache
         | 
| 13 | 
             
            import torch
         | 
| @@ -27,10 +27,9 @@ logging.basicConfig( | |
| 27 | 
             
            cache = Cache(maxsize=1000)
         | 
| 28 |  | 
| 29 | 
             
            # Load the model using the transformers pipeline
         | 
| 30 | 
            -
            model = pipeline("image-classification", model="dima806/deepfake_vs_real_image_detection")
         | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
            # model = ViTForImageClassification.from_pretrained("Wvolf/ViT_Deepfake_Detection")
         | 
| 34 |  | 
| 35 | 
             
            # Detect the device used by TensorFlow
         | 
| 36 | 
             
            # DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
         | 
| @@ -84,16 +83,26 @@ async def classify_image(file: UploadFile = File(None)): | |
| 84 |  | 
| 85 | 
             
                    image = Image.open(io.BytesIO(image_data))
         | 
| 86 |  | 
| 87 | 
            -
                     | 
| 88 | 
            -
                    inputs = model(image)
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                     | 
| 91 | 
            -
             | 
| 92 | 
            -
                     | 
| 93 | 
            -
                     | 
| 94 | 
            -
                     | 
| 95 | 
            -
                     | 
| 96 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 97 |  | 
| 98 | 
             
                # model predicts one of the 1000 ImageNet classes
         | 
| 99 | 
             
                #     predicted_label = logits.argmax(-1).item()
         | 
| @@ -101,16 +110,16 @@ async def classify_image(file: UploadFile = File(None)): | |
| 101 | 
             
                #     logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
         | 
| 102 | 
             
                # # print(model.config.id2label[predicted_label])
         | 
| 103 | 
             
                # Find the prediction with the highest confidence using the max() function
         | 
| 104 | 
            -
                    predicted_label = max(inputs, key=lambda x: x["score"])
         | 
| 105 | 
             
                # logging.info("best_prediction %s", best_prediction)
         | 
| 106 | 
             
                # best_prediction2 = results[1]["label"]
         | 
| 107 | 
             
                # logging.info("best_prediction2 %s", best_prediction2)
         | 
| 108 |  | 
| 109 | 
             
                # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
         | 
| 110 | 
            -
                    confidence = round(predicted_label["score"] * 100, 1)
         | 
| 111 |  | 
| 112 | 
             
                # # Prepare the custom response data
         | 
| 113 | 
            -
                     | 
| 114 | 
             
                        "prediction": predicted_label,
         | 
| 115 | 
             
                        "confidence":confidence,
         | 
| 116 | 
             
                    }
         | 
| @@ -130,20 +139,20 @@ async def classify_image(file: UploadFile = File(None)): | |
| 130 | 
             
                    # }
         | 
| 131 |  | 
| 132 | 
             
                    # Populate hash
         | 
| 133 | 
            -
                    cache[image_hash] =  | 
| 134 |  | 
| 135 | 
             
                    # Add url to the API response
         | 
| 136 | 
            -
                     | 
| 137 |  | 
| 138 | 
            -
                    response_data.append(detection_result)
         | 
| 139 |  | 
| 140 | 
             
                    # Add file_name to the API response
         | 
| 141 | 
            -
                    response_data["file_name"] = file.filename
         | 
| 142 |  | 
| 143 | 
             
                    return FileImageDetectionResponse(**response_data)
         | 
| 144 |  | 
| 145 | 
            -
                 | 
| 146 | 
            -
                except PipelineException as e:
         | 
| 147 | 
             
                    logging.error("Error processing image: %s", str(e))
         | 
| 148 | 
             
                    raise HTTPException(
         | 
| 149 | 
             
                        status_code=500, detail=f"Error processing image: {str(e)}"
         | 
| @@ -172,29 +181,29 @@ async def classify_images(request: ImageUrlsRequest): | |
| 172 | 
             
                            continue
         | 
| 173 |  | 
| 174 | 
             
                        image = Image.open(io.BytesIO(image_data))
         | 
| 175 | 
            -
                         | 
| 176 | 
            -
                        inputs = model(image)
         | 
| 177 | 
            -
             | 
| 178 | 
            -
             | 
| 179 | 
            -
                         | 
| 180 | 
            -
             | 
| 181 | 
            -
                         | 
| 182 | 
            -
                         | 
| 183 | 
            -
                         | 
| 184 | 
            -
                         | 
| 185 | 
            -
                         | 
| 186 | 
            -
                         | 
| 187 | 
            -
                         | 
| 188 | 
            -
                         | 
| 189 | 
            -
                         | 
| 190 |  | 
| 191 | 
             
                    # model predicts one of the 1000 ImageNet classes
         | 
| 192 | 
             
                    #     predicted_label = logits.argmax(-1).item()
         | 
| 193 | 
             
                    #     logging.info("predicted_label", predicted_label)
         | 
| 194 | 
             
                    #     logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
         | 
| 195 | 
             
                    # # print(model.config.id2label[predicted_label])
         | 
| 196 | 
            -
                        logging.info("inputs %s", inputs)
         | 
| 197 | 
            -
                        predicted_label = max(inputs, key=lambda x: x["score"])
         | 
| 198 | 
             
                # best_prediction = max(results, key=lambda x: x["score"])
         | 
| 199 | 
             
                    # logging.info("best_prediction %s", best_prediction)
         | 
| 200 | 
             
                    # best_prediction2 = results[1]["label"]
         | 
| @@ -202,7 +211,7 @@ async def classify_images(request: ImageUrlsRequest): | |
| 202 |  | 
| 203 | 
             
                    # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
         | 
| 204 | 
             
                        # confidence_percentage = round(best_prediction["score"] * 100, 1)
         | 
| 205 | 
            -
                        confidence = round(predicted_label["score"] * 100, 1)
         | 
| 206 |  | 
| 207 | 
             
                    # # Prepare the custom response data
         | 
| 208 | 
             
                        detection_result = {
         | 
| @@ -232,8 +241,8 @@ async def classify_images(request: ImageUrlsRequest): | |
| 232 |  | 
| 233 | 
             
                        response_data.append(detection_result)
         | 
| 234 |  | 
| 235 | 
            -
                     | 
| 236 | 
            -
                    except PipelineException as e:
         | 
| 237 | 
             
                        logging.error("Error processing image from %s: %s", image_url, str(e))
         | 
| 238 | 
             
                        raise HTTPException(
         | 
| 239 | 
             
                            status_code=500,
         | 
|  | |
| 5 | 
             
            from fastapi import FastAPI, File, UploadFile, HTTPException
         | 
| 6 | 
             
            from fastapi.responses import JSONResponse
         | 
| 7 |  | 
| 8 | 
            +
            from transformers import AutoImageProcessor, ViTForImageClassification
         | 
| 9 | 
            +
            # from transformers import pipeline
         | 
| 10 | 
            +
            # from transformers.pipelines import PipelineException
         | 
| 11 | 
             
            from PIL import Image
         | 
| 12 | 
             
            from cachetools import Cache
         | 
| 13 | 
             
            import torch
         | 
|  | |
| 27 | 
             
            cache = Cache(maxsize=1000)
         | 
| 28 |  | 
| 29 | 
             
            # Load the model using the transformers pipeline
         | 
| 30 | 
            +
            # model = pipeline("image-classification", model="dima806/deepfake_vs_real_image_detection")
         | 
| 31 | 
            +
            image_processor = AutoImageProcessor.from_pretrained("dima806/deepfake_vs_real_image_detection")
         | 
| 32 | 
            +
            model = ViTForImageClassification.from_pretrained("dima806/deepfake_vs_real_image_detection")
         | 
|  | |
| 33 |  | 
| 34 | 
             
            # Detect the device used by TensorFlow
         | 
| 35 | 
             
            # DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
         | 
|  | |
| 83 |  | 
| 84 | 
             
                    image = Image.open(io.BytesIO(image_data))
         | 
| 85 |  | 
| 86 | 
            +
                    inputs = image_processor(image, return_tensors="pt")
         | 
| 87 | 
            +
                    # inputs = model(image)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    with torch.no_grad():
         | 
| 90 | 
            +
                        outpus = model(**inputs)
         | 
| 91 | 
            +
                    logits = outpus.logits
         | 
| 92 | 
            +
                    logging.info("logits %s", logits)
         | 
| 93 | 
            +
                    probs = F.softmax(logits, dim=-1)
         | 
| 94 | 
            +
                    logging.info("probs %s", probs)
         | 
| 95 | 
            +
                    predicted_label_id = probs.argmax(-1).item()
         | 
| 96 | 
            +
                    logging.info("predicted_label_id %s", predicted_label_id)
         | 
| 97 | 
            +
                    predicted_label = model.config.id2label[predicted_label_id]
         | 
| 98 | 
            +
                    logging.info("model.config.id2label %s", model.config.id2label)
         | 
| 99 | 
            +
                    confidence = probs.max().item()
         | 
| 100 | 
            +
            # outpus = model(**inputs)
         | 
| 101 | 
            +
            #             logits = outpus.logits
         | 
| 102 | 
            +
            #             probs = F.softmax(logits, dim=-1)
         | 
| 103 | 
            +
            #             predicted_label_id = probs.argmax(-1).item()
         | 
| 104 | 
            +
            #             predicted_label = model.config.id2label[predicted_label_id]
         | 
| 105 | 
            +
            #             confidence = probs.max().item()
         | 
| 106 |  | 
| 107 | 
             
                # model predicts one of the 1000 ImageNet classes
         | 
| 108 | 
             
                #     predicted_label = logits.argmax(-1).item()
         | 
|  | |
| 110 | 
             
                #     logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
         | 
| 111 | 
             
                # # print(model.config.id2label[predicted_label])
         | 
| 112 | 
             
                # Find the prediction with the highest confidence using the max() function
         | 
| 113 | 
            +
                    # predicted_label = max(inputs, key=lambda x: x["score"])
         | 
| 114 | 
             
                # logging.info("best_prediction %s", best_prediction)
         | 
| 115 | 
             
                # best_prediction2 = results[1]["label"]
         | 
| 116 | 
             
                # logging.info("best_prediction2 %s", best_prediction2)
         | 
| 117 |  | 
| 118 | 
             
                # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
         | 
| 119 | 
            +
                    # confidence = round(predicted_label["score"] * 100, 1)
         | 
| 120 |  | 
| 121 | 
             
                # # Prepare the custom response data
         | 
| 122 | 
            +
                    response_data = {
         | 
| 123 | 
             
                        "prediction": predicted_label,
         | 
| 124 | 
             
                        "confidence":confidence,
         | 
| 125 | 
             
                    }
         | 
|  | |
| 139 | 
             
                    # }
         | 
| 140 |  | 
| 141 | 
             
                    # Populate hash
         | 
| 142 | 
            +
                    cache[image_hash] = response_data.copy()
         | 
| 143 |  | 
| 144 | 
             
                    # Add url to the API response
         | 
| 145 | 
            +
                    response_data["file_name"] = file.filename
         | 
| 146 |  | 
| 147 | 
            +
                    # response_data.append(detection_result)
         | 
| 148 |  | 
| 149 | 
             
                    # Add file_name to the API response
         | 
| 150 | 
            +
                    # response_data["file_name"] = file.filename
         | 
| 151 |  | 
| 152 | 
             
                    return FileImageDetectionResponse(**response_data)
         | 
| 153 |  | 
| 154 | 
            +
                except Exception as e:
         | 
| 155 | 
            +
                # except PipelineException as e:
         | 
| 156 | 
             
                    logging.error("Error processing image: %s", str(e))
         | 
| 157 | 
             
                    raise HTTPException(
         | 
| 158 | 
             
                        status_code=500, detail=f"Error processing image: {str(e)}"
         | 
|  | |
| 181 | 
             
                            continue
         | 
| 182 |  | 
| 183 | 
             
                        image = Image.open(io.BytesIO(image_data))
         | 
| 184 | 
            +
                        inputs = image_processor(image, return_tensors="pt")
         | 
| 185 | 
            +
                        # inputs = model(image)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
                        with torch.no_grad():
         | 
| 189 | 
            +
                            outpus = model(**inputs)
         | 
| 190 | 
            +
                        logits = outpus.logits
         | 
| 191 | 
            +
                        logging.info("logits %s", logits)
         | 
| 192 | 
            +
                        probs = F.softmax(logits, dim=-1)
         | 
| 193 | 
            +
                        logging.info("probs %s", probs)
         | 
| 194 | 
            +
                        predicted_label_id = probs.argmax(-1).item()
         | 
| 195 | 
            +
                        logging.info("predicted_label_id %s", predicted_label_id)
         | 
| 196 | 
            +
                        predicted_label = model.config.id2label[predicted_label_id]
         | 
| 197 | 
            +
                        logging.info("model.config.id2label %s", model.config.id2label)
         | 
| 198 | 
            +
                        confidence = probs.max().item()
         | 
| 199 |  | 
| 200 | 
             
                    # model predicts one of the 1000 ImageNet classes
         | 
| 201 | 
             
                    #     predicted_label = logits.argmax(-1).item()
         | 
| 202 | 
             
                    #     logging.info("predicted_label", predicted_label)
         | 
| 203 | 
             
                    #     logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
         | 
| 204 | 
             
                    # # print(model.config.id2label[predicted_label])
         | 
| 205 | 
            +
                        # logging.info("inputs %s", inputs)
         | 
| 206 | 
            +
                        # predicted_label = max(inputs, key=lambda x: x["score"])
         | 
| 207 | 
             
                # best_prediction = max(results, key=lambda x: x["score"])
         | 
| 208 | 
             
                    # logging.info("best_prediction %s", best_prediction)
         | 
| 209 | 
             
                    # best_prediction2 = results[1]["label"]
         | 
|  | |
| 211 |  | 
| 212 | 
             
                    # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
         | 
| 213 | 
             
                        # confidence_percentage = round(best_prediction["score"] * 100, 1)
         | 
| 214 | 
            +
                        # confidence = round(predicted_label["score"] * 100, 1)
         | 
| 215 |  | 
| 216 | 
             
                    # # Prepare the custom response data
         | 
| 217 | 
             
                        detection_result = {
         | 
|  | |
| 241 |  | 
| 242 | 
             
                        response_data.append(detection_result)
         | 
| 243 |  | 
| 244 | 
            +
                    except Exception as e:
         | 
| 245 | 
            +
                    # except PipelineException as e:
         | 
| 246 | 
             
                        logging.error("Error processing image from %s: %s", image_url, str(e))
         | 
| 247 | 
             
                        raise HTTPException(
         | 
| 248 | 
             
                            status_code=500,
         | 
    	
        models.py
    CHANGED
    
    | @@ -23,8 +23,8 @@ class ImageDetectionResponse(BaseModel): | |
| 23 | 
             
                    confidence_percentage (float): Confidence level of the NSFW classification.
         | 
| 24 | 
             
                """
         | 
| 25 |  | 
| 26 | 
            -
                 | 
| 27 | 
            -
                 | 
| 28 |  | 
| 29 |  | 
| 30 | 
             
            class FileImageDetectionResponse(ImageDetectionResponse):
         | 
|  | |
| 23 | 
             
                    confidence_percentage (float): Confidence level of the NSFW classification.
         | 
| 24 | 
             
                """
         | 
| 25 |  | 
| 26 | 
            +
                prediction: str
         | 
| 27 | 
            +
                confidence: float
         | 
| 28 |  | 
| 29 |  | 
| 30 | 
             
            class FileImageDetectionResponse(ImageDetectionResponse):
         |