-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Labels
Description
add functionality required for this test load_collection & collect
| # TODO: add functionality required for this test load_collection & collect |
# limitations under the License.
#
from pywy.dataquanta import WayangContext
from pywy.platforms.java import JavaPlugin
from pywy.platforms.spark import SparkPlugin
import pytest
# TODO: add functionality required for this test load_collection & collect
@pytest.mark.skip(reason="no way of currently testing this, since we are missing implementations for load_collection & collect")
def test_train_and_predict():
# Initialize context with platforms
ctx = WayangContext().register({JavaPlugin, SparkPlugin})
# Input features and labels
features = ctx.load_collection([
[1.0, 2.0],
[2.0, 3.0],
[3.0, 4.0],
[4.0, 5.0]
])
labels = ctx.load_collection([3.0, 4.0, 5.0, 6.0])
# Train the model
model = features.train_decision_tree_regression(labels, max_depth=3, min_instances=1)
# Run predictions on same features
predictions = model.predict(features)
# Collect and validate
result = predictions.collect()
print("Predictions:", result)
assert len(result) is 4, f"Expected len(result) to be 4, but got: {len(result)}"
for pred in result:
assert pred is float
assert pred > 1.0
assert pred <= 7.0fac400e04a14d77f47335d07ed4a704d15f02e42