safe-content-ai/main.py

76 lines
2.2 KiB
Python
Raw Normal View History

2024-04-22 20:56:10 +03:00
"""Module providing an API for NSFW image detection."""
import io
import hashlib
import logging
2024-04-22 18:43:04 +03:00
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
2024-04-22 21:02:01 +03:00
from transformers import pipeline
from transformers.pipelines import PipelineException
2024-04-22 18:43:04 +03:00
from PIL import Image
from cachetools import Cache
app = FastAPI()
2024-04-22 21:12:11 +03:00
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
2024-04-22 18:43:04 +03:00
# Initialize Cache with no TTL
2024-04-22 21:12:11 +03:00
cache = Cache(maxsize=1000)
2024-04-22 18:43:04 +03:00
# Load the model using the transformers pipeline
model = pipeline("image-classification", model="falconsai/nsfw_image_detection")
2024-04-22 21:12:11 +03:00
2024-04-22 18:43:04 +03:00
def hash_data(data):
2024-04-22 20:56:10 +03:00
"""Function for hashing image data."""
2024-04-22 18:43:04 +03:00
return hashlib.sha256(data).hexdigest()
2024-04-22 21:12:11 +03:00
2024-04-22 18:43:04 +03:00
@app.post("/api/v1/detect")
async def classify_image(file: UploadFile = File(...)):
2024-04-22 20:56:10 +03:00
"""Function analyzing image."""
2024-04-22 18:43:04 +03:00
try:
2024-04-22 20:56:10 +03:00
logging.info("Processing %s", file.filename)
2024-04-22 18:43:04 +03:00
# Read the image file
image_data = await file.read()
image_hash = hash_data(image_data)
if image_hash in cache:
# Return cached entry
2024-04-22 20:56:10 +03:00
logging.info("Returning cached entry for %s", file.filename)
2024-04-22 18:43:04 +03:00
return JSONResponse(status_code=200, content=cache[image_hash])
image = Image.open(io.BytesIO(image_data))
# Use the model to classify the image
results = model(image)
# Find the prediction with the highest confidence using the max() function
2024-04-22 21:12:11 +03:00
best_prediction = max(results, key=lambda x: x["score"])
2024-04-22 18:43:04 +03:00
# Calculate the confidence score, rounded to the nearest tenth and as a percentage
2024-04-22 21:12:11 +03:00
confidence_percentage = round(best_prediction["score"] * 100, 1)
2024-04-22 18:43:04 +03:00
# Prepare the custom response data
response_data = {
2024-04-22 20:50:21 +03:00
"file_name": file.filename,
2024-04-22 21:12:11 +03:00
"is_nsfw": best_prediction["label"] == "nsfw",
"confidence_percentage": confidence_percentage,
2024-04-22 18:43:04 +03:00
}
# Populate hash
cache[image_hash] = response_data
return JSONResponse(status_code=200, content=response_data)
2024-04-22 20:58:37 +03:00
except PipelineException as e:
2024-04-22 18:43:04 +03:00
return JSONResponse(status_code=500, content={"message": str(e)})
2024-04-22 21:12:11 +03:00
2024-04-22 18:43:04 +03:00
if __name__ == "__main__":
import uvicorn
2024-04-22 21:12:11 +03:00
2024-04-22 18:43:04 +03:00
uvicorn.run(app, host="127.0.0.1", port=8000)