We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a3dfc9c commit 8a48521Copy full SHA for 8a48521
jina/serve/executors/__init__.py
@@ -393,7 +393,7 @@ def __init__(
393
self._add_dynamic_batching(dynamic_batching)
394
self._add_runtime_args(runtime_args)
395
self.logger = JinaLogger(self.__class__.__name__, **vars(self.runtime_args))
396
- self._validate_sagemaker()
+ self._validate_csp()
397
self._init_instrumentation(runtime_args)
398
self._init_monitoring()
399
self._init_workspace = workspace
@@ -599,14 +599,14 @@ def _add_requests(self, _requests: Optional[Dict]):
599
f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}'
600
)
601
602
- def _validate_sagemaker(self):
603
- # sagemaker expects the POST /invocations endpoint to be defined.
+ def _validate_csp(self):
+ # csp (sagemaker/azure/gcp) expects the POST /invocations endpoint to be defined.
604
# if it is not defined, we check if there is only one endpoint defined,
605
# and if so, we use it as the POST /invocations endpoint, or raise an error
606
if (
607
not hasattr(self, 'runtime_args')
608
or not hasattr(self.runtime_args, 'provider')
609
- or self.runtime_args.provider != ProviderType.SAGEMAKER.value
+ or self.runtime_args.provider not in (ProviderType.SAGEMAKER.value, ProviderType.GCP.value)
610
):
611
return
612
jina/serve/runtimes/asyncio.py
@@ -206,6 +206,23 @@ def _get_server(self):
206
cors=getattr(self.args, 'cors', None),
207
is_cancel=self.is_cancel,
208
209
+ elif (
210
+ hasattr(self.args, 'provider')
211
+ and self.args.provider == ProviderType.GCP
212
+ ):
213
+ from jina.serve.runtimes.servers.http import GCPHTTPServer
214
+
215
+ return GCPHTTPServer(
216
+ name=self.args.name,
217
+ runtime_args=self.args,
218
+ req_handler_cls=self.req_handler_cls,
219
+ proxy=getattr(self.args, 'proxy', None),
220
+ uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
221
+ ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
222
+ ssl_certfile=getattr(self.args, 'ssl_certfile', None),
223
+ cors=getattr(self.args, 'cors', None),
224
+ is_cancel=self.is_cancel,
225
+ )
226
elif not hasattr(self.args, 'protocol') or (
227
len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC
228
jina/serve/runtimes/worker/http_gcp_app.py
@@ -41,7 +41,7 @@ def get_fastapi_app(
41
from jina.serve.runtimes.gateway.models import _to_camel_case
42
43
if not docarray_v2:
44
- logger.warning('Only docarray v2 is supported with Sagemaker. ')
+ logger.warning('Only docarray v2 is supported with GCP. ')
45
46
47
class Header(BaseModel):
@@ -129,7 +129,6 @@ async def process(body) -> output_model:
129
raise HTTPException(status_code=499, detail=status.description)
130
else:
131
return {"predictions": resp.docs}
132
- return output_model(predictions=resp.docs)
133
134
@app.api_route(**app_kwargs)
135
async def post(request: Request):
@@ -175,7 +174,7 @@ async def post(request: Request):
175
174
176
from jina.serve.runtimes.gateway.health_model import JinaHealthModel
177
178
- # `/ping` route is required by AWS Sagemaker
+ # `/ping` route is required by GCP
179
@app.get(
180
path='/ping',
181
summary='Get the health of Jina Executor service',
jina/serve/runtimes/worker/request_handling.py
@@ -326,7 +326,7 @@ def _init_monitoring(
326
if metrics_registry:
327
with ImportExtensions(
328
required=True,
329
- help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
+ help_text='You need to install the `prometheus_client` to use the monitoring functionality of jina',
330
331
from prometheus_client import Counter, Summary
332
tests/integration/docarray_v2/gcp/test_gcp.py
@@ -70,5 +70,25 @@ def test_provider_gcp_pod_inference():
70
assert resp.status_code == 200
71
resp_json = resp.json()
72
assert len(resp_json['predictions']) == 2
73
- print(resp_json)
74
75
+def test_provider_gcp_deployment_inference():
76
+ with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')):
77
+ dep_port = random_port()
78
+ with Deployment(uses='config.yml', provider='gcp', port=dep_port):
79
+ # Test the `GET /ping` endpoint (added by jina for gcp)
80
+ resp = requests.get(f'http://localhost:{dep_port}/ping')
81
+ assert resp.status_code == 200
82
+ assert resp.json() == {}
83
84
+ # Test the `POST /invocations` endpoint
85
+ # Note: this endpoint is not implemented in the sample executor
86
+ resp = requests.post(
87
+ f'http://localhost:{dep_port}/invocations',
88
+ json={
89
+ 'instances': ["hello world", "good apple"]
90
+ },
91
92
93
+ resp_json = resp.json()
94
+ assert len(resp_json['predictions']) == 2
0 commit comments