|
1 | 1 | import collections |
2 | 2 | import copy |
| 3 | +import dataclasses |
3 | 4 | import datetime |
4 | 5 | import json |
5 | 6 | import logging |
@@ -89,22 +90,21 @@ def sleep_mock(): |
89 | 90 | yield sleep |
90 | 91 |
|
91 | 92 |
|
92 | | -class DummyTask(Task): |
| 93 | +@dataclasses.dataclass(frozen=True) |
| 94 | +class DummyResultTask(Task): |
93 | 95 | """ |
94 | | - A Task that simply sleeps and then returns a predetermined _TaskResult. |
| 96 | + A dummy task to directly define a _TaskResult. |
95 | 97 | """ |
96 | 98 |
|
97 | | - def __init__(self, job_id, df_idx, db_update, stats_update): |
98 | | - super().__init__(job_id=job_id, df_idx=df_idx) |
99 | | - self._db_update = db_update or {} |
100 | | - self._stats_update = stats_update or {} |
| 99 | + db_update: dict = dataclasses.field(default_factory=dict) |
| 100 | + stats_update: dict = dataclasses.field(default_factory=dict) |
101 | 101 |
|
102 | 102 | def execute(self) -> _TaskResult: |
103 | 103 | return _TaskResult( |
104 | 104 | job_id=self.job_id, |
105 | 105 | df_idx=self.df_idx, |
106 | | - db_update=self._db_update, |
107 | | - stats_update=self._stats_update, |
| 106 | + db_update=self.db_update, |
| 107 | + stats_update=self.stats_update, |
108 | 108 | ) |
109 | 109 |
|
110 | 110 |
|
@@ -751,10 +751,10 @@ def test_process_threadworker_updates(self, tmp_path, caplog): |
751 | 751 | stats = collections.defaultdict(int) |
752 | 752 |
|
753 | 753 | # Submit tasks covering all cases |
754 | | - pool.submit_task(DummyTask("j-0", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) |
755 | | - pool.submit_task(DummyTask("j-1", df_idx=1, db_update={"status": "queued"}, stats_update=None)) |
756 | | - pool.submit_task(DummyTask("j-2", df_idx=2, db_update=None, stats_update={"queued": 1})) |
757 | | - pool.submit_task(DummyTask("j-3", df_idx=3, db_update=None, stats_update=None)) |
| 754 | + pool.submit_task(DummyResultTask("j-0", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) |
| 755 | + pool.submit_task(DummyResultTask("j-1", df_idx=1, db_update={"status": "queued"}, stats_update={})) |
| 756 | + pool.submit_task(DummyResultTask("j-2", df_idx=2, db_update={}, stats_update={"queued": 1})) |
| 757 | + pool.submit_task(DummyResultTask("j-3", df_idx=3, db_update={}, stats_update={})) |
758 | 758 |
|
759 | 759 | df_initial = pd.DataFrame( |
760 | 760 | { |
@@ -790,8 +790,8 @@ def test_process_threadworker_updates_unknown(self, tmp_path, caplog): |
790 | 790 | pool = _JobManagerWorkerThreadPool(max_workers=2) |
791 | 791 | stats = collections.defaultdict(int) |
792 | 792 |
|
793 | | - pool.submit_task(DummyTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) |
794 | | - pool.submit_task(DummyTask("j-unknown", df_idx=4, db_update={"status": "created"}, stats_update=None)) |
| 793 | + pool.submit_task(DummyResultTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) |
| 794 | + pool.submit_task(DummyResultTask("j-unknown", df_idx=4, db_update={"status": "created"}, stats_update={})) |
795 | 795 |
|
796 | 796 | df_initial = pd.DataFrame( |
797 | 797 | { |
|
0 commit comments