This commit is contained in:
Boroumand, Amir A 2024-04-22 14:12:11 -04:00
parent 54567a7b3a
commit 4c82208be9
3 changed files with 27 additions and 8 deletions

View File

@ -10,6 +10,10 @@ jobs:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: 'pip' # caching pip dependencies
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:

18
main.py
View File

@ -12,18 +12,22 @@ from cachetools import Cache
app = FastAPI()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Initialize Cache with no TTL
cache = Cache(maxsize=1000)
cache = Cache(maxsize=1000)
# Load the model using the transformers pipeline
model = pipeline("image-classification", model="falconsai/nsfw_image_detection")
def hash_data(data):
"""Function for hashing image data."""
return hashlib.sha256(data).hexdigest()
@app.post("/api/v1/detect")
async def classify_image(file: UploadFile = File(...)):
"""Function analyzing image."""
@ -44,16 +48,16 @@ async def classify_image(file: UploadFile = File(...)):
results = model(image)
# Find the prediction with the highest confidence using the max() function
best_prediction = max(results, key=lambda x: x['score'])
best_prediction = max(results, key=lambda x: x["score"])
# Calculate the confidence score, rounded to the nearest tenth and as a percentage
confidence_percentage = round(best_prediction['score'] * 100, 1)
confidence_percentage = round(best_prediction["score"] * 100, 1)
# Prepare the custom response data
response_data = {
"file_name": file.filename,
"is_nsfw": best_prediction['label'] == 'nsfw',
"confidence_percentage": confidence_percentage
"is_nsfw": best_prediction["label"] == "nsfw",
"confidence_percentage": confidence_percentage,
}
# Populate hash
@ -64,6 +68,8 @@ async def classify_image(file: UploadFile = File(...)):
except PipelineException as e:
return JSONResponse(status_code=500, content={"message": str(e)})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)

View File

@ -7,11 +7,20 @@ client = TestClient(app)
FILE_NAME = "sunflower.jpg"
def test_read_main():
"""Tests that POST /api/v1/detect returns 200 OK with valid request body"""
response = client.post("/api/v1/detect", files={"file": (FILE_NAME, open(FILE_NAME, "rb"), "image/jpeg")})
response = client.post(
"/api/v1/detect",
files={"file": (FILE_NAME, open(FILE_NAME, "rb"), "image/jpeg")},
)
assert response.status_code == 200
assert response.json() == {"file_name": FILE_NAME, "is_nsfw": False, "confidence_percentage": 100.0}
assert response.json() == {
"file_name": FILE_NAME,
"is_nsfw": False,
"confidence_percentage": 100.0,
}
def test_invalid_input():
"""Tests that POST /api/v1/detect returns 422 with empty request body"""