feat: Refactor path handling for data directories and enhance test setup with user-specific image management
All checks were successful
Gitea Actions Demo / build-and-push (push) Successful in 11s

This commit is contained in:
2026-02-06 17:02:45 -05:00
parent 0d651129cb
commit 04f50c32ae
4 changed files with 57 additions and 14 deletions

6
.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
backend/test_data/db/children.json
backend/test_data/db/images.json
backend/test_data/db/pending_rewards.json
backend/test_data/db/rewards.json
backend/test_data/db/tasks.json
backend/test_data/db/users.json

View File

@@ -9,19 +9,27 @@ TEST_DATA_DIR_NAME = 'test_data'
# Project root (two levels up from this file) # Project root (two levels up from this file)
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def get_base_data_dir(data_env: str | None = None) -> str:
"""
Return the absolute base data directory path for the given env.
data_env: 'prod' uses `data`, anything else uses `test_data`.
"""
env = (data_env or os.environ.get('DATA_ENV', 'prod')).lower()
base_name = DATA_DIR_NAME if env == 'prod' else TEST_DATA_DIR_NAME
return os.path.join(PROJECT_ROOT, base_name)
def get_database_dir(db_env: str | None = None) -> str: def get_database_dir(db_env: str | None = None) -> str:
""" """
Return the absolute base directory path for the given DB env. Return the absolute base directory path for the given DB env.
db_env: 'prod' uses `data/db`, anything else uses `test_data/db`. db_env: 'prod' uses `data/db`, anything else uses `test_data/db`.
""" """
env = (db_env or os.environ.get('DB_ENV', 'prod')).lower() env = (db_env or os.environ.get('DB_ENV', 'prod')).lower()
base_name = DATA_DIR_NAME if env == 'prod' else TEST_DATA_DIR_NAME return os.path.join(PROJECT_ROOT, get_base_data_dir(env), 'db')
return os.path.join(PROJECT_ROOT, base_name, 'db')
def get_user_image_dir(username: str | None) -> str: def get_user_image_dir(username: str | None) -> str:
""" """
Return the absolute directory path for storing images for a specific user. Return the absolute directory path for storing images for a specific user.
""" """
if username: if username:
return os.path.join(PROJECT_ROOT, DATA_DIR_NAME, 'images', username) return os.path.join(PROJECT_ROOT, get_base_data_dir(), 'images', username)
return os.path.join(PROJECT_ROOT, 'resources', 'images') return os.path.join(PROJECT_ROOT, 'resources', 'images')

View File

@@ -1,6 +1,11 @@
import os import os
os.environ['DB_ENV'] = 'test'
import sys
import pytest import pytest
# Ensure backend root is in sys.path for imports like 'config.paths'
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def set_test_db_env(): def set_test_db_env():
os.environ['DB_ENV'] = 'test' os.environ['DB_ENV'] = 'test'

View File

