From 4c82208be96480dfedf3e95387c20b0dc458ae33 Mon Sep 17 00:00:00 2001 From: "Boroumand, Amir A" Date: Mon, 22 Apr 2024 14:12:11 -0400 Subject: [PATCH] Updates --- .github/workflows/pylint.yml | 4 ++++ main.py | 18 ++++++++++++------ test_main.py | 13 +++++++++++-- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 58de2c6..ac97ebf 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -10,6 +10,10 @@ jobs: python-version: ["3.10"] steps: - uses: actions/checkout@v3 + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' # caching pip dependencies - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: diff --git a/main.py b/main.py index d29e2fc..fb95da7 100644 --- a/main.py +++ b/main.py @@ -12,18 +12,22 @@ from cachetools import Cache app = FastAPI() -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) # Initialize Cache with no TTL -cache = Cache(maxsize=1000) +cache = Cache(maxsize=1000) # Load the model using the transformers pipeline model = pipeline("image-classification", model="falconsai/nsfw_image_detection") + def hash_data(data): """Function for hashing image data.""" return hashlib.sha256(data).hexdigest() + @app.post("/api/v1/detect") async def classify_image(file: UploadFile = File(...)): """Function analyzing image.""" @@ -44,16 +48,16 @@ async def classify_image(file: UploadFile = File(...)): results = model(image) # Find the prediction with the highest confidence using the max() function - best_prediction = max(results, key=lambda x: x['score']) + best_prediction = max(results, key=lambda x: x["score"]) # Calculate the confidence score, rounded to the nearest tenth and as a percentage - confidence_percentage = round(best_prediction['score'] * 100, 1) + confidence_percentage = round(best_prediction["score"] * 100, 1) # Prepare the custom response data response_data = { "file_name": file.filename, - "is_nsfw": best_prediction['label'] == 'nsfw', - "confidence_percentage": confidence_percentage + "is_nsfw": best_prediction["label"] == "nsfw", + "confidence_percentage": confidence_percentage, } # Populate hash @@ -64,6 +68,8 @@ async def classify_image(file: UploadFile = File(...)): except PipelineException as e: return JSONResponse(status_code=500, content={"message": str(e)}) + if __name__ == "__main__": import uvicorn + uvicorn.run(app, host="127.0.0.1", port=8000) diff --git a/test_main.py b/test_main.py index c737ab5..c37923c 100644 --- a/test_main.py +++ b/test_main.py @@ -7,11 +7,20 @@ client = TestClient(app) FILE_NAME = "sunflower.jpg" + def test_read_main(): """Tests that POST /api/v1/detect returns 200 OK with valid request body""" - response = client.post("/api/v1/detect", files={"file": (FILE_NAME, open(FILE_NAME, "rb"), "image/jpeg")}) + response = client.post( + "/api/v1/detect", + files={"file": (FILE_NAME, open(FILE_NAME, "rb"), "image/jpeg")}, + ) assert response.status_code == 200 - assert response.json() == {"file_name": FILE_NAME, "is_nsfw": False, "confidence_percentage": 100.0} + assert response.json() == { + "file_name": FILE_NAME, + "is_nsfw": False, + "confidence_percentage": 100.0, + } + def test_invalid_input(): """Tests that POST /api/v1/detect returns 422 with empty request body"""