Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions app/common/test_log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import logging

from app.common.log_utils import EndpointFilter, ExtraFieldsFilter


def test_extra_fields_filter_with_all_context(mocker):
# Mock the context variables
mock_trace_id = mocker.patch("app.common.log_utils.ctx_trace_id")
mock_request = mocker.patch("app.common.log_utils.ctx_request")
mock_response = mocker.patch("app.common.log_utils.ctx_response")

# Set context values
mock_trace_id.get.return_value = "test-trace-id"
mock_request.get.return_value = {"url": "http://test.com", "method": "GET"}
mock_response.get.return_value = {"status_code": 200}

# Create a log record
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname=__file__,
lineno=10,
msg="test message",
args=(),
exc_info=None,
)

# Apply filter
log_filter = ExtraFieldsFilter()
result = log_filter.filter(record)

# Assertions
assert result is True
assert record.trace == {"id": "test-trace-id"}
assert record.url == {"full": "http://test.com"}
assert record.http == {
"request": {"method": "GET"},
"response": {"status_code": 200},
}


def test_extra_fields_filter_with_no_context(mocker):
# Mock the context variables to return None/empty
mock_trace_id = mocker.patch("app.common.log_utils.ctx_trace_id")
mock_request = mocker.patch("app.common.log_utils.ctx_request")
mock_response = mocker.patch("app.common.log_utils.ctx_response")

mock_trace_id.get.return_value = None
mock_request.get.return_value = None
mock_response.get.return_value = None

# Create a log record
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname=__file__,
lineno=10,
msg="test message",
args=(),
exc_info=None,
)

# Apply filter
log_filter = ExtraFieldsFilter()
result = log_filter.filter(record)

# Assertions
assert result is True
assert not hasattr(record, "trace")
assert not hasattr(record, "url")
assert not hasattr(record, "http")


def test_endpoint_filter_blocks_matching_path():
filter_path = "/health"
log_filter = EndpointFilter(path=filter_path)

# Create a log record containing the path
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname=__file__,
lineno=10,
msg=f"GET {filter_path} HTTP/1.1",
args=(),
exc_info=None,
)

assert log_filter.filter(record) is False


def test_endpoint_filter_allows_non_matching_path():
filter_path = "/health"
log_filter = EndpointFilter(path=filter_path)

# Create a log record NOT containing the path
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname=__file__,
lineno=10,
msg="GET /api/users HTTP/1.1",
args=(),
exc_info=None,
)

assert log_filter.filter(record) is True
23 changes: 23 additions & 0 deletions app/common/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from app.common.metrics import counter


def test_counter_success(mocker):
mock_put_metric = mocker.patch("app.common.metrics.__put_metric")

counter("test_metric", 123)

mock_put_metric.assert_called_once_with("test_metric", 123, "Count")


def test_counter_handles_exception(mocker):
mocker.patch("app.common.metrics.__put_metric", side_effect=Exception("Test Error"))
mock_logger = mocker.patch("app.common.metrics.logger")

# Should not raise exception but catch it
counter("test_metric", 123)

# Verify error was logged
assert mock_logger.error.call_count == 1
args, _ = mock_logger.error.call_args
assert "Error calling put_metric" in args[0]
assert str(args[1]) == "Test Error"
86 changes: 86 additions & 0 deletions app/common/test_mongo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest

from app.common import mongo
from app.config import config


# Reset the global client variable before each test
@pytest.fixture(autouse=True)
def reset_mongo_client():
mongo.client = None
mongo.db = None
yield
mongo.client = None
mongo.db = None


@pytest.mark.asyncio
async def test_get_mongo_client_initialization(mocker):
mock_client_cls = mocker.patch("app.common.mongo.AsyncMongoClient")
mock_instance = mock_client_cls.return_value

# Setup the async ping command
# get_database() returns a DB object, which has an async command() method
mock_db = mocker.MagicMock()
mock_instance.get_database.return_value = mock_db
mock_db.command = mocker.AsyncMock(return_value={"ok": 1})

client = await mongo.get_mongo_client()

assert client == mock_instance
mock_client_cls.assert_called_once_with(config.mongo_uri)
mock_db.command.assert_awaited_once_with("ping")


@pytest.mark.asyncio
async def test_get_mongo_client_with_custom_tls(mocker, monkeypatch):
# Mock config and custom certs
monkeypatch.setattr(config, "mongo_truststore", "custom-cert-key")
mocker.patch.dict(
"app.common.tls.custom_ca_certs", {"custom-cert-key": "/path/to/cert.pem"}
)

mock_client_cls = mocker.patch("app.common.mongo.AsyncMongoClient")
mock_instance = mock_client_cls.return_value
mock_db = mocker.MagicMock()
mock_instance.get_database.return_value = mock_db
mock_db.command = mocker.AsyncMock(return_value={"ok": 1})

await mongo.get_mongo_client()

# Verify TLS param was passed
mock_client_cls.assert_called_once_with(
config.mongo_uri, tlsCAFile="/path/to/cert.pem"
)


@pytest.mark.asyncio
async def test_get_mongo_client_returns_existing(mocker):
# Set an existing client
existing_client = mocker.Mock()
mongo.client = existing_client

