Skip to content

Commit 1de7f5b

Browse files
authored
Merge pull request #8 from mithril-security/update-python-docstrings
Update python docstrings
2 parents 3c43561 + 42b99dc commit 1de7f5b

File tree

1 file changed

+134
-104
lines changed

1 file changed

+134
-104
lines changed

client/blindai_preview/client.py

Lines changed: 134 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ def __init__(self, fact, datum_type, node_name=None):
9696

9797

9898
class Tensor:
99-
"""
100-
Tensor class to convert serialized tensors into convenients objects
101-
"""
99+
"""Tensor class to convert serialized tensors into convenients objects."""
102100

103101
info: Union[TensorInfo, dict]
104102
bytes_data: bytes
@@ -108,7 +106,8 @@ def __init__(self, info: Union[TensorInfo, dict], bytes_data: bytes):
108106
self.bytes_data = bytes_data
109107

110108
def as_flat(self) -> list:
111-
"""Convert the prediction calculated by the server to a flat python list."""
109+
"""Convert the prediction calculated by the server to a flat python
110+
list."""
112111
return self.as_numpy().tolist()
113112

114113
def as_numpy(self):
@@ -119,9 +118,10 @@ def as_numpy(self):
119118
return arr
120119

121120
def as_torch(self):
122-
"""
123-
Convert the prediction calculated by the server to a Torch Tensor.
124-
As torch is heavy it's an optional dependency of the project, and is imported only when needed.
121+
"""Convert the prediction calculated by the server to a Torch Tensor.
122+
123+
As torch is heavy it's an optional dependency of the project, and is
124+
imported only when needed.
125125
126126
Raises: ImportError if torch isn't installed
127127
"""
@@ -235,8 +235,7 @@ def __init__(
235235

236236

237237
def dtype_to_numpy(dtype: ModelDatumType) -> str:
238-
"""
239-
Convert a ModelDatumType to a numpy type
238+
"""Convert a ModelDatumType to a numpy type.
240239
241240
Raises:
242241
ValueError: if numpy doesn't support dtype
@@ -260,8 +259,7 @@ def dtype_to_numpy(dtype: ModelDatumType) -> str:
260259

261260

262261
def dtype_to_torch(dtype: ModelDatumType) -> str:
263-
"""
264-
Convert a ModelDatumType to a torch type
262+
"""Convert a ModelDatumType to a torch type.
265263
266264
Raises:
267265
ValueError: if torch doesn't support dtype
@@ -370,8 +368,7 @@ def _is_numpy_array(tensor) -> bool:
370368
def translate_tensor(
371369
tensor: Any, or_dtype: ModelDatumType, or_shape: Tuple, name=None
372370
) -> Tensor:
373-
"""
374-
Put the flat/numpy/torch tensor into a Tensor object
371+
"""Put the flat/numpy/torch tensor into a Tensor object.
375372
376373
Args:
377374
tensor: flat/numpy/torch tensor
@@ -411,8 +408,7 @@ def translate_tensor(
411408

412409

413410
def translate_tensors(tensors, dtypes, shapes) -> List[dict]:
414-
"""
415-
Put the flat/numpy/torch tensors into a list of Tensor objects
411+
"""Put the flat/numpy/torch tensors into a list of Tensor objects.
416412
417413
Args:
418414
tensor: list or dict of flat/numpy/torch tensors
@@ -473,75 +469,46 @@ def translate_tensors(tensors, dtypes, shapes) -> List[dict]:
473469

474470

475471
class BlindAiConnection(contextlib.AbstractContextManager):
476-
conn: requests.Session
472+
"""A class to represent a connection to a BlindAi server."""
477473

478-
_disable_attestation_checks: bool = False
479-
_disable_untrusted_server_cert_check: bool = False
474+
_conn: requests.Session
480475

481476
def __init__(
482477
self,
483478
addr: str,
484-
untrusted_port: int = 9923,
485-
attested_port: int = 9924,
486-
hazmat_manifest_path: Optional[pathlib.Path] = None,
487-
hazmat_http_on_untrusted_port=False,
479+
untrusted_port: int,
480+
attested_port: int,
481+
hazmat_manifest_path: Optional[pathlib.Path],
482+
hazmat_http_on_untrusted_port: bool,
488483
):
489-
"""
490-
Connect to the BlindAi server
491-
492-
Args:
493-
addr (str): The address of BlindAI server you want to connect to. It can be a domain (such as `example.com` or `localhost`) or an IP
494-
hazmat_manifest_path: By default the built-in Manifest.toml provided by Mithril Security will be use.
495-
You can override the default Manifest.toml by providing a path to your custom Manifest.toml
496-
Caution: The manifest describes which enclave are trustworthy, changing the Manifest.toml can impact the security of the solution.
497-
hazmat_http_on_untrusted_port: By default, the client fetch the attestation elements from the untrusted port of the server
498-
using an HTTPS connection. The certificate should be validated according to your OS defaults.
499-
You can opt out of a HTTPS connection and instead ask the client to connect via HTTP by setting this param to True
500-
Caution: This parameter should never be set to True in production. Using a HTTPS connection is critical to
501-
get a graceful degradation in case of a failure of Intel SGX attestation.
502-
untrusted_port (int, optional): Untrusted connection server port. Defaults to 9923.
503-
attested_port (int, optional): Attested connection server port. Defaults to 9924.
504-
Raises:
505-
HttpError: raised by the requests lib to relay server side errors
506-
ValueError: raised when inputs sanity checks fail
507-
IdentityError: raised when the enclave signature does not match the enclave signature expected in the manifest
508-
EnclaveHeldDataError: raised when the expected enclave held data does not match the one in the quote
509-
QuoteValidationError: raised when the returned quote is invalid (TCB outdated, not signed by the hardware provider...).
510-
AttestationError: raised when the attestation is not valid (enclave settings mismatching, debug mode unallowed...)
511-
"""
484+
"""Connect to a BlindAi service.
512485
513-
uname = platform.uname()
514-
515-
self.client_info = _ClientInfo(
516-
uid=sha256((socket.gethostname() + "-" + getpass.getuser()).encode("utf-8"))
517-
.digest()
518-
.hex(),
519-
platform_name=uname.system,
520-
platform_arch=uname.machine,
521-
platform_version=uname.version,
522-
platform_release=uname.release,
523-
user_agent="blindai_python",
524-
user_agent_version=app_version,
525-
)
526-
527-
self._connect_server(
528-
addr,
529-
untrusted_port,
530-
attested_port,
531-
hazmat_manifest_path,
532-
hazmat_http_on_untrusted_port,
533-
)
486+
Please refer to the connect function for documentation.
534487
535-
def _connect_server(
536-
self,
537-
addr: str,
538-
untrusted_port,
539-
attested_port,
540-
manifest_path,
541-
http_on_untrusted_port,
542-
):
488+
Args:
489+
addr (str):
490+
untrusted_port (int):
491+
attested_port (int):
492+
hazmat_manifest_path (Optional[pathlib.Path]):
493+
hazmat_http_on_untrusted_port (bool):
543494
544-
if http_on_untrusted_port:
495+
Returns:
496+
"""
497+
# uname = platform.uname()
498+
499+
# self.client_info = _ClientInfo(
500+
# uid=sha256((socket.gethostname() + "-" + getpass.getuser()).encode("utf-8"))
501+
# .digest()
502+
# .hex(),
503+
# platform_name=uname.system,
504+
# platform_arch=uname.machine,
505+
# platform_version=uname.version,
506+
# platform_release=uname.release,
507+
# user_agent="blindai_python",
508+
# user_agent_version=app_version,
509+
# )
510+
511+
if hazmat_http_on_untrusted_port:
545512
self._untrusted_url = "http://" + addr + ":" + str(untrusted_port)
546513
else:
547514
self._untrusted_url = "https://" + addr + ":" + str(untrusted_port)
@@ -569,7 +536,9 @@ def cert_verify(self, conn, url, verify, cert):
569536
quote = cbor.loads(s.get(f"{self._untrusted_url}/quote").content)
570537
collateral = cbor.loads(s.get(f"{self._untrusted_url}/collateral").content)
571538

572-
validate_attestation(quote, collateral, cert, manifest_path=manifest_path)
539+
validate_attestation(
540+
quote, collateral, cert, manifest_path=hazmat_manifest_path
541+
)
573542

574543
# requests (http library) takes a path to a file containing the CA
575544
# there is no easy way to give the CA as a string/bytes directly
@@ -590,23 +559,28 @@ def cert_verify(self, conn, url, verify, cert):
590559
# finally try to connect to the enclave
591560
trusted_conn.get(self._attested_url)
592561

593-
self.conn = trusted_conn
562+
self._conn = trusted_conn
594563

595564
def upload_model(
596565
self,
597566
model: str,
598567
model_name: Optional[str] = None,
599568
optimize: bool = True,
600569
) -> UploadResponse:
601-
"""
602-
Upload an inference model to the server.
570+
"""Upload an inference model to the server.
571+
603572
The provided model needs to be in the Onnx format.
573+
604574
***Security & confidentiality warnings:***
605-
*`model`: The model sent on a Onnx format is encrypted in transit via TLS (as all connections). It may be subject to inference Attacks if an adversary is able to query the trained model repeatedly to determine whether or not a particular example is part of the trained dataset model.
575+
model: The model sent on a Onnx format is encrypted in transit via TLS (as all connections).
576+
It may be subject to inference Attacks if an adversary is able to query the trained model
577+
repeatedly to determine whether or not a particular example is part of the trained dataset model.
606578
Args:
607579
model (str): Path to Onnx model file.
608-
model_name (Optional[str], optional): Name of the model. By default, the server will assign a random UUID. You can call the model with the name you specify here.
609-
optimize (bool): Whether tract (our inference engine) should optimize the model or not. Optimzing should only be turned off when tract wasn't able to optimze the model.
580+
model_name (Optional[str], optional): Name of the model. By default, the server will assign a random UUID.
581+
You can call the model with the name you specify here.
582+
optimize (bool): Whether tract (our inference engine) should optimize the model or not.
583+
Optimzing should only be turned off when tract wasn't able to optimze the model.
610584
Raises:
611585
HttpError: raised by the requests lib to relay server side errors
612586
ValueError: raised when inputs sanity checks fail
@@ -628,7 +602,7 @@ def upload_model(
628602
optimize=optimize,
629603
)
630604
bytes_data = cbor.dumps(data.__dict__)
631-
r = self.conn.post(f"{self._attested_url}/upload", data=bytes_data)
605+
r = self._conn.post(f"{self._attested_url}/upload", data=bytes_data)
632606
r.raise_for_status()
633607
send_model_reply = SendModelReply(**cbor.loads(r.content))
634608
ret = UploadResponse(model_id=send_model_reply.model_id)
@@ -641,17 +615,27 @@ def run_model(
641615
dtypes: Optional[List[ModelDatumType]] = None,
642616
shapes: Optional[Union[List[List[int]], List[int]]] = None,
643617
) -> RunModelResponse:
644-
"""
645-
Send data to the server to make a secure inference.
646-
The data provided must be in a list, as the tensor will be rebuilt inside the server.
618+
"""Send data to the server to make a secure inference.
619+
620+
The data provided must be in a list, as the tensor will be rebuilt inside the
621+
server.
622+
647623
***Security & confidentiality warnings:***
648-
*`model_id` : hash of the Onnx model uploaded. the given hash is return via gRPC through the proto files. It's a SHA-256 hash that is generated each time a model is uploaded.
649-
`tensors`: protected in transit and protected when running it on the secure enclave. In the case of a compromised OS, the data is isolated and confidential by SGX design.
624+
model_id: hash of the Onnx model uploaded. the given hash is return via gRPC through the proto files.
625+
It's a SHA-256 hash that is generated each time a model is uploaded.
626+
tensors: protected in transit and protected when running it on the secure enclave.
627+
In the case of a compromised OS, the data is isolated and confidential by SGX design.
628+
650629
Args:
651630
model_id (str): If set, will run a specific model.
652-
input_tensors (Union[List[Any], List[List[Any]]))): The input data. It must be an array of numpy, tensors or flat list of the same type datum_type specified in `upload_model`.
653-
dtypes (Union[List[ModelDatumType], ModelDatumType], optional): The type of data of the data you want to upload. Only required if you are uploading flat lists, will be ignored if you are uploading numpy or tensors (this info will be extracted directly from the tensors/numpys).
654-
shapes (Union[List[List[int]], List[int]], optional): The shape of the data you want to upload. Only required if you are uploading flat lists, will be ignored if you are uploading numpy or tensors (this info will be extracted directly from the tensors/numpys).
631+
input_tensors (Union[List[Any], List[List[Any]]))): The input data. It must be an array of numpy,
632+
tensors or flat list of the same type datum_type specified in `upload_model`.
633+
dtypes (Union[List[ModelDatumType], ModelDatumType], optional): The type of data
634+
of the data you want to upload. Only required if you are uploading flat lists, will be ignored
635+
if you are uploading numpy or tensors (this info will be extracted directly from the tensors/numpys).
636+
shapes (Union[List[List[int]], List[int]], optional): The shape of the data you want to upload.
637+
Only required if you are uploading flat lists, will be ignored if you are uploading numpy
638+
or tensors (this info will be extracted directly from the tensors/numpys).
655639
Raises:
656640
HttpError: raised by the requests lib to relay server side errors
657641
ValueError: raised when inputs sanity checks fail
@@ -662,7 +646,7 @@ def run_model(
662646
tensors = translate_tensors(input_tensors, dtypes, shapes)
663647
run_data = RunModel(model_id=model_id, inputs=tensors)
664648
bytes_run_data = cbor.dumps(run_data.__dict__)
665-
r = self.conn.post(f"{self._attested_url}/run", data=bytes_run_data)
649+
r = self._conn.post(f"{self._attested_url}/run", data=bytes_run_data)
666650
r.raise_for_status()
667651
run_model_reply = RunModelReply(**cbor.loads(r.content))
668652

@@ -675,12 +659,18 @@ def run_model(
675659
return ret
676660

677661
def delete_model(self, model_id: str):
678-
"""
679-
Delete a model in the inference server.
680-
This may be used to free up some memory.
681-
If you did not specify that you wanted your model to be saved on the server, please note that the model will only be present in memory, and will disappear when the server close.
682-
***Security & confidentiality warnings:***
683-
*model_id : If you are using this on the Mithril Security Cloud, you can only delete models that you uploaded. Otherwise, the deletion of a model does only relies on the `model_id`. It doesn't relies on a session token or anything, hence if the `model_id` is known, it's deletion is possible.*
662+
"""Delete a model in the inference server.
663+
664+
This may be used to free up some memory. If you did not specify that you
665+
wanted your model to be saved on the server, please note that the model will
666+
only be present in memory, and will disappear when the server close.
667+
668+
**Security & confidentiality warnings: **
669+
model_id: If you are using this on the Mithril Security Cloud, you can only delete models
670+
that you uploaded. Otherwise, the deletion of a model does only relies on the `model_id`.
671+
It doesn't relies on a session token or anything, hence if the `model_id` is known,
672+
it's deletion is possible.
673+
684674
Args:
685675
model_id (str): The id of the model to remove.
686676
Raises:
@@ -689,21 +679,61 @@ def delete_model(self, model_id: str):
689679
"""
690680
delete_data = DeleteModel(model_id=model_id)
691681
bytes_delete_data = cbor.dumps(delete_data.__dict__)
692-
r = self.conn.post(f"{self._attested_url}/delete", bytes_delete_data)
682+
r = self._conn.post(f"{self._attested_url}/delete", bytes_delete_data)
693683
r.raise_for_status()
694684

