Skip to content

Commit 049e524

Browse files
committed
Remove isolation
1 parent 9a1cfcd commit 049e524

File tree

1 file changed

+62
-123
lines changed

1 file changed

+62
-123
lines changed

xinference/deploy/cmdline.py

Lines changed: 62 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import asyncio
1615
import configparser
1716
import logging
1817
import os
@@ -35,7 +34,6 @@
3534
XINFERENCE_DEFAULT_LOCAL_HOST,
3635
XINFERENCE_ENV_ENDPOINT,
3736
)
38-
from ..isolation import Isolation
3937
from ..types import ChatCompletionMessage
4038

4139
try:
@@ -352,66 +350,38 @@ def model_generate(
352350
stream: bool,
353351
):
354352
endpoint = get_endpoint(endpoint)
355-
if stream:
356-
357-
async def generate_internal():
358-
while True:
359-
# the prompt will be written to stdout.
360-
# https://docs.python.org/3.10/library/functions.html#input
361-
prompt = input("Prompt: ")
362-
if prompt == "":
363-
break
364-
print(f"Completion: {prompt}", end="", file=sys.stdout)
365-
for chunk in model.generate(
366-
prompt=prompt,
367-
generate_config={"stream": stream, "max_tokens": max_tokens},
368-
):
369-
choice = chunk["choices"][0]
370-
if "text" not in choice:
371-
continue
372-
else:
373-
print(choice["text"], end="", flush=True, file=sys.stdout)
374-
print("\n", file=sys.stdout)
375-
376-
client = RESTfulClient(base_url=endpoint)
377-
model = client.get_model(model_uid=model_uid)
378-
379-
loop = asyncio.get_event_loop()
380-
coro = generate_internal()
381-
382-
if loop.is_running():
383-
isolation = Isolation(asyncio.new_event_loop(), threaded=True)
384-
isolation.start()
385-
isolation.call(coro)
353+
client = RESTfulClient(base_url=endpoint)
354+
model = client.get_model(model_uid=model_uid)
355+
if not isinstance(model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle)):
356+
raise ValueError(f"model {model_uid} has no generate method")
357+
358+
while True:
359+
# the prompt will be written to stdout.
360+
# https://docs.python.org/3.10/library/functions.html#input
361+
prompt = input("Prompt: ")
362+
if prompt.lower() == "exit" or prompt.lower() == "quit":
363+
break
364+
print(f"Completion: {prompt}", end="", file=sys.stdout)
365+
366+
if stream:
367+
for chunk in model.generate(
368+
prompt=prompt,
369+
generate_config={"stream": stream, "max_tokens": max_tokens},
370+
):
371+
choice = chunk["choices"][0]
372+
if "text" not in choice:
373+
continue
374+
else:
375+
print(choice["text"], end="", flush=True, file=sys.stdout)
386376
else:
387-
task = loop.create_task(coro)
388-
try:
389-
loop.run_until_complete(task)
390-
except KeyboardInterrupt:
391-
task.cancel()
392-
loop.run_until_complete(task)
393-
# avoid displaying exception-unhandled warnings
394-
task.exception()
395-
else:
396-
restful_client = RESTfulClient(base_url=endpoint)
397-
restful_model = restful_client.get_model(model_uid=model_uid)
398-
if not isinstance(
399-
restful_model, (RESTfulChatModelHandle, RESTfulGenerateModelHandle)
400-
):
401-
raise ValueError(f"model {model_uid} has no generate method")
402-
403-
while True:
404-
prompt = input("User: ")
405-
if prompt == "":
406-
break
407-
print(f"Assistant: {prompt}", end="", file=sys.stdout)
408-
response = restful_model.generate(
377+
response = model.generate(
409378
prompt=prompt,
410379
generate_config={"stream": stream, "max_tokens": max_tokens},
411380
)
412381
if not isinstance(response, dict):
413382
raise ValueError("generate result is not valid")
414-
print(f"{response['choices'][0]['text']}\n", file=sys.stdout)
383+
print(f"{response['choices'][0]['text']}", file=sys.stdout)
384+
print("\n", file=sys.stdout)
415385

416386

417387
@cli.command("chat")
@@ -431,80 +401,49 @@ def model_chat(
431401
):
432402
# TODO: chat model roles may not be user and assistant.
433403
endpoint = get_endpoint(endpoint)
404+
client = RESTfulClient(base_url=endpoint)
405+
model = client.get_model(model_uid=model_uid)
406+
if not isinstance(
407+
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
408+
):
409+
raise ValueError(f"model {model_uid} has no chat method")
410+
434411
chat_history: "List[ChatCompletionMessage]" = []
435-
if stream:
436-
437-
async def chat_internal():
438-
while True:
439-
# the prompt will be written to stdout.
440-
# https://docs.python.org/3.10/library/functions.html#input
441-
prompt = input("User: ")
442-
if prompt == "":
443-
break
444-
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
445-
print("Assistant: ", end="", file=sys.stdout)
446-
response_content = ""
447-
for chunk in model.chat(
448-
prompt=prompt,
449-
chat_history=chat_history,
450-
generate_config={"stream": stream, "max_tokens": max_tokens},
451-
):
452-
delta = chunk["choices"][0]["delta"]
453-
if "content" not in delta:
454-
continue
455-
else:
456-
response_content += delta["content"]
457-
print(delta["content"], end="", flush=True, file=sys.stdout)
458-
print("\n", file=sys.stdout)
459-
chat_history.append(
460-
ChatCompletionMessage(role="assistant", content=response_content)
461-
)
462-
463-
client = RESTfulClient(base_url=endpoint)
464-
model = client.get_model(model_uid=model_uid)
465-
466-
loop = asyncio.get_event_loop()
467-
coro = chat_internal()
468-
469-
if loop.is_running():
470-
isolation = Isolation(asyncio.new_event_loop(), threaded=True)
471-
isolation.start()
472-
isolation.call(coro)
412+
while True:
413+
# the prompt will be written to stdout.
414+
# https://docs.python.org/3.10/library/functions.html#input
415+
prompt = input("User: ")
416+
if prompt == "":
417+
break
418+
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
419+
print("Assistant: ", end="", file=sys.stdout)
420+
421+
response_content = ""
422+
if stream:
423+
for chunk in model.chat(
424+
prompt=prompt,
425+
chat_history=chat_history,
426+
generate_config={"stream": stream, "max_tokens": max_tokens},
427+
):
428+
delta = chunk["choices"][0]["delta"]
429+
if "content" not in delta:
430+
continue
431+
else:
432+
response_content += delta["content"]
433+
print(delta["content"], end="", flush=True, file=sys.stdout)
473434
else:
474-
task = loop.create_task(coro)
475-
try:
476-
loop.run_until_complete(task)
477-
except KeyboardInterrupt:
478-
task.cancel()
479-
loop.run_until_complete(task)
480-
# avoid displaying exception-unhandled warnings
481-
task.exception()
482-
else:
483-
restful_client = RESTfulClient(base_url=endpoint)
484-
restful_model = restful_client.get_model(model_uid=model_uid)
485-
if not isinstance(
486-
restful_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
487-
):
488-
raise ValueError(f"model {model_uid} has no chat method")
489-
490-
while True:
491-
prompt = input("User: ")
492-
if prompt == "":
493-
break
494-
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
495-
print("Assistant: ", end="", file=sys.stdout)
496-
response = restful_model.chat(
435+
response = model.chat(
497436
prompt=prompt,
498437
chat_history=chat_history,
499438
generate_config={"stream": stream, "max_tokens": max_tokens},
500439
)
501-
if not isinstance(response, dict):
502-
raise ValueError("chat result is not valid")
503440
response_content = response["choices"][0]["message"]["content"]
504-
print(f"{response_content}\n", file=sys.stdout)
505-
chat_history.append(
506-
ChatCompletionMessage(role="assistant", content=response_content)
507-
)
441+
print(f"{response_content}", file=sys.stdout)
442+
443+
chat_history.append(
444+
ChatCompletionMessage(role="assistant", content=response_content)
445+
)
446+
print("\n", file=sys.stdout)
508447

509448

510449
if __name__ == "__main__":

0 commit comments

Comments
 (0)