mock_client_cls = mocker.patch("app.common.mongo.AsyncMongoClient")

result = await mongo.get_mongo_client()

# Should return existing without creating new one or pinging
assert result == existing_client
mock_client_cls.assert_not_called()


@pytest.mark.asyncio
async def test_get_db(mocker):
mock_client = mocker.MagicMock()
mock_db = mocker.Mock()
mock_client.get_database.return_value = mock_db

# First call initializes
result = await mongo.get_db(mock_client)
assert result == mock_db
mock_client.get_database.assert_called_once_with(config.mongo_database)

# Second call returns cached
result2 = await mongo.get_db(mock_client)
assert result2 == mock_db
assert mock_client.get_database.call_count == 1
105 changes: 105 additions & 0 deletions app/common/test_tls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import base64
import ssl

from app.common.tls import (
extract_all_certs,
init_custom_certificates,
load_certs_into_context,
)


class TestExtractAllCerts:
def test_extract_valid_certs(self, mocker, monkeypatch, tmp_path):
cert_content = b"cert1"
encoded_cert = base64.b64encode(cert_content).decode()
monkeypatch.setenv("TRUSTSTORE_CERT1", encoded_cert)

cert_path = tmp_path / "cert1.pem"

mock_named_temp_file = mocker.patch(
"app.common.tls.tempfile.NamedTemporaryFile"
)
mock_file_obj = mocker.MagicMock()
mock_file_obj.name = str(cert_path)
mock_named_temp_file.return_value.__enter__.return_value = mock_file_obj

certs = extract_all_certs()

assert len(certs) == 1
assert certs["TRUSTSTORE_CERT1"] == str(cert_path)

# Check if decoded content was written
mock_file_obj.write.assert_called_once_with(b"cert1")

def test_extract_invalid_base64_cert(self, monkeypatch):
monkeypatch.setenv("TRUSTSTORE_BAD", "invalid-base64!")

certs = extract_all_certs()
assert len(certs) == 0

def test_extract_no_truststore_vars(self, monkeypatch):
monkeypatch.setenv("NORMAL_VAR", "value")

certs = extract_all_certs()
assert len(certs) == 0


class TestLoadCertsIntoContext:
def test_load_valid_certs(self, mocker):
mock_create_context = mocker.patch("app.common.tls.ssl.create_default_context")
mock_ctx = mocker.MagicMock()
mock_create_context.return_value = mock_ctx

certs = {
"TRUSTSTORE_1": "/path/to/cert1.pem",
"TRUSTSTORE_2": "/path/to/cert2.pem",
}

ctx = load_certs_into_context(certs)

assert ctx == mock_ctx
assert mock_ctx.load_verify_locations.call_count == 2
mock_ctx.load_verify_locations.assert_any_call("/path/to/cert1.pem")
mock_ctx.load_verify_locations.assert_any_call("/path/to/cert2.pem")

def test_load_certs_error(self, mocker):
mock_create_context = mocker.patch("app.common.tls.ssl.create_default_context")
mock_ctx = mocker.MagicMock()
mock_create_context.return_value = mock_ctx
# Make load_verify_locations raise for the first one
mock_ctx.load_verify_locations.side_effect = [ssl.SSLError("Bad cert"), None]

certs = {
"TRUSTSTORE_BAD": "/path/to/bad.pem",
"TRUSTSTORE_GOOD": "/path/to/good.pem",
}

ctx = load_certs_into_context(certs)

# Should proceed to load the second one despite error in first
assert ctx == mock_ctx
assert mock_ctx.load_verify_locations.call_count == 2


class TestInitCustomCertificates:
def test_init_globals(self, mocker):
mock_extract = mocker.patch("app.common.tls.extract_all_certs")
mock_load = mocker.patch("app.common.tls.load_certs_into_context")

mock_certs = {"cert": "path"}
mock_ctx = mocker.MagicMock()

mock_extract.return_value = mock_certs
mock_load.return_value = mock_ctx

result = init_custom_certificates()

assert result == mock_certs
mock_extract.assert_called_once()
mock_load.assert_called_once_with(mock_certs)

# Check globals are set
from app.common import tls

assert tls.custom_ca_certs == mock_certs
assert tls.ctx == mock_ctx
4 changes: 2 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
class AppConfig(BaseSettings):
model_config = SettingsConfigDict()
python_env: str | None = None
host: str | None = None
port: int | None = None
host: str = "127.0.0.1"
port: int = 8086
log_config: str | None = None
mongo_uri: str | None = None
mongo_database: str = "cdp-python-backend-template"
Expand Down
4 changes: 3 additions & 1 deletion app/example/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@
logger = getLogger(__name__)


# remove this example route
# basic endpoint example
@router.get("/test")
async def root():
logger.info("TEST ENDPOINT")
return {"ok": True}


# database endpoint example
@router.get("/db")
async def db_query(db=Depends(get_db)):
await db.example.insert_one({"foo": "bar"})
data = await db.example.find_one({}, {"_id": 0})
return {"ok": data}


# http client endpoint example
@router.get("/http")
async def http_query(client=Depends(create_async_client)):
endpoint = config.aws_endpoint_url or "http://localstack:4566"
Expand Down
Loading