695685
def __enter__(self):
696686
"""Return the BlindAiConnection upon entering the runtime context."""
697687
return self
698688

699689
def __exit__(self, *args):
700-
"""Close the connection to BlindAI server and raise any exception triggered within the runtime context."""
701-
self.conn.close()
690+
"""Close the connection to BlindAI server."""
691+
self._conn.close()
702692

703693

704694
from functools import wraps
705695

706696

707-
@wraps(BlindAiConnection.__init__, assigned=("__doc__", "__annotations__"))
708-
def connect(*args, **kwargs):
709-
return BlindAiConnection(*args, **kwargs)
697+
def connect(
698+
addr: str,
699+
untrusted_port: int = 9923,
700+
attested_port: int = 9924,
701+
hazmat_manifest_path: Optional[pathlib.Path] = None,
702+
hazmat_http_on_untrusted_port=False,
703+
) -> BlindAiConnection:
704+
"""Connect to a BlindAi server.
705+
706+
Args:
707+
addr (str): The address of BlindAI server you want to connect to.
708+
It can be a domain (such as "example.com" or "localhost") or an IP
709+
untrusted_port (int, optional): The untrusted port number. Defaults to 9923.
710+
attested_port (int, optional): The attested port number. Defaults to 9924.
711+
hazmat_manifest_path (Optional[pathlib.Path], optional): Path to the Manifest.toml which describes
712+
which enclave are to be accepted.
713+
Defaults to the built-in Manifest.toml provided by Mithril Security as part of the Python package.
714+
You can override the default by providing a path to your own Manifest.toml
715+
Caution: Changing the manifest can impact the security of the solution.
716+
hazmat_http_on_untrusted_port (bool, optional): If set to True, the client will request the attestation elements of
717+
the server using a plain HTTP connection instead of a more secure HTTPS connection. Defaults to False.
718+
Caution: This parameter should never be set to True in production. Using a HTTPS connection is critical to
719+
get a graceful degradation in case of a failure of the Intel SGX attestation.
720+
721+
Raises:
722+
requests.exceptions.RequestException: If a network or server error occurs
723+
ValueError: raised when inputs sanity checks fail
724+
IdentityError: raised when the enclave signature does not match the enclave signature expected in the manifest
725+
EnclaveHeldDataError: raised when the expected enclave held data does not match the one in the quote
726+
QuoteValidationError: raised when the returned quote is invalid (TCB outdated, not signed by the hardware provider...).
727+
AttestationError: raised when the attestation is not valid (enclave settings mismatching, debug mode unallowed...)
728+
729+
Returns:
730+
BlindAiConnection: An object representing an active connection to a BlindAi server
731+
"""
732+
733+
return BlindAiConnection(
734+
addr,
735+
untrusted_port,
736+
attested_port,
737+
hazmat_manifest_path,
738+
hazmat_http_on_untrusted_port,
739+
)

0 commit comments

Comments
 (0)