aznasut's picture
add endpoint for single url
538fbd6
import io
import hashlib
import logging
import aiohttp
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
# from transformers import AutoImageProcessor, ViTForImageClassification
from transformers import pipeline
from transformers.pipelines import PipelineException
from PIL import Image
from cachetools import Cache
import torch
import torch.nn.functional as F
from models import (
FileImageDetectionResponse,
UrlImageDetectionResponse,
ImageUrlsRequest,
ImageUrlRequest,
)
app = FastAPI()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Initialize Cache with no TTL
cache = Cache(maxsize=1000)
# Load the model using the transformers pipeline
model = pipeline("image-classification", model="aznasut/ai_vs_fake_image")
#model = pipeline("image-classification", model="NYUAD-ComNets/AI-generated_images_detector")
# image_processor = AutoImageProcessor.from_pretrained("dima806/deepfake_vs_real_image_detection")
# model = ViTForImageClassification.from_pretrained("dima806/deepfake_vs_real_image_detection")
#model.save_pretrained("./SavedModel")
# Detect the device used by TensorFlow
# DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
# logging.info("TensorFlow version: %s", tf.__version__)
# logging.info("Model is using: %s", DEVICE)
# if DEVICE == "GPU":
# logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU")))
async def download_image(image_url: str) -> bytes:
"""Download an image from a URL."""
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status != 200:
raise HTTPException(
status_code=response.status, detail="Image could not be retrieved."
)
return await response.read()
def hash_data(data):
"""Function for hashing image data."""
return hashlib.sha256(data).hexdigest()
@app.post("/v1/detect", response_model=FileImageDetectionResponse)
async def classify_image(file: UploadFile = File(None)):
"""Function analyzing image."""
if file is None:
raise HTTPException(
status_code=400,
detail="An image file must be provided.",
)
try:
logging.info("Processing %s", file.filename)
# Read the image file
image_data = await file.read()
image_hash = hash_data(image_data)
if image_hash in cache:
# Return cached entry
logging.info("Returning cached entry for %s", file.filename)
cached_response = cache[image_hash]
response_data = {**cached_response, "file_name": file.filename}
return FileImageDetectionResponse(**response_data)
image = Image.open(io.BytesIO(image_data))
# inputs = image_processor(image, return_tensors="pt")
inputs = model(image)
logging.info("inputs %s", inputs)
predicted_label = max(inputs, key=lambda x: x["score"])
confidence = round(predicted_label["score"] * 100, 1)
# # Prepare the custom response data
response_data = {
# "prediction": predicted_label,
"prediction": predicted_label["label"],
"confidence": str(confidence),
}
# Populate hash
cache[image_hash] = response_data.copy()
# Add url to the API response
response_data["file_name"] = file.filename
return FileImageDetectionResponse(**response_data)
# except Exception as e:
except PipelineException as e:
logging.error("Error processing image: %s", str(e))
raise HTTPException(
status_code=500, detail=f"Error processing image: {str(e)}"
) from e
@app.post("/v1/detect/url", response_model=UrlImageDetectionResponse)
async def classify_images(request: ImageUrlRequest):
try:
image_url = request.url
logging.info("Downloading image from URL: %s", image_url)
image_data = await download_image(image_url)
image_hash = hash_data(image_data)
if image_hash in cache:
# Return cached entry
logging.info("Returning cached entry for %s", image_url)
cached_response = cache[image_hash]
response_data = {**cached_response, "url": image_url}
return UrlImageDetectionResponse(**response_data)
image = Image.open(io.BytesIO(image_data))
# inputs = image_processor(image, return_tensors="pt")
inputs = model(image)
predicted_label = max(inputs, key=lambda x: x["score"])
confidence = round(predicted_label["score"] * 100, 1)
response_data = {
"prediction": predicted_label["label"],
"confidence": str(confidence),
}
# Populate hash
cache[image_hash] = response_data.copy()
# Add url to the API response
response_data["url"] = image_url
return UrlImageDetectionResponse(**response_data)
# except Exception as e:
except PipelineException as e:
logging.error("Error processing image from %s: %s", image_url, str(e))
raise HTTPException(
status_code=500,
detail=f"Error processing image from {image_url}: {str(e)}",
) from e
@app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse])
async def classify_images(request: ImageUrlsRequest):
"""Function analyzing images from URLs."""
response_data = []
for image_url in request.urls:
try:
logging.info("Downloading image from URL: %s", image_url)
image_data = await download_image(image_url)
image_hash = hash_data(image_data)
if image_hash in cache:
# Return cached entry
logging.info("Returning cached entry for %s", image_url)
cached_response = cache[image_hash]
response = {**cached_response, "url": image_url}
response_data.append(response)
continue
image = Image.open(io.BytesIO(image_data))
# inputs = image_processor(image, return_tensors="pt")
inputs = model(image)
# with torch.no_grad():
# outpus = model(**inputs)
# logits = outpus.logits
# logging.info("logits %s", logits)
# probs = F.softmax(logits, dim=-1)
# logging.info("probs %s", probs)
# predicted_label_id = probs.argmax(-1).item()
# logging.info("predicted_label_id %s", predicted_label_id)
# predicted_label = model.config.id2label[predicted_label_id]
# logging.info("model.config.id2label %s", model.config.id2label)
# confidence = probs.max().item()
# model predicts one of the 1000 ImageNet classes
# predicted_label = logits.argmax(-1).item()
# logging.info("predicted_label", predicted_label)
# logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
# # print(model.config.id2label[predicted_label])
# logging.info("inputs %s", inputs)
predicted_label = max(inputs, key=lambda x: x["score"])
# best_prediction = max(results, key=lambda x: x["score"])
# logging.info("best_prediction %s", best_prediction)
# best_prediction2 = results[1]["label"]
# logging.info("best_prediction2 %s", best_prediction2)
# # Calculate the confidence score, rounded to the nearest tenth and as a percentage
# confidence_percentage = round(best_prediction["score"] * 100, 1)
confidence = round(predicted_label["score"] * 100, 1)
# # Prepare the custom response data
detection_result = {
# "prediction": predicted_label,
"prediction": predicted_label["label"],
"confidence": str(confidence),
}
# Use the model to classify the image
# results = model(image)
# Find the prediction with the highest confidence using the max() function
# best_prediction = max(results, key=lambda x: x["score"])
# Calculate the confidence score, rounded to the nearest tenth and as a percentage
# confidence_percentage = round(best_prediction["score"] * 100, 1)
# Prepare the custom response data
# detection_result = {
# "is_nsfw": best_prediction["label"] == "nsfw",
# "confidence_percentage": confidence_percentage,
# }
# Populate hash
cache[image_hash] = detection_result.copy()
# Add url to the API response
detection_result["url"] = image_url
response_data.append(detection_result)
# except Exception as e:
except PipelineException as e:
logging.error("Error processing image from %s: %s", image_url, str(e))
raise HTTPException(
status_code=500,
detail=f"Error processing image from {image_url}: {str(e)}",
) from e
return JSONResponse(status_code=200, content=response_data)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)