Skip to content

Commit d5f9c08

Browse files
authored
Merge pull request #133 from whylabs/dev/andy/mlflow
[bugfix] Fix mlfow tracking with predict API
2 parents 4d391a1 + f0964d1 commit d5f9c08

File tree

6 files changed

+35
-12
lines changed

6 files changed

+35
-12
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.1.13-dev3
2+
current_version = 0.1.13-dev8
33
commit = True
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def setup(app):
119119
# built documents.
120120
#
121121
# The short X.Y version.
122-
version = " 0.1.13-dev3"
122+
version = "0.1.13-dev8"
123123
# The full version, including alpha/beta/rc tags.
124124
release = "" # Is set by calling `setup.py docs`
125125

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
[metadata]
66
name = whylogs
7-
version = 0.1.13-dev3
7+
version = 0.1.13-dev8
88
description = Profile and monitor your ML data pipeline end-to-end
99
author = WhyLabs.ai
1010
author-email = [email protected]

src/whylogs/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""WhyLabs version number."""
22

3-
__version__ = "0.1.13-dev3"
3+
__version__ = "0.1.13-dev8"
Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,48 @@
11
import atexit
2+
from typing import Union
3+
4+
import numpy as np
5+
import pandas as pd
26

37
import whylogs
48
import datetime
59

10+
from logging import getLogger
611
from whylogs.app.config import WHYLOGS_YML
712

13+
logger = getLogger(__name__)
14+
15+
PyFuncOutput = Union[pd.DataFrame, pd.Series, np.ndarray, list]
16+
817

918
class ModelWrapper(object):
1019
def __init__(self, model):
1120
self.model = model
1221
self.session = whylogs.get_or_create_session("/opt/ml/model/" + WHYLOGS_YML)
13-
self.logger = self.create_logger()
22+
self.ylog = self.create_logger()
1423
self.last_upload_time = datetime.datetime.utcnow()
15-
atexit.register(self.logger.close)
24+
atexit.register(self.ylog.close)
1625

1726
def create_logger(self):
1827
# TODO: support different rotation mode and support custom name here
1928
return self.session.logger('live', with_rotation_time='m')
2029

21-
def predict(self, df):
22-
self.logger.log_dataframe(df)
23-
output = self.model.predict(df)
24-
self.logger.log_dataframe(df)
30+
def predict(self, data: pd.DataFrame) -> PyFuncOutput:
31+
"""
32+
Wrapper around https://www.mlflow.org/docs/latest/_modules/mlflow/pyfunc.html#PyFuncModel.predict
33+
This allows us to capture input and predictions into whylogs
34+
"""
35+
self.ylog.log_dataframe(data)
36+
output = self.model.predict(data)
37+
38+
if isinstance(output, np.ndarray) or isinstance(output, pd.Series):
39+
data = pd.DataFrame(data=output, columns=['prediction'])
40+
self.ylog.log_dataframe(data)
41+
elif isinstance(output, pd.DataFrame):
42+
self.ylog.log_dataframe(output)
43+
elif isinstance(output, list):
44+
for e in output:
45+
self.ylog.log(feature_name='prediction', value=e)
46+
else:
47+
logger.warning('Unsupported output type: %s', type(output))
2548
return output

src/whylogs/mlflow/patcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88

99
from whylogs.app.logger import Logger
10+
from whylogs import __version__ as whylogs_version
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -153,8 +154,7 @@ def _new_mlflow_conda_env(
153154
):
154155
global _original_mlflow_conda_env
155156
pip_deps = additional_pip_deps or []
156-
if "whylogs" not in pip_deps:
157-
pip_deps.append("whylogs")
157+
pip_deps.append(f"whylogs=={whylogs_version}")
158158
return _original_mlflow_conda_env(
159159
path, additional_conda_deps, pip_deps, additional_conda_channels, install_mlflow
160160
)

0 commit comments

Comments
 (0)