diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..78ee986 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +backend/test_data/db/children.json +backend/test_data/db/images.json +backend/test_data/db/pending_rewards.json +backend/test_data/db/rewards.json +backend/test_data/db/tasks.json +backend/test_data/db/users.json diff --git a/backend/config/paths.py b/backend/config/paths.py index a5c354b..c1ad229 100644 --- a/backend/config/paths.py +++ b/backend/config/paths.py @@ -9,19 +9,27 @@ TEST_DATA_DIR_NAME = 'test_data' # Project root (two levels up from this file) PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +def get_base_data_dir(data_env: str | None = None) -> str: + """ + Return the absolute base data directory path for the given env. + data_env: 'prod' uses `data`, anything else uses `test_data`. + """ + env = (data_env or os.environ.get('DATA_ENV', 'prod')).lower() + base_name = DATA_DIR_NAME if env == 'prod' else TEST_DATA_DIR_NAME + return os.path.join(PROJECT_ROOT, base_name) + def get_database_dir(db_env: str | None = None) -> str: """ Return the absolute base directory path for the given DB env. db_env: 'prod' uses `data/db`, anything else uses `test_data/db`. """ env = (db_env or os.environ.get('DB_ENV', 'prod')).lower() - base_name = DATA_DIR_NAME if env == 'prod' else TEST_DATA_DIR_NAME - return os.path.join(PROJECT_ROOT, base_name, 'db') + return os.path.join(PROJECT_ROOT, get_base_data_dir(env), 'db') def get_user_image_dir(username: str | None) -> str: """ Return the absolute directory path for storing images for a specific user. """ if username: - return os.path.join(PROJECT_ROOT, DATA_DIR_NAME, 'images', username) + return os.path.join(PROJECT_ROOT, get_base_data_dir(), 'images', username) return os.path.join(PROJECT_ROOT, 'resources', 'images') diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 3ec1c78..b52d8af 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,6 +1,11 @@ import os +os.environ['DB_ENV'] = 'test' +import sys import pytest +# Ensure backend root is in sys.path for imports like 'config.paths' +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + @pytest.fixture(scope="session", autouse=True) def set_test_db_env(): os.environ['DB_ENV'] = 'test' \ No newline at end of file diff --git a/backend/tests/test_image_api.py b/backend/tests/test_image_api.py index cf9c3d3..0acc8ba 100644 --- a/backend/tests/test_image_api.py +++ b/backend/tests/test_image_api.py @@ -1,6 +1,7 @@ # python import io import os +import time from config.paths import get_user_image_dir from PIL import Image as PILImage import pytest @@ -17,13 +18,14 @@ MAX_DIMENSION = 512 # Test user credentials +TEST_USER_ID = "9999999-9999-9999-9999-999999999999" TEST_EMAIL = "testuser@example.com" TEST_PASSWORD = "testpass" def add_test_user(): users_db.remove(Query().email == TEST_EMAIL) users_db.insert({ - "id": "testuserid", + "id": TEST_USER_ID, "first_name": "Test", "last_name": "User", "email": TEST_EMAIL, @@ -38,6 +40,26 @@ def login_and_set_cookie(client): token = resp.headers.get("Set-Cookie") assert token and "token=" in token +def safe_remove(path): + try: + os.remove(path) + except PermissionError as e: + print(f"Warning: Could not remove {path}: {e}. Retrying...") + time.sleep(0.1) + try: + os.remove(path) + except Exception as e2: + print(f"Warning: Still could not remove {path}: {e2}") + +def remove_test_data(): + # Remove uploaded images + user_image_dir = get_user_image_dir(TEST_USER_ID) + if os.path.exists(user_image_dir): + for f in os.listdir(user_image_dir): + safe_remove(os.path.join(user_image_dir, f)) + # Clear image database + image_db.truncate() + @pytest.fixture def client(): app = Flask(__name__) @@ -47,10 +69,12 @@ def client(): app.config['SECRET_KEY'] = 'supersecretkey' with app.test_client() as c: add_test_user() + remove_test_data() + os.makedirs(get_user_image_dir(TEST_USER_ID), exist_ok=True) login_and_set_cookie(c) yield c - for f in os.listdir(UPLOAD_FOLDER): - os.remove(os.path.join(UPLOAD_FOLDER, f)) + for f in os.listdir(get_user_image_dir(TEST_USER_ID)): + safe_remove(os.path.join(get_user_image_dir(TEST_USER_ID), f)) image_db.truncate() def make_image_bytes(w, h, mode='RGB', color=(255, 0, 0, 255), fmt='PNG'): @@ -110,7 +134,7 @@ def test_upload_valid_jpeg_extension_mapping(client): filename = j['filename'] # Accept both .jpg and .jpeg extensions assert filename.endswith('.jpg') or filename.endswith('.jpeg'), "JPEG should be saved with .jpg or .jpeg extension" - user_dir = get_user_image_dir('testuserid') + user_dir = get_user_image_dir(TEST_USER_ID) path = os.path.join(user_dir, filename) assert os.path.exists(path) @@ -121,7 +145,7 @@ def test_upload_png_alpha_preserved(client): resp = client.post('/image/upload', data=data, content_type='multipart/form-data') assert resp.status_code == 200 j = resp.get_json() - user_dir = get_user_image_dir('testuserid') + user_dir = get_user_image_dir(TEST_USER_ID) path = os.path.join(user_dir, j['filename']) assert os.path.exists(path) with PILImage.open(path) as saved: @@ -134,7 +158,7 @@ def test_upload_large_image_resized(client): resp = client.post('/image/upload', data=data, content_type='multipart/form-data') assert resp.status_code == 200 j = resp.get_json() - user_dir = get_user_image_dir('testuserid') + user_dir = get_user_image_dir(TEST_USER_ID) path = os.path.join(user_dir, j['filename']) assert os.path.exists(path) with PILImage.open(path) as saved: @@ -160,7 +184,7 @@ def test_upload_invalid_extension(client): def test_request_image_success(client): image_db.truncate() img = make_image_bytes(30, 30, fmt='PNG') - data = {'file': (img, 'r.png'), 'type': str(IMAGE_TYPE_ICON)} + data = {'file': (img, 'r.png'), 'type': str(IMAGE_TYPE_ICON), 'user_id': TEST_USER_ID} up = client.post('/image/upload', data=data, content_type='multipart/form-data') assert up.status_code == 200 recs = image_db.all() @@ -178,11 +202,11 @@ def test_list_images_filter_type(client): # Upload type 1 for _ in range(2): img = make_image_bytes(20, 20, fmt='PNG') - client.post('/image/upload', data={'file': (img, 'a.png'), 'type': '1'}, content_type='multipart/form-data') + client.post('/image/upload', data={'file': (img, 'a.png'), 'type': '1', 'user_id': TEST_USER_ID}, content_type='multipart/form-data') # Upload type 2 for _ in range(3): img = make_image_bytes(25, 25, fmt='PNG') - client.post('/image/upload', data={'file': (img, 'b.png'), 'type': '2'}, content_type='multipart/form-data') + client.post('/image/upload', data={'file': (img, 'b.png'), 'type': '2', 'user_id': TEST_USER_ID}, content_type='multipart/form-data') resp = client.get('/image/list?type=2') assert resp.status_code == 200 @@ -198,7 +222,7 @@ def test_list_images_all(client): image_db.truncate() for _ in range(4): img = make_image_bytes(10, 10, fmt='PNG') - client.post('/image/upload', data={'file': (img, 'x.png'), 'type': '2'}, content_type='multipart/form-data') + client.post('/image/upload', data={'file': (img, 'x.png'), 'type': '2', 'user_id': TEST_USER_ID}, content_type='multipart/form-data') resp = client.get('/image/list') assert resp.status_code == 200 j = resp.get_json() @@ -208,7 +232,7 @@ def test_list_images_all(client): def test_permanent_flag_false_default(client): image_db.truncate() img = make_image_bytes(32, 32, fmt='PNG') - data = {'file': (img, 't.png'), 'type': '1'} + data = {'file': (img, 't.png'), 'type': '1', 'user_id': TEST_USER_ID} resp = client.post('/image/upload', data=data, content_type='multipart/form-data') assert resp.status_code == 200 recs = image_db.all()