Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_translation_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
- name: Run tests
run: |
cd workers/translation-worker
uv sync --frozen --all-extras
uv sync --frozen --extra cpu --dev
uv run --frozen python -m pytest --timeout=180 -vvv --cache-clear --show-capture=all -r A


Expand Down
73 changes: 48 additions & 25 deletions workers/translation-worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ name = "translation-worker"
dynamic = ["version"]
description = "Translation worker implementation in Temporal"
authors = [
{ name = "Clément Doumouro", email = "cdoumouro@icij.org" },
{ name = "Clément Doumouro", email = "clement.doumouro@gmail.com" },
{ name = "Lion Summerbell", email = "lsummerbell@icij.org" }
{ name = "Clément Doumouro", email = "cdoumouro@icij.org" },
{ name = "Clément Doumouro", email = "clement.doumouro@gmail.com" },
{ name = "Lion Summerbell", email = "lsummerbell@icij.org" }
]
readme = "README.md"
requires-python = ">=3.11.0, <3.14"
requires-python = ">=3.11.0, <3.13"

dependencies = [
"datashare-python~=0.8.4",
"argostranslate>=1.11.0",
"temporalio>=1.22.0",
"pycountry~=26.2.16",
"datashare-python~=0.8.4",
"argostranslate==1.11.0",
"pydantic-extra-types[pycountry]==2.11.1",
]

[project.entry-points."datashare.workflows"]
Expand All @@ -40,39 +40,62 @@ workflows = "translation_worker.workflows:WORKFLOWS"
[project.entry-points."datashare.activities"]
activities = "translation_worker.activities:ACTIVITIES"

