Skip to content

Commit 3bff55d

Browse files
committed
chore: Update backend.py
1 parent 4d61c40 commit 3bff55d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/art/serverless/backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,22 @@ async def _train_model(
137137
),
138138
)
139139
after: str | None = None
140-
num_gradient_steps: int | None = None
140+
num_sequences: int | None = None
141141
pbar: tqdm.tqdm | None = None
142142
while True:
143143
await asyncio.sleep(0.5)
144144
async for event in self._client.training_jobs.events.list(
145145
training_job_id=training_job.id, after=after or NOT_GIVEN
146146
):
147147
if event.type == "gradient_step":
148-
assert pbar is not None and num_gradient_steps is not None
148+
assert pbar is not None and num_sequences is not None
149149
pbar.update(1)
150150
pbar.set_postfix(event.data)
151-
yield {**event.data, "num_gradient_steps": num_gradient_steps}
151+
yield {**event.data, "num_gradient_steps": num_sequences}
152152
elif event.type == "training_started":
153-
num_gradient_steps = event.data["num_gradient_steps"]
153+
num_sequences = event.data["num_sequences"]
154154
if pbar is None:
155-
pbar = tqdm.tqdm(total=num_gradient_steps, desc="train")
155+
pbar = tqdm.tqdm(total=num_sequences, desc="train")
156156
continue
157157
elif event.type == "training_ended":
158158
return

0 commit comments

Comments
 (0)