aznasut commited on
Commit
507cd9a
·
1 Parent(s): b2dfed5

Add application file

Browse files
Files changed (5) hide show
  1. .gitignore +160 -0
  2. Dockerfile +25 -0
  3. main.py +224 -0
  4. models.py +49 -0
  5. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.12-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ gcc \
10
+ build-essential \
11
+ pkg-config \
12
+ libhdf5-dev \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy the current directory contents into the container at /app
16
+ COPY . /app
17
+
18
+ # Install any needed packages specified in requirements.txt
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Make port 8000 available to the world outside this container
22
+ EXPOSE 8000
23
+
24
+ # Command to run the Uvicorn server
25
+ CMD ["uvicorn", "main:app", "--host", "localhost", "--port", "8000"]
main.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module providing an API for NSFW image detection."""
2
+
3
+ import io
4
+ import hashlib
5
+ import logging
6
+ import aiohttp
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from transformers import pipeline
10
+ from transformers.pipelines import PipelineException
11
+ from transformers import AutoImageProcessor, ViTForImageClassification
12
+ from PIL import Image
13
+ from cachetools import Cache
14
+ import tensorflow as tf
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from models import (
18
+ FileImageDetectionResponse,
19
+ UrlImageDetectionResponse,
20
+ ImageUrlsRequest,
21
+ )
22
+
23
+
24
+ app = FastAPI()
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
+ )
29
+
30
+ # Initialize Cache with no TTL
31
+ cache = Cache(maxsize=1000)
32
+
33
+ # Load the model using the transformers pipeline
34
+ # model = pipeline("image-classification", model="Wvolf/ViT_Deepfake_Detection")
35
+
36
+ # Detect the device used by TensorFlow
37
+ DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
38
+ logging.info("TensorFlow version: %s", tf.__version__)
39
+ logging.info("Model is using: %s", DEVICE)
40
+
41
+ if DEVICE == "GPU":
42
+ logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU")))
43
+
44
+
45
+ async def download_image(image_url: str) -> bytes:
46
+ """Download an image from a URL."""
47
+ async with aiohttp.ClientSession() as session:
48
+ async with session.get(image_url) as response:
49
+ if response.status != 200:
50
+ raise HTTPException(
51
+ status_code=response.status, detail="Image could not be retrieved."
52
+ )
53
+ return await response.read()
54
+
55
+
56
+ def hash_data(data):
57
+ """Function for hashing image data."""
58
+ return hashlib.sha256(data).hexdigest()
59
+
60
+
61
+ @app.post("/v1/detect", response_model=FileImageDetectionResponse)
62
+ async def classify_image(file: UploadFile = File(None)):
63
+ """Function analyzing image."""
64
+ if file is None:
65
+ raise HTTPException(
66
+ status_code=400,
67
+ detail="An image file must be provided.",
68
+ )
69
+
70
+ try:
71
+ logging.info("Processing %s", file.filename)
72
+
73
+ # Read the image file
74
+ image_data = await file.read()
75
+ image_hash = hash_data(image_data)
76
+
77
+ if image_hash in cache:
78
+ # Return cached entry
79
+ logging.info("Returning cached entry for %s", file.filename)
80
+
81
+ cached_response = cache[image_hash]
82
+ response_data = {**cached_response, "file_name": file.filename}
83
+
84
+ return FileImageDetectionResponse(**response_data)
85
+
86
+ image = Image.open(io.BytesIO(image_data))
87
+
88
+ # Use the model to classify the image
89
+ # results = model(image)
90
+
91
+ image_processor = AutoImageProcessor.from_pretrained("dima806/ai_vs_real_image_detection")
92
+ model = ViTForImageClassification.from_pretrained("dima806/ai_vs_real_image_detection")
93
+
94
+ inputs = image_processor(image, return_tensors="pt")
95
+
96
+ with torch.no_grad():
97
+ logits = model(**inputs).logits
98
+
99
+ # model predicts one of the 1000 ImageNet classes
100
+ predicted_label = logits.argmax(-1).item()
101
+ logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
102
+ # print(model.config.id2label[predicted_label])
103
+ # Find the prediction with the highest confidence using the max() function
104
+ # best_prediction = max(results, key=lambda x: x["score"])
105
+ # logging.info("best_prediction %s", best_prediction)
106
+ # best_prediction2 = results[1]["label"]
107
+ # logging.info("best_prediction2 %s", best_prediction2)
108
+
109
+ # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
110
+ # confidence_percentage = round(best_prediction["score"] * 100, 1)
111
+
112
+ # # Prepare the custom response data
113
+ response_data = {
114
+ "prediction": model.config.id2label[predicted_label],
115
+ "confidence_percentage":model.config.id2label[predicted_label],
116
+ }
117
+
118
+ # Populate hash
119
+ cache[image_hash] = response_data.copy()
120
+
121
+ # Add file_name to the API response
122
+ response_data["file_name"] = file.filename
123
+
124
+ return FileImageDetectionResponse(**response_data)
125
+
126
+ except PipelineException as e:
127
+ logging.error("Error processing image: %s", str(e))
128
+ raise HTTPException(
129
+ status_code=500, detail=f"Error processing image: {str(e)}"
130
+ ) from e
131
+
132
+
133
+ @app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse])
134
+ async def classify_images(request: ImageUrlsRequest):
135
+ """Function analyzing images from URLs."""
136
+ response_data = []
137
+
138
+ for image_url in request.urls:
139
+ try:
140
+ logging.info("Downloading image from URL: %s", image_url)
141
+ image_data = await download_image(image_url)
142
+ image_hash = hash_data(image_data)
143
+
144
+ if image_hash in cache:
145
+ # Return cached entry
146
+ logging.info("Returning cached entry for %s", image_url)
147
+
148
+ cached_response = cache[image_hash]
149
+ response = {**cached_response, "url": image_url}
150
+
151
+ response_data.append(response)
152
+ continue
153
+
154
+ image = Image.open(io.BytesIO(image_data))
155
+ image_processor = AutoImageProcessor.from_pretrained("Wvolf/ViT_Deepfake_Detection")
156
+ model = ViTForImageClassification.from_pretrained("Wvolf/ViT_Deepfake_Detection")
157
+ inputs = image_processor(image, return_tensors="pt")
158
+
159
+ with torch.no_grad():
160
+ logits = model(**inputs).logits
161
+ probs = F.softmax(logits, dim=-1)
162
+ predicted_label_id = probs.argmax(-1).item()
163
+ predicted_label = model.config.id2label[predicted_label_id]
164
+ confidence = probs.max().item()
165
+
166
+ # model predicts one of the 1000 ImageNet classes
167
+ # predicted_label = logits.argmax(-1).item()
168
+ # logging.info("predicted_label", predicted_label)
169
+ # logging.info("model.config.id2label[predicted_label] %s", model.config.id2label[predicted_label])
170
+ # # print(model.config.id2label[predicted_label])
171
+ # Find the prediction with the highest confidence using the max() function
172
+ # best_prediction = max(results, key=lambda x: x["score"])
173
+ # logging.info("best_prediction %s", best_prediction)
174
+ # best_prediction2 = results[1]["label"]
175
+ # logging.info("best_prediction2 %s", best_prediction2)
176
+
177
+ # # Calculate the confidence score, rounded to the nearest tenth and as a percentage
178
+ # confidence_percentage = round(best_prediction["score"] * 100, 1)
179
+
180
+ # # Prepare the custom response data
181
+ detection_result = {
182
+ "prediction": predicted_label,
183
+ "confidence_percentage":confidence,
184
+ }
185
+ # Use the model to classify the image
186
+ # results = model(image)
187
+
188
+ # Find the prediction with the highest confidence using the max() function
189
+ # best_prediction = max(results, key=lambda x: x["score"])
190
+
191
+ # Calculate the confidence score, rounded to the nearest tenth and as a percentage
192
+ # confidence_percentage = round(best_prediction["score"] * 100, 1)
193
+
194
+ # Prepare the custom response data
195
+ # detection_result = {
196
+ # "is_nsfw": best_prediction["label"] == "nsfw",
197
+ # "confidence_percentage": confidence_percentage,
198
+ # }
199
+
200
+ # Populate hash
201
+ cache[image_hash] = detection_result.copy()
202
+
203
+ # Add url to the API response
204
+ detection_result["url"] = image_url
205
+
206
+ response_data.append(detection_result)
207
+
208
+ except PipelineException as e:
209
+ logging.error("Error processing image from %s: %s", image_url, str(e))
210
+ raise HTTPException(
211
+ status_code=500,
212
+ detail=f"Error processing image from {image_url}: {str(e)}",
213
+ ) from e
214
+
215
+ return JSONResponse(status_code=200, content=response_data)
216
+
217
+ @app.get("/hello")
218
+ async def hello_world():
219
+ return {"message": "hello_world"}
220
+
221
+ if __name__ == "__main__":
222
+ import uvicorn
223
+
224
+ uvicorn.run(app, host="127.0.0.1", port=8000)
models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module providing base models."""
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ImageUrlsRequest(BaseModel):
7
+ """
8
+ Model representing the request body for the /v1/detect/urls endpoint.
9
+
10
+ Attributes:
11
+ urls (list[str]): List of image URLs to be processed.
12
+ """
13
+
14
+ urls: list[str]
15
+
16
+
17
+ class ImageDetectionResponse(BaseModel):
18
+ """
19
+ Base model representing the response body for image detection.
20
+
21
+ Attributes:
22
+ is_nsfw (bool): Whether the image is classified as NSFW.
23
+ confidence_percentage (float): Confidence level of the NSFW classification.
24
+ """
25
+
26
+ is_nsfw: bool
27
+ confidence_percentage: float
28
+
29
+
30
+ class FileImageDetectionResponse(ImageDetectionResponse):
31
+ """
32
+ Model extending ImageDetectionResponse with a file attribute.
33
+
34
+ Attributes:
35
+ file (str): The name of the file that was processed.
36
+ """
37
+
38
+ file_name: str
39
+
40
+
41
+ class UrlImageDetectionResponse(ImageDetectionResponse):
42
+ """
43
+ Model extending ImageDetectionResponse with a URL attribute.
44
+
45
+ Attributes:
46
+ url (str): The URL of the image that was processed.
47
+ """
48
+
49
+ url: str
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.110.2
2
+ uvicorn[standard]==0.29.0
3
+ transformers==4.40.0
4
+ aiohttp==3.9.5
5
+ pillow==10.3.0
6
+ python-multipart==0.0.9
7
+ tensorflow==2.16.1
8
+ tf-keras==2.16.0
9
+ cachetools===5.3.3
10
+ pydantic===2.7.2