2024-04-22 20:56:10 +03:00
|
|
|
"""Module providing an API for NSFW image detection."""
|
|
|
|
|
|
|
|
import io
|
|
|
|
import hashlib
|
|
|
|
import logging
|
2024-04-22 18:43:04 +03:00
|
|
|
from fastapi import FastAPI, File, UploadFile
|
|
|
|
from fastapi.responses import JSONResponse
|
2024-04-22 21:02:01 +03:00
|
|
|
from transformers import pipeline
|
|
|
|
from transformers.pipelines import PipelineException
|
2024-04-22 18:43:04 +03:00
|
|
|
from PIL import Image
|
|
|
|
from cachetools import Cache
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
2024-04-22 21:12:11 +03:00
|
|
|
logging.basicConfig(
|
|
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
|
)
|
2024-04-22 18:43:04 +03:00
|
|
|
|
|
|
|
# Initialize Cache with no TTL
|
2024-04-22 21:12:11 +03:00
|
|
|
cache = Cache(maxsize=1000)
|
2024-04-22 18:43:04 +03:00
|
|
|
|
|
|
|
# Load the model using the transformers pipeline
|
|
|
|
model = pipeline("image-classification", model="falconsai/nsfw_image_detection")
|
|
|
|
|
2024-04-22 21:12:11 +03:00
|
|
|
|
2024-04-22 18:43:04 +03:00
|
|
|
def hash_data(data):
|
2024-04-22 20:56:10 +03:00
|
|
|
"""Function for hashing image data."""
|
2024-04-22 18:43:04 +03:00
|
|
|
return hashlib.sha256(data).hexdigest()
|
|
|
|
|
2024-04-22 21:12:11 +03:00
|
|
|
|
2024-04-22 18:43:04 +03:00
|
|
|
@app.post("/api/v1/detect")
|
|
|
|
async def classify_image(file: UploadFile = File(...)):
|
2024-04-22 20:56:10 +03:00
|
|
|
"""Function analyzing image."""
|
2024-04-22 18:43:04 +03:00
|
|
|
try:
|
2024-04-22 20:56:10 +03:00
|
|
|
logging.info("Processing %s", file.filename)
|
2024-04-22 18:43:04 +03:00
|
|
|
# Read the image file
|
|
|
|
image_data = await file.read()
|
|
|
|
image_hash = hash_data(image_data)
|
|
|
|
|
|
|
|
if image_hash in cache:
|
|
|
|
# Return cached entry
|
2024-04-22 20:56:10 +03:00
|
|
|
logging.info("Returning cached entry for %s", file.filename)
|
2024-04-22 18:43:04 +03:00
|
|
|
return JSONResponse(status_code=200, content=cache[image_hash])
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(image_data))
|
|
|
|
|
|
|
|
# Use the model to classify the image
|
|
|
|
results = model(image)
|
|
|
|
|
|
|
|
# Find the prediction with the highest confidence using the max() function
|
2024-04-22 21:12:11 +03:00
|
|
|
best_prediction = max(results, key=lambda x: x["score"])
|
2024-04-22 18:43:04 +03:00
|
|
|
|
|
|
|
# Calculate the confidence score, rounded to the nearest tenth and as a percentage
|
2024-04-22 21:12:11 +03:00
|
|
|
confidence_percentage = round(best_prediction["score"] * 100, 1)
|
2024-04-22 18:43:04 +03:00
|
|
|
|
|
|
|
# Prepare the custom response data
|
|
|
|
response_data = {
|
2024-04-22 20:50:21 +03:00
|
|
|
"file_name": file.filename,
|
2024-04-22 21:12:11 +03:00
|
|
|
"is_nsfw": best_prediction["label"] == "nsfw",
|
|
|
|
"confidence_percentage": confidence_percentage,
|
2024-04-22 18:43:04 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
# Populate hash
|
|
|
|
cache[image_hash] = response_data
|
|
|
|
|
|
|
|
return JSONResponse(status_code=200, content=response_data)
|
|
|
|
|
2024-04-22 20:58:37 +03:00
|
|
|
except PipelineException as e:
|
2024-04-22 18:43:04 +03:00
|
|
|
return JSONResponse(status_code=500, content={"message": str(e)})
|
|
|
|
|
2024-04-22 21:12:11 +03:00
|
|
|
|
2024-04-22 18:43:04 +03:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import uvicorn
|
2024-04-22 21:12:11 +03:00
|
|
|
|
2024-04-22 18:43:04 +03:00
|
|
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|