@@ -1,6 +1,7 @@
# python # python
import io import io
import os import os
import time
from config.paths import get_user_image_dir from config.paths import get_user_image_dir
from PIL import Image as PILImage from PIL import Image as PILImage
import pytest import pytest
@@ -17,13 +18,14 @@ MAX_DIMENSION = 512
# Test user credentials # Test user credentials
TEST_USER_ID = "9999999-9999-9999-9999-999999999999"
TEST_EMAIL = "testuser@example.com" TEST_EMAIL = "testuser@example.com"
TEST_PASSWORD = "testpass" TEST_PASSWORD = "testpass"
def add_test_user(): def add_test_user():
users_db.remove(Query().email == TEST_EMAIL) users_db.remove(Query().email == TEST_EMAIL)
users_db.insert({ users_db.insert({
"id": "testuserid", "id": TEST_USER_ID,
"first_name": "Test", "first_name": "Test",
"last_name": "User", "last_name": "User",
"email": TEST_EMAIL, "email": TEST_EMAIL,
@@ -38,6 +40,26 @@ def login_and_set_cookie(client):
token = resp.headers.get("Set-Cookie") token = resp.headers.get("Set-Cookie")
assert token and "token=" in token assert token and "token=" in token
def safe_remove(path):
try:
os.remove(path)
except PermissionError as e:
print(f"Warning: Could not remove {path}: {e}. Retrying...")
time.sleep(0.1)
try:
os.remove(path)
except Exception as e2:
print(f"Warning: Still could not remove {path}: {e2}")
def remove_test_data():
# Remove uploaded images
user_image_dir = get_user_image_dir(TEST_USER_ID)
if os.path.exists(user_image_dir):
for f in os.listdir(user_image_dir):
safe_remove(os.path.join(user_image_dir, f))
# Clear image database
image_db.truncate()
@pytest.fixture @pytest.fixture
def client(): def client():
app = Flask(__name__) app = Flask(__name__)
@@ -47,10 +69,12 @@ def client():
app.config['SECRET_KEY'] = 'supersecretkey' app.config['SECRET_KEY'] = 'supersecretkey'
with app.test_client() as c: with app.test_client() as c:
add_test_user() add_test_user()
remove_test_data()
os.makedirs(get_user_image_dir(TEST_USER_ID), exist_ok=True)
login_and_set_cookie(c) login_and_set_cookie(c)
yield c yield c
for f in os.listdir(UPLOAD_FOLDER): for f in os.listdir(get_user_image_dir(TEST_USER_ID)):
os.remove(os.path.join(UPLOAD_FOLDER, f)) safe_remove(os.path.join(get_user_image_dir(TEST_USER_ID), f))
image_db.truncate() image_db.truncate()
def make_image_bytes(w, h, mode='RGB', color=(255, 0, 0, 255), fmt='PNG'): def make_image_bytes(w, h, mode='RGB', color=(255, 0, 0, 255), fmt='PNG'):
@@ -110,7 +134,7 @@ def test_upload_valid_jpeg_extension_mapping(client):
filename = j['filename'] filename = j['filename']
# Accept both .jpg and .jpeg extensions # Accept both .jpg and .jpeg extensions
assert filename.endswith('.jpg') or filename.endswith('.jpeg'), "JPEG should be saved with .jpg or .jpeg extension" assert filename.endswith('.jpg') or filename.endswith('.jpeg'), "JPEG should be saved with .jpg or .jpeg extension"
user_dir = get_user_image_dir('testuserid') user_dir = get_user_image_dir(TEST_USER_ID)
path = os.path.join(user_dir, filename) path = os.path.join(user_dir, filename)
assert os.path.exists(path) assert os.path.exists(path)
@@ -121,7 +145,7 @@ def test_upload_png_alpha_preserved(client):
resp = client.post('/image/upload', data=data, content_type='multipart/form-data') resp = client.post('/image/upload', data=data, content_type='multipart/form-data')
assert resp.status_code == 200 assert resp.status_code == 200
j = resp.get_json() j = resp.get_json()
user_dir = get_user_image_dir('testuserid') user_dir = get_user_image_dir(TEST_USER_ID)
path = os.path.join(user_dir, j['filename']) path = os.path.join(user_dir, j['filename'])
assert os.path.exists(path) assert os.path.exists(path)
with PILImage.open(path) as saved: with PILImage.open(path) as saved:
@@ -134,7 +158,7 @@ def test_upload_large_image_resized(client):
resp = client.post('/image/upload', data=data, content_type='multipart/form-data') resp = client.post('/image/upload', data=data, content_type='multipart/form-data')
assert resp.status_code == 200 assert resp.status_code == 200
j = resp.get_json() j = resp.get_json()
user_dir = get_user_image_dir('testuserid') user_dir = get_user_image_dir(TEST_USER_ID)
path = os.path.join(user_dir, j['filename']) path = os.path.join(user_dir, j['filename'])
assert os.path.exists(path) assert os.path.exists(path)
with PILImage.open(path) as saved: with PILImage.open(path) as saved:
@@ -160,7 +184,7 @@ def test_upload_invalid_extension(client):
def test_request_image_success(client): def test_request_image_success(client):
image_db.truncate() image_db.truncate()
img = make_image_bytes(30, 30, fmt='PNG') img = make_image_bytes(30, 30, fmt='PNG')
data = {'file': (img, 'r.png'), 'type': str(IMAGE_TYPE_ICON)} data = {'file': (img, 'r.png'), 'type': str(IMAGE_TYPE_ICON), 'user_id': TEST_USER_ID}
up = client.post('/image/upload', data=data, content_type='multipart/form-data') up = client.post('/image/upload', data=data, content_type='multipart/form-data')
assert up.status_code == 200 assert up.status_code == 200
recs = image_db.all() recs = image_db.all()
@@ -178,11 +202,11 @@ def test_list_images_filter_type(client):
# Upload type 1 # Upload type 1
for _ in range(2): for _ in range(2):
img = make_image_bytes(20, 20, fmt='PNG') img = make_image_bytes(20, 20, fmt='PNG')
client.post('/image/upload', data={'file': (img, 'a.png'), 'type': '1'}, content_type='multipart/form-data') client.post('/image/upload', data={'file': (img, 'a.png'), 'type': '1', 'user_id': TEST_USER_ID}, content_type='multipart/form-data')
# Upload type 2 # Upload type 2
for _ in range(3): for _ in range(3):
img = make_image_bytes(25, 25, fmt='PNG') img = make_image_bytes(25, 25, fmt='PNG')
client.post('/image/upload', data={'file': (img, 'b.png'), 'type': '2'}, content_type='multipart/form-data') client.post('/image/upload', data={'file': (img, 'b.png'), 'type': '2', 'user_id': TEST_USER_ID}, content_type='multipart/form-data')
resp = client.get('/image/list?type=2') resp = client.get('/image/list?type=2')
assert resp.status_code == 200 assert resp.status_code == 200
@@ -198,7 +222,7 @@ def test_list_images_all(client):
image_db.truncate() image_db.truncate()
for _ in range(4): for _ in range(4):
img = make_image_bytes(10, 10, fmt='PNG') img = make_image_bytes(10, 10, fmt='PNG')
client.post('/image/upload', data={'file': (img, 'x.png'), 'type': '2'}, content_type='multipart/form-data') client.post('/image/upload', data={'file': (img, 'x.png'), 'type': '2', 'user_id': TEST_USER_ID}, content_type='multipart/form-data')
resp = client.get('/image/list') resp = client.get('/image/list')
assert resp.status_code == 200 assert resp.status_code == 200
j = resp.get_json() j = resp.get_json()
@@ -208,7 +232,7 @@ def test_list_images_all(client):
def test_permanent_flag_false_default(client): def test_permanent_flag_false_default(client):
image_db.truncate() image_db.truncate()
img = make_image_bytes(32, 32, fmt='PNG') img = make_image_bytes(32, 32, fmt='PNG')
data = {'file': (img, 't.png'), 'type': '1'} data = {'file': (img, 't.png'), 'type': '1', 'user_id': TEST_USER_ID}
resp = client.post('/image/upload', data=data, content_type='multipart/form-data') resp = client.post('/image/upload', data=data, content_type='multipart/form-data')
assert resp.status_code == 200 assert resp.status_code == 200
recs = image_db.all() recs = image_db.all()