[tool.uv.sources]
torch = [
{ index = "pytorch-cpu" },

[project.optional-dependencies]
cpu = [
"torch==2.10.0",
]
gpu = [
"torch==2.10.0+cu128; sys_platform == 'linux'",
]
datashare-python = { path = "../../datashare-python", editable = true }

[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

[[tool.uv.index]]
name = "pytorch-gpu"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.uv.sources]
torch = [
{ index = "pytorch-gpu", extra = "gpu" },
{ index = "pytorch-cpu", extra = "cpu" },
]
datashare-python = { path = "../../datashare-python", editable = true }

[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]

[dependency-groups]
dev = [
"datashare-python~=0.2",
"nest-asyncio>=1.6.0",
"pre-commit>=4.5.1",
"psutil>=6.1.0",
"pytest~=8.1",
"pytest-asyncio~=0.24",
"pytest-timeout==2.4.0",
"redis[hiredis]>=5.2.1",
"ruff==0.15.2",
"typing-extensions>=4.15.0",
"datashare-python~=0.2",
"nest-asyncio>=1.6.0",
"pre-commit>=4.5.1",
"psutil>=6.1.0",
"pytest~=8.1",
"pytest-asyncio~=0.24",
"pytest-timeout==2.4.0",
"redis[hiredis]>=5.2.1",
"ruff==0.15.2",
"typing-extensions>=4.15.0",
]

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
markers = [
"integration",
"pull",
"integration",
"pull",
]
log_cli = 1
log_cli_level = "DEBUG"
log_file_format = "[%(levelname)s][%(asctime)s.%(msecs)03d][%(name)s]: %(message)s"
log_file_date_format = "%Y-%m-%d %H:%M:%S"
log_file_date_format = "%Y-%m-%d %H:%M:%S"
152 changes: 87 additions & 65 deletions workers/translation-worker/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import asyncio
import uuid
from collections.abc import AsyncGenerator
from concurrent.futures import ThreadPoolExecutor

import datashare_python
import pytest
import translation_worker
from _pytest.tmpdir import TempPathFactory
from datashare_python.config import (
DatashareClientConfig,
LogFormat,
LoggingConfig,
TemporalClientConfig,
)
from datashare_python.conftest import ( # noqa: F401
TEST_PROJECT,
event_loop,
index_docs,
test_deps,
test_es_client,
test_es_client_session,
test_task_client,
Expand All @@ -18,16 +25,44 @@
worker_lifetime_deps,
)
from datashare_python.objects import Document
from datashare_python.types_ import TemporalClient
from datashare_python.types_ import ContextManagerFactory, TemporalClient
from datashare_python.worker import worker_context
from dependencies import set_es_client, set_worker_config
from icij_common.es import ESClient
from temporalio.worker import Worker
from translation_worker.activities import (
CreateTranslationBatches,
TranslateDocs,
resolve_language_alpha_code,
)
from translation_worker.objects import TaskQueues, TranslationWorkerConfig
from translation_worker.workflows import TranslationWorkflow
from translation_worker.activities import TranslationActivities
from translation_worker.objects import TranslationWorkerConfig
from translation_worker.workflows import TaskQueue, TranslationWorkflow


@pytest.fixture(scope="session")
def test_deps() -> list[ContextManagerFactory]:
return [set_worker_config, set_es_client]


@pytest.fixture(scope="session")
def test_worker_config(tmp_path_factory: TempPathFactory) -> TranslationWorkerConfig: # noqa: ANN001, ARG001, F811
tmp_path = tmp_path_factory.mktemp("test-")
audios_root = tmp_path / "audios"
audios_root.mkdir()
artifacts_root = tmp_path / "artifacts"
artifacts_root.mkdir()
workdir = tmp_path / "workdir"
workdir.mkdir()
logging_config = LoggingConfig(
loggers={
datashare_python.__name__: "DEBUG",
translation_worker.__name__: "DEBUG",
},
format=LogFormat.DEFAULT,
)
return TranslationWorkerConfig(
logging=logging_config,
datashare=DatashareClientConfig(url="http://localhost:8080"),
temporal=TemporalClientConfig(host="localhost:7233"),
artifacts_root=artifacts_root,
workdir=workdir,
)


EN = "en"
FR = "fr"
Expand Down Expand Up @@ -85,72 +120,59 @@ def translation_worker_config() -> TranslationWorkerConfig:


@pytest.fixture(scope="session")
async def batching_worker(
test_es_client_session: ESClient, # noqa: F811
async def io_worker(
test_worker_config: TranslationWorkerConfig, # noqa: F811
test_temporal_client_session: TemporalClient, # noqa: F811
event_loop: asyncio.AbstractEventLoop, # noqa: F811
) -> AsyncGenerator[Worker, None]:
es_client = test_es_client_session
temporal_client = test_temporal_client_session
batching_worker_id = f"test-translation-batching-worker-{uuid.uuid4()}"
create_translation_batches = CreateTranslationBatches(
es_client=es_client,
temporal_client=temporal_client,
event_loop=event_loop,
test_deps: list[ContextManagerFactory], # noqa: F811
) -> AsyncGenerator[None, None]:
client = test_temporal_client_session
worker_id = f"test-io-worker-{uuid.uuid4()}"
translation_activities = TranslationActivities(
temporal_client=client, event_loop=event_loop
)
batching_activities = [
resolve_language_alpha_code,
create_translation_batches.create_translation_batches,
translation_activities.translation_worker_config,
translation_activities.create_translation_batches,
]
workflows = [TranslationWorkflow]
batching_worker = Worker(
temporal_client,
identity=batching_worker_id,
task_queue=TaskQueues.CPU,
task_queue = TaskQueue.IO
worker_ctx = worker_context(
worker_id,
activities=batching_activities,
workflows=workflows,
worker_config=test_worker_config,
client=client,
event_loop=event_loop,
task_queue=task_queue,
dependencies=test_deps,
)
async with batching_worker:
t = None
try:
t = asyncio.create_task(batching_worker.run())
yield
except Exception as e: # noqa: BLE001
if t is not None:
t.cancel()
raise e
async with worker_ctx:
yield


@pytest.fixture(scope="session")
async def translation_worker(
test_es_client_session: ESClient, # noqa: F811
async def translation_cpu_worker(
test_worker_config: TranslationWorkerConfig, # noqa: F811
test_temporal_client_session: TemporalClient, # noqa: F811
event_loop: asyncio.AbstractEventLoop, # noqa: F811
test_deps: list[ContextManagerFactory], # noqa: F811
) -> AsyncGenerator[None, None]:
es_client = test_es_client_session
temporal_client = test_temporal_client_session
translation_worker_id = f"test-translation-translate-worker-{uuid.uuid4()}"
translation_activities = [
TranslateDocs(
es_client=es_client,
temporal_client=temporal_client,
event_loop=event_loop,
).translate_docs,
]
with ThreadPoolExecutor() as executor:
translation_worker = Worker(
temporal_client,
identity=translation_worker_id,
task_queue=TaskQueues.GPU,
activities=translation_activities,
activity_executor=executor,
)
async with translation_worker:
t = None
try:
t = asyncio.create_task(translation_worker.run())
yield
except Exception as e: # noqa: BLE001
if t is not None:
t.cancel()
raise e
client = test_temporal_client_session
worker_id = f"test-io-worker-{uuid.uuid4()}"
create_translation_batches = TranslationActivities(
temporal_client=client, event_loop=event_loop
)
translation_activities = [create_translation_batches.translate_docs]
task_queue = TaskQueue.INFERENCE_CPU
worker_ctx = worker_context(
worker_id,
activities=translation_activities,
worker_config=test_worker_config,
client=client,
event_loop=event_loop,
task_queue=task_queue,
dependencies=test_deps,
)
async with worker_ctx:
yield
10 changes: 5 additions & 5 deletions workers/translation-worker/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def mock_get_es_docs(*args, **kwargs):
):
result = await create_translation_batches(
project=TEST_PROJECT,
target_language_alpha_code=EN,
target_language=EN,
)

assert result == []
Expand All @@ -221,7 +221,7 @@ async def mock_get_es_docs(*args, **kwargs):
):
result = await create_translation_batches(
project=TEST_PROJECT,
target_language_alpha_code=EN,
target_language=EN,
)

assert len(result) == 1
Expand All @@ -241,7 +241,7 @@ async def mock_get_es_docs(*args, **kwargs):
):
result = await create_translation_batches(
project=TEST_PROJECT,
target_language_alpha_code=EN,
target_language=EN,
)

assert len(result) == 1
Expand All @@ -264,7 +264,7 @@ async def mock_get_es_docs(*args, **kwargs):
):
result = await create_translation_batches(
project=TEST_PROJECT,
target_language_alpha_code=EN,
target_language=EN,
)

langs = [lang for lang, _ in result]
Expand Down Expand Up @@ -292,7 +292,7 @@ async def mock_get_es_docs(*args, **kwargs):
"translation_worker.activities._get_es_docs", side_effect=mock_get_es_docs
):
result = await create_translation_batches(
project=TEST_PROJECT, target_language_alpha_code=EN, max_batch_byte_len=1000
project=TEST_PROJECT, target_language=EN, max_batch_byte_len=1000
)

_, batches = result[0]
Expand Down
Loading
Loading