1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import asyncio
1615import configparser
1716import logging
1817import os
3534 XINFERENCE_DEFAULT_LOCAL_HOST ,
3635 XINFERENCE_ENV_ENDPOINT ,
3736)
38- from ..isolation import Isolation
3937from ..types import ChatCompletionMessage
4038
4139try :
@@ -352,66 +350,38 @@ def model_generate(
352350 stream : bool ,
353351):
354352 endpoint = get_endpoint (endpoint )
355- if stream :
356-
357- async def generate_internal ():
358- while True :
359- # the prompt will be written to stdout.
360- # https://docs.python.org/3.10/library/functions.html#input
361- prompt = input ("Prompt: " )
362- if prompt == "" :
363- break
364- print (f"Completion: { prompt } " , end = "" , file = sys .stdout )
365- for chunk in model .generate (
366- prompt = prompt ,
367- generate_config = {"stream" : stream , "max_tokens" : max_tokens },
368- ):
369- choice = chunk ["choices" ][0 ]
370- if "text" not in choice :
371- continue
372- else :
373- print (choice ["text" ], end = "" , flush = True , file = sys .stdout )
374- print ("\n " , file = sys .stdout )
375-
376- client = RESTfulClient (base_url = endpoint )
377- model = client .get_model (model_uid = model_uid )
378-
379- loop = asyncio .get_event_loop ()
380- coro = generate_internal ()
381-
382- if loop .is_running ():
383- isolation = Isolation (asyncio .new_event_loop (), threaded = True )
384- isolation .start ()
385- isolation .call (coro )
353+ client = RESTfulClient (base_url = endpoint )
354+ model = client .get_model (model_uid = model_uid )
355+ if not isinstance (model , (RESTfulChatModelHandle , RESTfulGenerateModelHandle )):
356+ raise ValueError (f"model { model_uid } has no generate method" )
357+
358+ while True :
359+ # the prompt will be written to stdout.
360+ # https://docs.python.org/3.10/library/functions.html#input
361+ prompt = input ("Prompt: " )
362+ if prompt .lower () == "exit" or prompt .lower () == "quit" :
363+ break
364+ print (f"Completion: { prompt } " , end = "" , file = sys .stdout )
365+
366+ if stream :
367+ for chunk in model .generate (
368+ prompt = prompt ,
369+ generate_config = {"stream" : stream , "max_tokens" : max_tokens },
370+ ):
371+ choice = chunk ["choices" ][0 ]
372+ if "text" not in choice :
373+ continue
374+ else :
375+ print (choice ["text" ], end = "" , flush = True , file = sys .stdout )
386376 else :
387- task = loop .create_task (coro )
388- try :
389- loop .run_until_complete (task )
390- except KeyboardInterrupt :
391- task .cancel ()
392- loop .run_until_complete (task )
393- # avoid displaying exception-unhandled warnings
394- task .exception ()
395- else :
396- restful_client = RESTfulClient (base_url = endpoint )
397- restful_model = restful_client .get_model (model_uid = model_uid )
398- if not isinstance (
399- restful_model , (RESTfulChatModelHandle , RESTfulGenerateModelHandle )
400- ):
401- raise ValueError (f"model { model_uid } has no generate method" )
402-
403- while True :
404- prompt = input ("User: " )
405- if prompt == "" :
406- break
407- print (f"Assistant: { prompt } " , end = "" , file = sys .stdout )
408- response = restful_model .generate (
377+ response = model .generate (
409378 prompt = prompt ,
410379 generate_config = {"stream" : stream , "max_tokens" : max_tokens },
411380 )
412381 if not isinstance (response , dict ):
413382 raise ValueError ("generate result is not valid" )
414- print (f"{ response ['choices' ][0 ]['text' ]} \n " , file = sys .stdout )
383+ print (f"{ response ['choices' ][0 ]['text' ]} " , file = sys .stdout )
384+ print ("\n " , file = sys .stdout )
415385
416386
417387@cli .command ("chat" )
@@ -431,80 +401,49 @@ def model_chat(
431401):
432402 # TODO: chat model roles may not be user and assistant.
433403 endpoint = get_endpoint (endpoint )
404+ client = RESTfulClient (base_url = endpoint )
405+ model = client .get_model (model_uid = model_uid )
406+ if not isinstance (
407+ model , (RESTfulChatModelHandle , RESTfulChatglmCppChatModelHandle )
408+ ):
409+ raise ValueError (f"model { model_uid } has no chat method" )
410+
434411 chat_history : "List[ChatCompletionMessage]" = []
435- if stream :
436-
437- async def chat_internal ():
438- while True :
439- # the prompt will be written to stdout.
440- # https://docs.python.org/3.10/library/functions.html#input
441- prompt = input ("User: " )
442- if prompt == "" :
443- break
444- chat_history .append (ChatCompletionMessage (role = "user" , content = prompt ))
445- print ("Assistant: " , end = "" , file = sys .stdout )
446- response_content = ""
447- for chunk in model .chat (
448- prompt = prompt ,
449- chat_history = chat_history ,
450- generate_config = {"stream" : stream , "max_tokens" : max_tokens },
451- ):
452- delta = chunk ["choices" ][0 ]["delta" ]
453- if "content" not in delta :
454- continue
455- else :
456- response_content += delta ["content" ]
457- print (delta ["content" ], end = "" , flush = True , file = sys .stdout )
458- print ("\n " , file = sys .stdout )
459- chat_history .append (
460- ChatCompletionMessage (role = "assistant" , content = response_content )
461- )
462-
463- client = RESTfulClient (base_url = endpoint )
464- model = client .get_model (model_uid = model_uid )
465-
466- loop = asyncio .get_event_loop ()
467- coro = chat_internal ()
468-
469- if loop .is_running ():
470- isolation = Isolation (asyncio .new_event_loop (), threaded = True )
471- isolation .start ()
472- isolation .call (coro )
412+ while True :
413+ # the prompt will be written to stdout.
414+ # https://docs.python.org/3.10/library/functions.html#input
415+ prompt = input ("User: " )
416+ if prompt == "" :
417+ break
418+ chat_history .append (ChatCompletionMessage (role = "user" , content = prompt ))
419+ print ("Assistant: " , end = "" , file = sys .stdout )
420+
421+ response_content = ""
422+ if stream :
423+ for chunk in model .chat (
424+ prompt = prompt ,
425+ chat_history = chat_history ,
426+ generate_config = {"stream" : stream , "max_tokens" : max_tokens },
427+ ):
428+ delta = chunk ["choices" ][0 ]["delta" ]
429+ if "content" not in delta :
430+ continue
431+ else :
432+ response_content += delta ["content" ]
433+ print (delta ["content" ], end = "" , flush = True , file = sys .stdout )
473434 else :
474- task = loop .create_task (coro )
475- try :
476- loop .run_until_complete (task )
477- except KeyboardInterrupt :
478- task .cancel ()
479- loop .run_until_complete (task )
480- # avoid displaying exception-unhandled warnings
481- task .exception ()
482- else :
483- restful_client = RESTfulClient (base_url = endpoint )
484- restful_model = restful_client .get_model (model_uid = model_uid )
485- if not isinstance (
486- restful_model , (RESTfulChatModelHandle , RESTfulChatglmCppChatModelHandle )
487- ):
488- raise ValueError (f"model { model_uid } has no chat method" )
489-
490- while True :
491- prompt = input ("User: " )
492- if prompt == "" :
493- break
494- chat_history .append (ChatCompletionMessage (role = "user" , content = prompt ))
495- print ("Assistant: " , end = "" , file = sys .stdout )
496- response = restful_model .chat (
435+ response = model .chat (
497436 prompt = prompt ,
498437 chat_history = chat_history ,
499438 generate_config = {"stream" : stream , "max_tokens" : max_tokens },
500439 )
501- if not isinstance (response , dict ):
502- raise ValueError ("chat result is not valid" )
503440 response_content = response ["choices" ][0 ]["message" ]["content" ]
504- print (f"{ response_content } \n " , file = sys .stdout )
505- chat_history .append (
506- ChatCompletionMessage (role = "assistant" , content = response_content )
507- )
441+ print (f"{ response_content } " , file = sys .stdout )
442+
443+ chat_history .append (
444+ ChatCompletionMessage (role = "assistant" , content = response_content )
445+ )
446+ print ("\n " , file = sys .stdout )
508447
509448
510449if __name__ == "__main__" :
0 commit comments