Skip to content

Commit d4234f6

Browse files
committed
task manager added
based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py * classified * this way, gc.collect() will work as intended.
1 parent 1b16c62 commit d4234f6

File tree

4 files changed

+107
-7
lines changed

4 files changed

+107
-7
lines changed

modules/call_queue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import html
44
import time
55

6-
from modules import shared, progress, errors, devices, fifo_lock, profiling
6+
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager
77

88
queue_lock = fifo_lock.FIFOLock()
99

@@ -34,7 +34,7 @@ def f(*args, **kwargs):
3434
progress.start_task(id_task)
3535

3636
try:
37-
res = func(*args, **kwargs)
37+
res = manager.task.run_and_wait_result(func, *args, **kwargs)
3838
progress.record_results(id_task, res)
3939
finally:
4040
progress.finish_task(id_task)

modules/launch_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,17 @@ def configure_for_tests():
463463
def start():
464464
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
465465
import webui
466+
467+
from modules import manager
468+
466469
if '--nowebui' in sys.argv:
467470
webui.api_only()
468471
else:
469472
webui.webui()
470473

474+
manager.task.main_loop()
475+
return
476+
471477

472478
def dump_sysinfo():
473479
from modules import sysinfo

modules/manager.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#
2+
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
3+
#
4+
# Original author comment:
5+
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
6+
# Other gradio calls (like those from extensions) are not influenced.
7+
# By using one single thread to process all major calls, model moving is significantly faster.
8+
#
9+
# 2024/09/28 classified,
10+
11+
import random
12+
import string
13+
import threading
14+
import time
15+
16+
from collections import OrderedDict
17+
18+
19+
class Task:
20+
def __init__(self, **kwargs):
21+
self.__dict__.update(kwargs)
22+
23+
24+
class TaskManager:
25+
last_exception = None
26+
pending_tasks = []
27+
finished_tasks = OrderedDict()
28+
lock = None
29+
running = False
30+
31+
def __init__(self):
32+
self.lock = threading.Lock()
33+
34+
def work(self, task):
35+
try:
36+
task.result = task.func(*task.args, **task.kwargs)
37+
except Exception as e:
38+
task.exception = e
39+
self.last_exception = e
40+
41+
42+
def stop(self):
43+
self.running = False
44+
45+
46+
def main_loop(self):
47+
self.running = True
48+
while self.running:
49+
time.sleep(0.01)
50+
if len(self.pending_tasks) > 0:
51+
with self.lock:
52+
task = self.pending_tasks.pop(0)
53+
54+
self.work(task)
55+
56+
self.finished_tasks[task.task_id] = task
57+
58+
59+
def push_task(self, func, *args, **kwargs):
60+
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
61+
task_id = args[0]
62+
else:
63+
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
64+
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
65+
self.pending_tasks.append(task)
66+
67+
return task.task_id
68+
69+
70+
def run_and_wait_result(self, func, *args, **kwargs):
71+
current_id = self.push_task(func, *args, **kwargs)
72+
73+
while True:
74+
time.sleep(0.01)
75+
if current_id in self.finished_tasks:
76+
finished = self.finished_tasks.pop(current_id)
77+
if finished.exception is not None:
78+
raise finished.exception
79+
80+
return finished.result
81+
82+
83+
task = TaskManager()

webui.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from modules import timer
77
from modules import initialize_util
88
from modules import initialize
9+
from modules import manager
10+
from threading import Thread
911

1012
startup_timer = timer.startup_timer
1113
startup_timer.record("launcher")
@@ -14,6 +16,8 @@
1416

1517
initialize.check_versions()
1618

19+
initialize.initialize()
20+
1721

1822
def create_api(app):
1923
from modules.api.api import Api
@@ -23,12 +27,10 @@ def create_api(app):
2327
return api
2428

2529

26-
def api_only():
30+
def _api_only():
2731
from fastapi import FastAPI
2832
from modules.shared_cmd_options import cmd_opts
2933

30-
initialize.initialize()
31-
3234
app = FastAPI()
3335
initialize_util.setup_middleware(app)
3436
api = create_api(app)
@@ -83,11 +85,10 @@ def abspath(path):
8385
{"!"*25} Warning {"!"*25}''')
8486

8587

86-
def webui():
88+
def _webui():
8789
from modules.shared_cmd_options import cmd_opts
8890

8991
launch_api = cmd_opts.api
90-
initialize.initialize()
9192

9293
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
9394

@@ -177,6 +178,7 @@ def webui():
177178
print("Stopping server...")
178179
# If we catch a keyboard interrupt, we want to stop the server and exit.
179180
shared.demo.close()
181+
manager.task.stop()
180182
break
181183

182184
# disable auto launch webui in browser for subsequent UI Reload
@@ -193,10 +195,19 @@ def webui():
193195
initialize.initialize_rest(reload_script_modules=True)
194196

195197

198+
def api_only():
199+
Thread(target=_api_only, daemon=True).start()
200+
201+
202+
def webui():
203+
Thread(target=_webui, daemon=True).start()
204+
196205
if __name__ == "__main__":
197206
from modules.shared_cmd_options import cmd_opts
198207

199208
if cmd_opts.nowebui:
200209
api_only()
201210
else:
202211
webui()
212+
213+
manager.task.main_loop()

0 commit comments

Comments
 (0)