|
1 | 1 | import logging |
2 | | -import importlib |
| 2 | +from collections.abc import Callable |
3 | 3 | from celery import current_task |
4 | 4 | from asgi_correlation_id import correlation_id |
5 | 5 |
|
6 | 6 | from app.celery.celery_app import celery_app |
| 7 | +import app.services.llm.jobs as _llm_jobs |
| 8 | +import app.services.response.jobs as _response_jobs |
| 9 | +import app.services.doctransform.job as _doctransform_job |
| 10 | +import app.services.collections.create_collection as _create_collection |
| 11 | +import app.services.collections.delete_collection as _delete_collection |
| 12 | +import app.services.stt_evaluations.batch_job as _stt_batch_job |
| 13 | +import app.services.stt_evaluations.metric_job as _stt_metric_job |
| 14 | +import app.services.tts_evaluations.batch_job as _tts_batch_job |
| 15 | +import app.services.tts_evaluations.batch_result_processing as _tts_result_processing |
7 | 16 |
|
8 | 17 | logger = logging.getLogger(__name__) |
9 | 18 |
|
| 19 | +# Hardcoded dispatch table — avoids dynamic importlib at task execution time. |
| 20 | +# Imports above happen once in the main Celery process before worker forks, |
| 21 | +# so all child workers inherit them via copy-on-write instead of each loading |
| 22 | +# them independently (which was causing OOM with warmup_job_modules). |
| 23 | +_FUNCTION_REGISTRY: dict[str, Callable] = { |
| 24 | + "app.services.llm.jobs.execute_job": _llm_jobs.execute_job, |
| 25 | + "app.services.llm.jobs.execute_chain_job": _llm_jobs.execute_chain_job, |
| 26 | + "app.services.response.jobs.execute_job": _response_jobs.execute_job, |
| 27 | + "app.services.doctransform.job.execute_job": _doctransform_job.execute_job, |
| 28 | + "app.services.collections.create_collection.execute_job": _create_collection.execute_job, |
| 29 | + "app.services.collections.delete_collection.execute_job": _delete_collection.execute_job, |
| 30 | + "app.services.stt_evaluations.batch_job.execute_batch_submission": _stt_batch_job.execute_batch_submission, |
| 31 | + "app.services.stt_evaluations.metric_job.execute_metric_computation": _stt_metric_job.execute_metric_computation, |
| 32 | + "app.services.tts_evaluations.batch_job.execute_batch_submission": _tts_batch_job.execute_batch_submission, |
| 33 | + "app.services.tts_evaluations.batch_result_processing.execute_tts_result_processing": _tts_result_processing.execute_tts_result_processing, |
| 34 | +} |
| 35 | + |
10 | 36 |
|
11 | 37 | @celery_app.task(bind=True, queue="high_priority") |
12 | 38 | def execute_high_priority_task( |
@@ -85,10 +111,11 @@ def _execute_job_internal( |
85 | 111 | logger.info(f"Set correlation ID context: {trace_id} for job {job_id}") |
86 | 112 |
|
87 | 113 | try: |
88 | | - # Dynamically import and resolve the function |
89 | | - module_path, function_name = function_path.rsplit(".", 1) |
90 | | - module = importlib.import_module(module_path) |
91 | | - execute_function = getattr(module, function_name) |
| 114 | + execute_function = _FUNCTION_REGISTRY.get(function_path) |
| 115 | + if execute_function is None: |
| 116 | + raise ValueError( |
| 117 | + f"[_execute_job_internal] Unknown function path: {function_path}" |
| 118 | + ) |
92 | 119 |
|
93 | 120 | logger.info( |
94 | 121 | f"Executing {priority} job {job_id} (task {task_id}) using function {function_path}" |
|
0 commit comments