Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 70 additions & 22 deletions guardrails/schema/rail_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
from guardrails.utils.xml_utils import xml_to_string
from guardrails.validator_base import OnFailAction, Validator

_VALIDATION_TYPE_MAPPING = {
RailTypes.STRING: ValidationType(SimpleTypes.STRING),
RailTypes.INTEGER: ValidationType(SimpleTypes.INTEGER),
RailTypes.FLOAT: ValidationType(SimpleTypes.NUMBER),
RailTypes.BOOL: ValidationType(SimpleTypes.BOOLEAN),
RailTypes.DATE: ValidationType(SimpleTypes.STRING),
RailTypes.TIME: ValidationType(SimpleTypes.STRING),
RailTypes.DATETIME: ValidationType(SimpleTypes.STRING),
RailTypes.PERCENTAGE: ValidationType(SimpleTypes.STRING),
RailTypes.ENUM: ValidationType(SimpleTypes.STRING),
RailTypes.LIST: ValidationType(SimpleTypes.ARRAY),
RailTypes.OBJECT: ValidationType(SimpleTypes.OBJECT),
}


### RAIL to JSON Schema ###
STRING_TAGS = [
Expand Down Expand Up @@ -107,43 +121,65 @@ def parse_element(
) -> ModelSchema:
"""Takes an XML element Extracts validators to add to the 'validators' list
and validator_map Returns a ModelSchema."""

schema_type = element.tag
if element.tag in STRING_TAGS:
schema_type = RailTypes.STRING
elif element.tag == "output":
schema_type: str = element.attrib.get("type", RailTypes.OBJECT) # type: ignore

description = xml_to_string(element.attrib.get("description"))
# Fast path: avoid xml_to_string if possible, and use .get directly
description_raw = element.attrib.get("description")
description = (
description_raw
if (description_raw is None or isinstance(description_raw, str))
else xml_to_string(description_raw)
)

# Extract validators from RAIL and assign into ProcessedSchema
extract_validators(element, processed_schema, json_path)

json_path = json_path.replace(".*", "")

# Consolidate ModelSchema construction using mapping where possible
if schema_type == RailTypes.STRING:
format = xml_to_string(element.attrib.get("format"))
format_raw = element.attrib.get("format")
format = (
format_raw
if (format_raw is None or isinstance(format_raw, str))
else xml_to_string(format_raw)
)
return ModelSchema(
type=ValidationType(SimpleTypes.STRING),
type=_VALIDATION_TYPE_MAPPING[RailTypes.STRING],
description=description,
format=format,
)
elif schema_type == RailTypes.INTEGER:
format = xml_to_string(element.attrib.get("format"))
format_raw = element.attrib.get("format")
format = (
format_raw
if (format_raw is None or isinstance(format_raw, str))
else xml_to_string(format_raw)
)
return ModelSchema(
type=ValidationType(SimpleTypes.INTEGER),
type=_VALIDATION_TYPE_MAPPING[RailTypes.INTEGER],
description=description,
format=format,
)
elif schema_type == RailTypes.FLOAT:
format = xml_to_string(element.attrib.get("format", RailTypes.FLOAT))
format_raw = element.attrib.get("format", RailTypes.FLOAT)
format = (
format_raw if (isinstance(format_raw, str)) else xml_to_string(format_raw)
)
return ModelSchema(
type=ValidationType(SimpleTypes.NUMBER),
type=_VALIDATION_TYPE_MAPPING[RailTypes.FLOAT],
description=description,
format=format,
)
elif schema_type == RailTypes.BOOL:
return ModelSchema(
type=ValidationType(SimpleTypes.BOOLEAN), description=description
type=_VALIDATION_TYPE_MAPPING[RailTypes.BOOL],
description=description,
)
elif schema_type == RailTypes.DATE:
format = extract_format(
Expand All @@ -152,7 +188,7 @@ def parse_element(
internal_format_attr="date-format",
)
return ModelSchema(
type=ValidationType(SimpleTypes.STRING),
type=_VALIDATION_TYPE_MAPPING[RailTypes.DATE],
description=description,
format=format,
)
Expand All @@ -163,7 +199,7 @@ def parse_element(
internal_format_attr="time-format",
)
return ModelSchema(
type=ValidationType(SimpleTypes.STRING),
type=_VALIDATION_TYPE_MAPPING[RailTypes.TIME],
description=description,
format=format,
)
Expand All @@ -174,7 +210,7 @@ def parse_element(
internal_format_attr="datetime-format",
)
return ModelSchema(
type=ValidationType(SimpleTypes.STRING),
type=_VALIDATION_TYPE_MAPPING[RailTypes.DATETIME],
description=description,
format=format,
)
Expand All @@ -185,22 +221,27 @@ def parse_element(
internal_format_attr="",
)
return ModelSchema(
type=ValidationType(SimpleTypes.STRING),
type=_VALIDATION_TYPE_MAPPING[RailTypes.PERCENTAGE],
description=description,
format=format,
)
elif schema_type == RailTypes.ENUM:
format = xml_to_string(element.attrib.get("format"))
csv = xml_to_string(element.attrib.get("values", "")) or ""
format_raw = element.attrib.get("format")
format = (
format_raw
if (format_raw is None or isinstance(format_raw, str))
else xml_to_string(format_raw)
)
csv_raw = element.attrib.get("values", "")
csv = csv_raw if (isinstance(csv_raw, str)) else (xml_to_string(csv_raw) or "")
values = [v.strip() for v in csv.split(",")] if csv else None
return ModelSchema(
type=ValidationType(SimpleTypes.STRING),
type=_VALIDATION_TYPE_MAPPING[RailTypes.ENUM],
description=description,
format=format,
enum=values,
)
elif schema_type == RailTypes.LIST:
items = None
children = list(element)
num_of_children = len(children)
if num_of_children > 1:
Expand All @@ -216,16 +257,20 @@ def parse_element(
)
items = child_schema.to_dict()
return ModelSchema(
type=ValidationType(SimpleTypes.ARRAY), items=items, description=description
type=_VALIDATION_TYPE_MAPPING[RailTypes.LIST],
items=items,
description=description,
)
elif schema_type == RailTypes.OBJECT:
# Use list comprehensions and avoid extra lookups
properties = {}
required: List[str] = []
for child in element:
name = child.get("name")
child_required = child.get("required", "true") == "true"
if not name:
output_path = json_path.replace("$.", "output.")
# Avoid calling .replace if not needed
logger.warning(
f"{output_path} has a nameless child which is not allowed!"
)
Expand All @@ -236,7 +281,7 @@ def parse_element(
properties[name] = child_schema.to_dict()

object_schema = ModelSchema(
type=ValidationType(SimpleTypes.OBJECT),
type=_VALIDATION_TYPE_MAPPING[RailTypes.OBJECT],
properties=properties,
description=description,
required=required,
Expand Down Expand Up @@ -264,7 +309,7 @@ def parse_element(
if not discriminator:
raise ValueError("<choice /> elements must specify a discriminator!")
discriminator_model = ModelSchema(
type=ValidationType(SimpleTypes.STRING), enum=[]
type=_VALIDATION_TYPE_MAPPING[RailTypes.STRING], enum=[]
)
for choice_case in element:
case_name = choice_case.get("name")
Expand Down Expand Up @@ -309,17 +354,20 @@ def parse_element(
properties = {}
properties[discriminator] = discriminator_model.to_dict()
return ModelSchema(
type=ValidationType(SimpleTypes.OBJECT),
type=_VALIDATION_TYPE_MAPPING[RailTypes.OBJECT],
properties=properties,
required=[discriminator],
allOf=allOf,
description=description,
)
else:
# TODO: What if the user specifies a custom tag _and_ a format?
format = xml_to_string(element.attrib.get("format", schema_type))
format_raw = element.attrib.get("format", schema_type)
format = (
format_raw if (isinstance(format_raw, str)) else xml_to_string(format_raw)
)
return ModelSchema(
type=ValidationType(SimpleTypes.STRING),
type=_VALIDATION_TYPE_MAPPING[RailTypes.STRING],
description=description,
format=format,
)
Expand Down