|
1 | 1 | import atexit |
| 2 | +from typing import Union |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
2 | 6 |
|
3 | 7 | import whylogs |
4 | 8 | import datetime |
5 | 9 |
|
| 10 | +from logging import getLogger |
6 | 11 | from whylogs.app.config import WHYLOGS_YML |
7 | 12 |
|
| 13 | +logger = getLogger(__name__) |
| 14 | + |
| 15 | +PyFuncOutput = Union[pd.DataFrame, pd.Series, np.ndarray, list] |
| 16 | + |
8 | 17 |
|
9 | 18 | class ModelWrapper(object): |
10 | 19 | def __init__(self, model): |
11 | 20 | self.model = model |
12 | 21 | self.session = whylogs.get_or_create_session("/opt/ml/model/" + WHYLOGS_YML) |
13 | | - self.logger = self.create_logger() |
| 22 | + self.ylog = self.create_logger() |
14 | 23 | self.last_upload_time = datetime.datetime.utcnow() |
15 | | - atexit.register(self.logger.close) |
| 24 | + atexit.register(self.ylog.close) |
16 | 25 |
|
17 | 26 | def create_logger(self): |
18 | 27 | # TODO: support different rotation mode and support custom name here |
19 | 28 | return self.session.logger('live', with_rotation_time='m') |
20 | 29 |
|
21 | | - def predict(self, df): |
22 | | - self.logger.log_dataframe(df) |
23 | | - output = self.model.predict(df) |
24 | | - self.logger.log_dataframe(df) |
| 30 | + def predict(self, data: pd.DataFrame) -> PyFuncOutput: |
| 31 | + """ |
| 32 | + Wrapper around https://www.mlflow.org/docs/latest/_modules/mlflow/pyfunc.html#PyFuncModel.predict |
| 33 | + This allows us to capture input and predictions into whylogs |
| 34 | + """ |
| 35 | + self.ylog.log_dataframe(data) |
| 36 | + output = self.model.predict(data) |
| 37 | + |
| 38 | + if isinstance(output, np.ndarray) or isinstance(output, pd.Series): |
| 39 | + data = pd.DataFrame(data=output, columns=['prediction']) |
| 40 | + self.ylog.log_dataframe(data) |
| 41 | + elif isinstance(output, pd.DataFrame): |
| 42 | + self.ylog.log_dataframe(output) |
| 43 | + elif isinstance(output, list): |
| 44 | + for e in output: |
| 45 | + self.ylog.log(feature_name='prediction', value=e) |
| 46 | + else: |
| 47 | + logger.warning('Unsupported output type: %s', type(output)) |
25 | 48 | return output |
0 commit comments