diff --git a/app/common/test_log_utils.py b/app/common/test_log_utils.py new file mode 100644 index 0000000..4ab941c --- /dev/null +++ b/app/common/test_log_utils.py @@ -0,0 +1,107 @@ +import logging + +from app.common.log_utils import EndpointFilter, ExtraFieldsFilter + + +def test_extra_fields_filter_with_all_context(mocker): + # Mock the context variables + mock_trace_id = mocker.patch("app.common.log_utils.ctx_trace_id") + mock_request = mocker.patch("app.common.log_utils.ctx_request") + mock_response = mocker.patch("app.common.log_utils.ctx_response") + + # Set context values + mock_trace_id.get.return_value = "test-trace-id" + mock_request.get.return_value = {"url": "http://test.com", "method": "GET"} + mock_response.get.return_value = {"status_code": 200} + + # Create a log record + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=10, + msg="test message", + args=(), + exc_info=None, + ) + + # Apply filter + log_filter = ExtraFieldsFilter() + result = log_filter.filter(record) + + # Assertions + assert result is True + assert record.trace == {"id": "test-trace-id"} + assert record.url == {"full": "http://test.com"} + assert record.http == { + "request": {"method": "GET"}, + "response": {"status_code": 200}, + } + + +def test_extra_fields_filter_with_no_context(mocker): + # Mock the context variables to return None/empty + mock_trace_id = mocker.patch("app.common.log_utils.ctx_trace_id") + mock_request = mocker.patch("app.common.log_utils.ctx_request") + mock_response = mocker.patch("app.common.log_utils.ctx_response") + + mock_trace_id.get.return_value = None + mock_request.get.return_value = None + mock_response.get.return_value = None + + # Create a log record + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=10, + msg="test message", + args=(), + exc_info=None, + ) + + # Apply filter + log_filter = ExtraFieldsFilter() + result = log_filter.filter(record) + + # Assertions + assert result is True + assert not hasattr(record, "trace") + assert not hasattr(record, "url") + assert not hasattr(record, "http") + + +def test_endpoint_filter_blocks_matching_path(): + filter_path = "/health" + log_filter = EndpointFilter(path=filter_path) + + # Create a log record containing the path + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=10, + msg=f"GET {filter_path} HTTP/1.1", + args=(), + exc_info=None, + ) + + assert log_filter.filter(record) is False + + +def test_endpoint_filter_allows_non_matching_path(): + filter_path = "/health" + log_filter = EndpointFilter(path=filter_path) + + # Create a log record NOT containing the path + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=10, + msg="GET /api/users HTTP/1.1", + args=(), + exc_info=None, + ) + + assert log_filter.filter(record) is True diff --git a/app/common/test_metrics.py b/app/common/test_metrics.py new file mode 100644 index 0000000..41ee6ba --- /dev/null +++ b/app/common/test_metrics.py @@ -0,0 +1,23 @@ +from app.common.metrics import counter + + +def test_counter_success(mocker): + mock_put_metric = mocker.patch("app.common.metrics.__put_metric") + + counter("test_metric", 123) + + mock_put_metric.assert_called_once_with("test_metric", 123, "Count") + + +def test_counter_handles_exception(mocker): + mocker.patch("app.common.metrics.__put_metric", side_effect=Exception("Test Error")) + mock_logger = mocker.patch("app.common.metrics.logger") + + # Should not raise exception but catch it + counter("test_metric", 123) + + # Verify error was logged + assert mock_logger.error.call_count == 1 + args, _ = mock_logger.error.call_args + assert "Error calling put_metric" in args[0] + assert str(args[1]) == "Test Error" diff --git a/app/common/test_mongo.py b/app/common/test_mongo.py new file mode 100644 index 0000000..d20e00d --- /dev/null +++ b/app/common/test_mongo.py @@ -0,0 +1,86 @@ +import pytest + +from app.common import mongo +from app.config import config + + +# Reset the global client variable before each test +@pytest.fixture(autouse=True) +def reset_mongo_client(): + mongo.client = None + mongo.db = None + yield + mongo.client = None + mongo.db = None + + +@pytest.mark.asyncio +async def test_get_mongo_client_initialization(mocker): + mock_client_cls = mocker.patch("app.common.mongo.AsyncMongoClient") + mock_instance = mock_client_cls.return_value + + # Setup the async ping command + # get_database() returns a DB object, which has an async command() method + mock_db = mocker.MagicMock() + mock_instance.get_database.return_value = mock_db + mock_db.command = mocker.AsyncMock(return_value={"ok": 1}) + + client = await mongo.get_mongo_client() + + assert client == mock_instance + mock_client_cls.assert_called_once_with(config.mongo_uri) + mock_db.command.assert_awaited_once_with("ping") + + +@pytest.mark.asyncio +async def test_get_mongo_client_with_custom_tls(mocker, monkeypatch): + # Mock config and custom certs + monkeypatch.setattr(config, "mongo_truststore", "custom-cert-key") + mocker.patch.dict( + "app.common.tls.custom_ca_certs", {"custom-cert-key": "/path/to/cert.pem"} + ) + + mock_client_cls = mocker.patch("app.common.mongo.AsyncMongoClient") + mock_instance = mock_client_cls.return_value + mock_db = mocker.MagicMock() + mock_instance.get_database.return_value = mock_db + mock_db.command = mocker.AsyncMock(return_value={"ok": 1}) + + await mongo.get_mongo_client() + + # Verify TLS param was passed + mock_client_cls.assert_called_once_with( + config.mongo_uri, tlsCAFile="/path/to/cert.pem" + ) + + +@pytest.mark.asyncio +async def test_get_mongo_client_returns_existing(mocker): + # Set an existing client + existing_client = mocker.Mock() + mongo.client = existing_client + + mock_client_cls = mocker.patch("app.common.mongo.AsyncMongoClient") + + result = await mongo.get_mongo_client() + + # Should return existing without creating new one or pinging + assert result == existing_client + mock_client_cls.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_db(mocker): + mock_client = mocker.MagicMock() + mock_db = mocker.Mock() + mock_client.get_database.return_value = mock_db + + # First call initializes + result = await mongo.get_db(mock_client) + assert result == mock_db + mock_client.get_database.assert_called_once_with(config.mongo_database) + + # Second call returns cached + result2 = await mongo.get_db(mock_client) + assert result2 == mock_db + assert mock_client.get_database.call_count == 1 diff --git a/app/common/test_tls.py b/app/common/test_tls.py new file mode 100644 index 0000000..03cea12 --- /dev/null +++ b/app/common/test_tls.py @@ -0,0 +1,105 @@ +import base64 +import ssl + +from app.common.tls import ( + extract_all_certs, + init_custom_certificates, + load_certs_into_context, +) + + +class TestExtractAllCerts: + def test_extract_valid_certs(self, mocker, monkeypatch, tmp_path): + cert_content = b"cert1" + encoded_cert = base64.b64encode(cert_content).decode() + monkeypatch.setenv("TRUSTSTORE_CERT1", encoded_cert) + + cert_path = tmp_path / "cert1.pem" + + mock_named_temp_file = mocker.patch( + "app.common.tls.tempfile.NamedTemporaryFile" + ) + mock_file_obj = mocker.MagicMock() + mock_file_obj.name = str(cert_path) + mock_named_temp_file.return_value.__enter__.return_value = mock_file_obj + + certs = extract_all_certs() + + assert len(certs) == 1 + assert certs["TRUSTSTORE_CERT1"] == str(cert_path) + + # Check if decoded content was written + mock_file_obj.write.assert_called_once_with(b"cert1") + + def test_extract_invalid_base64_cert(self, monkeypatch): + monkeypatch.setenv("TRUSTSTORE_BAD", "invalid-base64!") + + certs = extract_all_certs() + assert len(certs) == 0 + + def test_extract_no_truststore_vars(self, monkeypatch): + monkeypatch.setenv("NORMAL_VAR", "value") + + certs = extract_all_certs() + assert len(certs) == 0 + + +class TestLoadCertsIntoContext: + def test_load_valid_certs(self, mocker): + mock_create_context = mocker.patch("app.common.tls.ssl.create_default_context") + mock_ctx = mocker.MagicMock() + mock_create_context.return_value = mock_ctx + + certs = { + "TRUSTSTORE_1": "/path/to/cert1.pem", + "TRUSTSTORE_2": "/path/to/cert2.pem", + } + + ctx = load_certs_into_context(certs) + + assert ctx == mock_ctx + assert mock_ctx.load_verify_locations.call_count == 2 + mock_ctx.load_verify_locations.assert_any_call("/path/to/cert1.pem") + mock_ctx.load_verify_locations.assert_any_call("/path/to/cert2.pem") + + def test_load_certs_error(self, mocker): + mock_create_context = mocker.patch("app.common.tls.ssl.create_default_context") + mock_ctx = mocker.MagicMock() + mock_create_context.return_value = mock_ctx + # Make load_verify_locations raise for the first one + mock_ctx.load_verify_locations.side_effect = [ssl.SSLError("Bad cert"), None] + + certs = { + "TRUSTSTORE_BAD": "/path/to/bad.pem", + "TRUSTSTORE_GOOD": "/path/to/good.pem", + } + + ctx = load_certs_into_context(certs) + + # Should proceed to load the second one despite error in first + assert ctx == mock_ctx + assert mock_ctx.load_verify_locations.call_count == 2 + + +class TestInitCustomCertificates: + def test_init_globals(self, mocker): + mock_extract = mocker.patch("app.common.tls.extract_all_certs") + mock_load = mocker.patch("app.common.tls.load_certs_into_context") + + mock_certs = {"cert": "path"} + mock_ctx = mocker.MagicMock() + + mock_extract.return_value = mock_certs + mock_load.return_value = mock_ctx + + result = init_custom_certificates() + + assert result == mock_certs + mock_extract.assert_called_once() + mock_load.assert_called_once_with(mock_certs) + + # Check globals are set + from app.common import tls + + assert tls.custom_ca_certs == mock_certs + assert tls.ctx == mock_ctx diff --git a/app/config.py b/app/config.py index ed48942..994b101 100644 --- a/app/config.py +++ b/app/config.py @@ -5,8 +5,8 @@ class AppConfig(BaseSettings): model_config = SettingsConfigDict() python_env: str | None = None - host: str | None = None - port: int | None = None + host: str = "127.0.0.1" + port: int = 8086 log_config: str | None = None mongo_uri: str | None = None mongo_database: str = "cdp-python-backend-template" diff --git a/app/example/router.py b/app/example/router.py index acd2af3..46ae672 100644 --- a/app/example/router.py +++ b/app/example/router.py @@ -10,13 +10,14 @@ logger = getLogger(__name__) -# remove this example route +# basic endpoint example @router.get("/test") async def root(): logger.info("TEST ENDPOINT") return {"ok": True} +# database endpoint example @router.get("/db") async def db_query(db=Depends(get_db)): await db.example.insert_one({"foo": "bar"}) @@ -24,6 +25,7 @@ async def db_query(db=Depends(get_db)): return {"ok": data} +# http client endpoint example @router.get("/http") async def http_query(client=Depends(create_async_client)): endpoint = config.aws_endpoint_url or "http://localstack:4566" diff --git a/app/example/test_router.py b/app/example/test_router.py new file mode 100644 index 0000000..9ee931f --- /dev/null +++ b/app/example/test_router.py @@ -0,0 +1,46 @@ +from fastapi.testclient import TestClient + +from app.common.http_client import create_async_client +from app.common.mongo import get_db +from app.main import app + +client = TestClient(app) + + +def test_root_success(): + response = client.get("/example/test") + assert response.status_code == 200 + assert response.json() == {"ok": True} + + +def test_db_query_success(mocker): + mock_db = mocker.AsyncMock() + mock_db.example.insert_one.return_value = None + mock_db.example.find_one.return_value = {"foo": "bar", "id": 123} + + app.dependency_overrides[get_db] = lambda: mock_db + + try: + response = client.get("/example/db") + + assert response.status_code == 200 + assert response.json() == {"ok": {"foo": "bar", "id": 123}} + + mock_db.example.insert_one.assert_called_once() + finally: + app.dependency_overrides = {} + + +def test_http_query_success(mocker): + mock_client = mocker.AsyncMock() + mock_client.get.return_value.status_code = 200 + + app.dependency_overrides[create_async_client] = lambda: mock_client + + try: + response = client.get("/example/http") + + assert response.status_code == 200 + assert response.json() == {"ok": 200} + finally: + app.dependency_overrides = {} diff --git a/app/main.py b/app/main.py index a7d6c51..53cef53 100644 --- a/app/main.py +++ b/app/main.py @@ -35,15 +35,15 @@ async def lifespan(_: FastAPI): app.include_router(example_router) -def main() -> None: +def main() -> None: # pragma: no cover uvicorn.run( "app.main:app", host=config.host, port=config.port, log_config=config.log_config, - reload=config.python_env == "development" + reload=config.python_env == "development", ) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/app/test_main.py b/app/test_main.py index 7919ea0..d7c657c 100644 --- a/app/test_main.py +++ b/app/test_main.py @@ -5,6 +5,17 @@ client = TestClient(app) +def test_lifespan(mocker): + mock_mongo_client = mocker.AsyncMock() + mock_get_mongo = mocker.patch("app.main.get_mongo_client", return_value=mock_mongo_client) + + # Using TestClient as a context manager triggers lifespan startup/shutdown + with TestClient(app): + mock_get_mongo.assert_called_once() # Startup: connect called + + mock_mongo_client.close.assert_awaited_once() # Shutdown: close called + + def test_example(): response = client.get("/example/test") assert response.status_code == 200 diff --git a/pyproject.toml b/pyproject.toml index 77d9b97..a24491e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dev = [ "pytest==8.4.0", "pytest-asyncio==1.0.0", "pytest-cov==6.2.1", + "pytest-mock>=3.15.1", "ruff==0.11.13", "taskipy==1.14.1", ] diff --git a/uv.lock b/uv.lock index 21bbf9d..fb04823 100644 --- a/uv.lock +++ b/uv.lock @@ -190,6 +190,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-mock" }, { name = "ruff" }, { name = "taskipy" }, ] @@ -214,6 +215,7 @@ dev = [ { name = "pytest", specifier = "==8.4.0" }, { name = "pytest-asyncio", specifier = "==1.0.0" }, { name = "pytest-cov", specifier = "==6.2.1" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, { name = "ruff", specifier = "==0.11.13" }, { name = "taskipy", specifier = "==1.14.1" }, ] @@ -997,6 +999,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644, upload-time = "2025-06-12T10:47:45.932Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"