import io import hashlib import logging import aiohttp from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse # import os # from os import path # cache_path = path.join(path.dirname(path.abspath(__file__)), "models") # os.environ["TRANSFORMERS_CACHE"] = cache_path # os.environ["HF_HUB_CACHE"] = cache_path # os.environ["HF_HOME"] = cache_path # PATH = 'huggingface' # DATASETPATH = '/home/ahmadzen/.cache/huggingface/datasets' # MODEL_PATH = '/home/ahmadzen/ViT_Deepfake_Detection/SavedModel' # os.environ['HF_HOME'] = PATH # os.environ['HF_DATASETS_CACHE'] = DATASETPATH # os.environ['TORCH_HOME'] = PATH # os.environ['HF_HUB_CACHE'] = '/home/ahmadzen/.cache/huggingface' # from transformers import AutoImageProcessor, ViTForImageClassification from transformers import pipeline from transformers.pipelines import PipelineException from PIL import Image from cachetools import Cache import torch import torch.nn.functional as F from models import ( FileImageDetectionResponse, UrlImageDetectionResponse, ImageUrlsRequest, ) app = FastAPI() logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # Initialize Cache with no TTL cache = Cache(maxsize=1000) # Load the model using the transformers pipeline model = pipeline("image-classification", model="Wvolf/ViT_Deepfake_Detection") # image_processor = AutoImageProcessor.from_pretrained("Wvolf/ViT_Deepfake_Detection") # model = ViTForImageClassification.from_pretrained("Wvolf/ViT_Deepfake_Detection") # Detect the device used by TensorFlow # DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU" # logging.info("TensorFlow version: %s", tf.__version__) # logging.info("Model is using: %s", DEVICE) # if DEVICE == "GPU": # logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU"))) async def download_image(image_url: str) -> bytes: """Download an image from a URL.""" async with aiohttp.ClientSession() as session: async with session.get(image_url) as response: if response.status != 200: raise HTTPException( status_code=response.status, detail="Image could not be retrieved." ) return await response.read() def hash_data(data): """Function for hashing image data.""" return hashlib.sha256(data).hexdigest() @app.post("/v1/detect", response_model=FileImageDetectionResponse) async def classify_image(file: UploadFile = File(None)): """Function analyzing image.""" if file is None: raise HTTPException( status_code=400, detail="An image file must be provided.", ) try: logging.info("Processing %s", file.filename) # Read the image file image_data = await file.read() image_hash = hash_data(image_data) if image_hash in cache: # Return cached entry logging.info("Returning cached entry for %s", file.filename) cached_response = cache[image_hash] response_data = {**cached_response, "file_name": file.filename} return FileImageDetectionResponse(**response_data) image = Image.open(io.BytesIO(image_data)) inputs = model(image) # with torch.no_grad(): # logits = model(**inputs).logits # probs = F.softmax(logits, dim=-1) # predicted_label_id = probs.argmax(-1).item() # predicted_label = model.config.id2label[predicted_label_id] # confidence = probs.max().item() # model predicts one of the 1000 ImageNet classes # predicted_label = logits.argmax(-1).item() # logging.info("predicted_label", predicted_label) # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label]) # # print(model.config.id2label[predicted_label]) # Find the prediction with the highest confidence using the max() function predicted_label = max(inputs, key=lambda x: x["score"]) # logging.info("best_prediction %s", best_prediction) # best_prediction2 = results[1]["label"] # logging.info("best_prediction2 %s", best_prediction2) # # Calculate the confidence score, rounded to the nearest tenth and as a percentage confidence = round(predicted_label["score"] * 100, 1) # # Prepare the custom response data detection_result = { "prediction": predicted_label, "confidence_percentage":confidence, } # Use the model to classify the image # results = model(image) # Find the prediction with the highest confidence using the max() function # best_prediction = max(results, key=lambda x: x["score"]) # Calculate the confidence score, rounded to the nearest tenth and as a percentage # confidence_percentage = round(best_prediction["score"] * 100, 1) # Prepare the custom response data # detection_result = { # "is_nsfw": best_prediction["label"] == "nsfw", # "confidence_percentage": confidence_percentage, # } # Populate hash cache[image_hash] = detection_result.copy() # Add url to the API response detection_result["file_name"] = file.filename response_data.append(detection_result) # Add file_name to the API response response_data["file_name"] = file.filename return FileImageDetectionResponse(**response_data) except PipelineException as e: logging.error("Error processing image: %s", str(e)) raise HTTPException( status_code=500, detail=f"Error processing image: {str(e)}" ) from e @app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse]) async def classify_images(request: ImageUrlsRequest): """Function analyzing images from URLs.""" response_data = [] for image_url in request.urls: try: logging.info("Downloading image from URL: %s", image_url) image_data = await download_image(image_url) image_hash = hash_data(image_data) if image_hash in cache: # Return cached entry logging.info("Returning cached entry for %s", image_url) cached_response = cache[image_hash] response = {**cached_response, "url": image_url} response_data.append(response) continue image = Image.open(io.BytesIO(image_data)) inputs = model(image) # with torch.no_grad(): # logits = model(**inputs).logits # probs = F.softmax(logits, dim=-1) # predicted_label_id = probs.argmax(-1).item() # predicted_label = model.config.id2label[predicted_label_id] # confidence = probs.max().item() # model predicts one of the 1000 ImageNet classes # predicted_label = logits.argmax(-1).item() # logging.info("predicted_label", predicted_label) # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label]) # # print(model.config.id2label[predicted_label]) predicted_label = max(inputs, key=lambda x: x["score"]) # best_prediction = max(results, key=lambda x: x["score"]) # logging.info("best_prediction %s", best_prediction) # best_prediction2 = results[1]["label"] # logging.info("best_prediction2 %s", best_prediction2) # # Calculate the confidence score, rounded to the nearest tenth and as a percentage # confidence_percentage = round(best_prediction["score"] * 100, 1) confidence = round(predicted_label["score"] * 100, 1) # # Prepare the custom response data detection_result = { "prediction": predicted_label, "confidence_percentage":confidence, } # Use the model to classify the image # results = model(image) # Find the prediction with the highest confidence using the max() function # best_prediction = max(results, key=lambda x: x["score"]) # Calculate the confidence score, rounded to the nearest tenth and as a percentage # confidence_percentage = round(best_prediction["score"] * 100, 1) # Prepare the custom response data # detection_result = { # "is_nsfw": best_prediction["label"] == "nsfw", # "confidence_percentage": confidence_percentage, # } # Populate hash cache[image_hash] = detection_result.copy() # Add url to the API response detection_result["url"] = image_url response_data.append(detection_result) except PipelineException as e: logging.error("Error processing image from %s: %s", image_url, str(e)) raise HTTPException( status_code=500, detail=f"Error processing image from {image_url}: {str(e)}", ) from e return JSONResponse(status_code=200, content=response_data) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)