mirror of
https://github.com/steelcityamir/safe-content-ai.git
synced 2024-10-26 18:49:27 +03:00
Updates
This commit is contained in:
parent
54567a7b3a
commit
4c82208be9
4
.github/workflows/pylint.yml
vendored
4
.github/workflows/pylint.yml
vendored
@ -10,6 +10,10 @@ jobs:
|
|||||||
python-version: ["3.10"]
|
python-version: ["3.10"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
cache: 'pip' # caching pip dependencies
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v3
|
||||||
with:
|
with:
|
||||||
|
18
main.py
18
main.py
@ -12,18 +12,22 @@ from cachetools import Cache
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Cache with no TTL
|
# Initialize Cache with no TTL
|
||||||
cache = Cache(maxsize=1000)
|
cache = Cache(maxsize=1000)
|
||||||
|
|
||||||
# Load the model using the transformers pipeline
|
# Load the model using the transformers pipeline
|
||||||
model = pipeline("image-classification", model="falconsai/nsfw_image_detection")
|
model = pipeline("image-classification", model="falconsai/nsfw_image_detection")
|
||||||
|
|
||||||
|
|
||||||
def hash_data(data):
|
def hash_data(data):
|
||||||
"""Function for hashing image data."""
|
"""Function for hashing image data."""
|
||||||
return hashlib.sha256(data).hexdigest()
|
return hashlib.sha256(data).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/detect")
|
@app.post("/api/v1/detect")
|
||||||
async def classify_image(file: UploadFile = File(...)):
|
async def classify_image(file: UploadFile = File(...)):
|
||||||
"""Function analyzing image."""
|
"""Function analyzing image."""
|
||||||
@ -44,16 +48,16 @@ async def classify_image(file: UploadFile = File(...)):
|
|||||||
results = model(image)
|
results = model(image)
|
||||||
|
|
||||||
# Find the prediction with the highest confidence using the max() function
|
# Find the prediction with the highest confidence using the max() function
|
||||||
best_prediction = max(results, key=lambda x: x['score'])
|
best_prediction = max(results, key=lambda x: x["score"])
|
||||||
|
|
||||||
# Calculate the confidence score, rounded to the nearest tenth and as a percentage
|
# Calculate the confidence score, rounded to the nearest tenth and as a percentage
|
||||||
confidence_percentage = round(best_prediction['score'] * 100, 1)
|
confidence_percentage = round(best_prediction["score"] * 100, 1)
|
||||||
|
|
||||||
# Prepare the custom response data
|
# Prepare the custom response data
|
||||||
response_data = {
|
response_data = {
|
||||||
"file_name": file.filename,
|
"file_name": file.filename,
|
||||||
"is_nsfw": best_prediction['label'] == 'nsfw',
|
"is_nsfw": best_prediction["label"] == "nsfw",
|
||||||
"confidence_percentage": confidence_percentage
|
"confidence_percentage": confidence_percentage,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Populate hash
|
# Populate hash
|
||||||
@ -64,6 +68,8 @@ async def classify_image(file: UploadFile = File(...)):
|
|||||||
except PipelineException as e:
|
except PipelineException as e:
|
||||||
return JSONResponse(status_code=500, content={"message": str(e)})
|
return JSONResponse(status_code=500, content={"message": str(e)})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||||
|
13
test_main.py
13
test_main.py
@ -7,11 +7,20 @@ client = TestClient(app)
|
|||||||
|
|
||||||
FILE_NAME = "sunflower.jpg"
|
FILE_NAME = "sunflower.jpg"
|
||||||
|
|
||||||
|
|
||||||
def test_read_main():
|
def test_read_main():
|
||||||
"""Tests that POST /api/v1/detect returns 200 OK with valid request body"""
|
"""Tests that POST /api/v1/detect returns 200 OK with valid request body"""
|
||||||
response = client.post("/api/v1/detect", files={"file": (FILE_NAME, open(FILE_NAME, "rb"), "image/jpeg")})
|
response = client.post(
|
||||||
|
"/api/v1/detect",
|
||||||
|
files={"file": (FILE_NAME, open(FILE_NAME, "rb"), "image/jpeg")},
|
||||||
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"file_name": FILE_NAME, "is_nsfw": False, "confidence_percentage": 100.0}
|
assert response.json() == {
|
||||||
|
"file_name": FILE_NAME,
|
||||||
|
"is_nsfw": False,
|
||||||
|
"confidence_percentage": 100.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_input():
|
def test_invalid_input():
|
||||||
"""Tests that POST /api/v1/detect returns 422 with empty request body"""
|
"""Tests that POST /api/v1/detect returns 422 with empty request body"""
|
||||||
|
Loading…
Reference in New Issue
Block a user