Spaces:
Running
Running
| import io | |
| import hashlib | |
| import logging | |
| import aiohttp | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| # from transformers import AutoImageProcessor, ViTForImageClassification | |
| from transformers import pipeline | |
| from transformers.pipelines import PipelineException | |
| from PIL import Image | |
| from cachetools import Cache | |
| import torch | |
| import torch.nn.functional as F | |
| from models import ( | |
| FileImageDetectionResponse, | |
| UrlImageDetectionResponse, | |
| ImageUrlsRequest, | |
| ImageUrlRequest, | |
| ) | |
| app = FastAPI() | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| # Initialize Cache with no TTL | |
| cache = Cache(maxsize=1000) | |
| # Load the model using the transformers pipeline | |
| model = pipeline("image-classification", model="aznasut/ai_vs_fake_image") | |
| #model = pipeline("image-classification", model="NYUAD-ComNets/AI-generated_images_detector") | |
| # image_processor = AutoImageProcessor.from_pretrained("dima806/deepfake_vs_real_image_detection") | |
| # model = ViTForImageClassification.from_pretrained("dima806/deepfake_vs_real_image_detection") | |
| #model.save_pretrained("./SavedModel") | |
| # Detect the device used by TensorFlow | |
| # DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU" | |
| # logging.info("TensorFlow version: %s", tf.__version__) | |
| # logging.info("Model is using: %s", DEVICE) | |
| # if DEVICE == "GPU": | |
| # logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU"))) | |
| async def download_image(image_url: str) -> bytes: | |
| """Download an image from a URL.""" | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(image_url) as response: | |
| if response.status != 200: | |
| raise HTTPException( | |
| status_code=response.status, detail="Image could not be retrieved." | |
| ) | |
| return await response.read() | |
| def hash_data(data): | |
| """Function for hashing image data.""" | |
| return hashlib.sha256(data).hexdigest() | |
| async def classify_image(file: UploadFile = File(None)): | |
| """Function analyzing image.""" | |
| if file is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="An image file must be provided.", | |
| ) | |
| try: | |
| logging.info("Processing %s", file.filename) | |
| # Read the image file | |
| image_data = await file.read() | |
| image_hash = hash_data(image_data) | |
| if image_hash in cache: | |
| # Return cached entry | |
| logging.info("Returning cached entry for %s", file.filename) | |
| cached_response = cache[image_hash] | |
| response_data = {**cached_response, "file_name": file.filename} | |
| return FileImageDetectionResponse(**response_data) | |
| image = Image.open(io.BytesIO(image_data)) | |
| # inputs = image_processor(image, return_tensors="pt") | |
| inputs = model(image) | |
| logging.info("inputs %s", inputs) | |
| predicted_label = max(inputs, key=lambda x: x["score"]) | |
| confidence = round(predicted_label["score"] * 100, 1) | |
| # # Prepare the custom response data | |
| response_data = { | |
| # "prediction": predicted_label, | |
| "prediction": predicted_label["label"], | |
| "confidence": str(confidence), | |
| } | |
| # Populate hash | |
| cache[image_hash] = response_data.copy() | |
| # Add url to the API response | |
| response_data["file_name"] = file.filename | |
| return FileImageDetectionResponse(**response_data) | |
| # except Exception as e: | |
| except PipelineException as e: | |
| logging.error("Error processing image: %s", str(e)) | |
| raise HTTPException( | |
| status_code=500, detail=f"Error processing image: {str(e)}" | |
| ) from e | |
| async def classify_images(request: ImageUrlRequest): | |
| try: | |
| image_url = request.url | |
| logging.info("Downloading image from URL: %s", image_url) | |
| image_data = await download_image(image_url) | |
| image_hash = hash_data(image_data) | |
| if image_hash in cache: | |
| # Return cached entry | |
| logging.info("Returning cached entry for %s", image_url) | |
| cached_response = cache[image_hash] | |
| response_data = {**cached_response, "url": image_url} | |
| return UrlImageDetectionResponse(**response_data) | |
| image = Image.open(io.BytesIO(image_data)) | |
| # inputs = image_processor(image, return_tensors="pt") | |
| inputs = model(image) | |
| predicted_label = max(inputs, key=lambda x: x["score"]) | |
| confidence = round(predicted_label["score"] * 100, 1) | |
| response_data = { | |
| "prediction": predicted_label["label"], | |
| "confidence": str(confidence), | |
| } | |
| # Populate hash | |
| cache[image_hash] = response_data.copy() | |
| # Add url to the API response | |
| response_data["url"] = image_url | |
| return UrlImageDetectionResponse(**response_data) | |
| # except Exception as e: | |
| except PipelineException as e: | |
| logging.error("Error processing image from %s: %s", image_url, str(e)) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing image from {image_url}: {str(e)}", | |
| ) from e | |
| async def classify_images(request: ImageUrlsRequest): | |
| """Function analyzing images from URLs.""" | |
| response_data = [] | |
| for image_url in request.urls: | |
| try: | |
| logging.info("Downloading image from URL: %s", image_url) | |
| image_data = await download_image(image_url) | |
| image_hash = hash_data(image_data) | |
| if image_hash in cache: | |
| # Return cached entry | |
| logging.info("Returning cached entry for %s", image_url) | |
| cached_response = cache[image_hash] | |
| response = {**cached_response, "url": image_url} | |
| response_data.append(response) | |
| continue | |
| image = Image.open(io.BytesIO(image_data)) | |
| # inputs = image_processor(image, return_tensors="pt") | |
| inputs = model(image) | |
| # with torch.no_grad(): | |
| # outpus = model(**inputs) | |
| # logits = outpus.logits | |
| # logging.info("logits %s", logits) | |
| # probs = F.softmax(logits, dim=-1) | |
| # logging.info("probs %s", probs) | |
| # predicted_label_id = probs.argmax(-1).item() | |
| # logging.info("predicted_label_id %s", predicted_label_id) | |
| # predicted_label = model.config.id2label[predicted_label_id] | |
| # logging.info("model.config.id2label %s", model.config.id2label) | |
| # confidence = probs.max().item() | |
| # model predicts one of the 1000 ImageNet classes | |
| # predicted_label = logits.argmax(-1).item() | |
| # logging.info("predicted_label", predicted_label) | |
| # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label]) | |
| # # print(model.config.id2label[predicted_label]) | |
| # logging.info("inputs %s", inputs) | |
| predicted_label = max(inputs, key=lambda x: x["score"]) | |
| # best_prediction = max(results, key=lambda x: x["score"]) | |
| # logging.info("best_prediction %s", best_prediction) | |
| # best_prediction2 = results[1]["label"] | |
| # logging.info("best_prediction2 %s", best_prediction2) | |
| # # Calculate the confidence score, rounded to the nearest tenth and as a percentage | |
| # confidence_percentage = round(best_prediction["score"] * 100, 1) | |
| confidence = round(predicted_label["score"] * 100, 1) | |
| # # Prepare the custom response data | |
| detection_result = { | |
| # "prediction": predicted_label, | |
| "prediction": predicted_label["label"], | |
| "confidence": str(confidence), | |
| } | |
| # Use the model to classify the image | |
| # results = model(image) | |
| # Find the prediction with the highest confidence using the max() function | |
| # best_prediction = max(results, key=lambda x: x["score"]) | |
| # Calculate the confidence score, rounded to the nearest tenth and as a percentage | |
| # confidence_percentage = round(best_prediction["score"] * 100, 1) | |
| # Prepare the custom response data | |
| # detection_result = { | |
| # "is_nsfw": best_prediction["label"] == "nsfw", | |
| # "confidence_percentage": confidence_percentage, | |
| # } | |
| # Populate hash | |
| cache[image_hash] = detection_result.copy() | |
| # Add url to the API response | |
| detection_result["url"] = image_url | |
| response_data.append(detection_result) | |
| # except Exception as e: | |
| except PipelineException as e: | |
| logging.error("Error processing image from %s: %s", image_url, str(e)) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing image from {image_url}: {str(e)}", | |
| ) from e | |
| return JSONResponse(status_code=200, content=response_data) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |