Skip to content
48 changes: 43 additions & 5 deletions s3file/forms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import html
import logging
import pathlib
import uuid
from html.parser import HTMLParser

from django.conf import settings
from django.templatetags.static import static
Expand All @@ -16,6 +18,43 @@
logger = logging.getLogger("s3file")


class InputToS3FileRewriter(HTMLParser):
"""HTML parser that rewrites <input type="file"> to <s3-file> custom elements."""

def __init__(self):
super().__init__()
self.output = []

def handle_starttag(self, tag, attrs):
if tag == "input" and dict(attrs).get("type") == "file":
self.output.append("<s3-file")
for name, value in attrs:
if name != "type":
self.output.append(f' {name}="{html.escape(value, quote=True)}"' if value else f" {name}")
self.output.append(">")
else:
self.output.append(self.get_starttag_text())

def handle_endtag(self, tag):
self.output.append(f"</{tag}>")

def handle_data(self, data):
self.output.append(data)

def handle_startendtag(self, tag, attrs):
if tag == "input" and dict(attrs).get("type") == "file":
self.output.append("<s3-file")
for name, value in attrs:
if name != "type":
self.output.append(f' {name}="{html.escape(value, quote=True)}"' if value else f" {name}")
self.output.append(">")
else:
self.output.append(self.get_starttag_text())

def get_html(self):
return "".join(self.output)


@html_safe
class Asset:
"""A generic asset that can be included in a template."""
Expand Down Expand Up @@ -99,11 +138,10 @@ def build_attrs(self, *args, **kwargs):

def render(self, name, value, attrs=None, renderer=None):
"""Render the widget as a custom element for Safari compatibility."""
return mark_safe( # noqa: S308
str(super().render(name, value, attrs=attrs, renderer=renderer)).replace(
f'<input type="{self.input_type}"', "<s3-file"
)
)
html_output = str(super().render(name, value, attrs=attrs, renderer=renderer))
parser = InputToS3FileRewriter()
parser.feed(html_output)
return mark_safe(parser.get_html()) # noqa: S308

def get_conditions(self, accept):
conditions = [
Expand Down
63 changes: 63 additions & 0 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,54 @@ def test_str(self, settings):
assert str(js) == '<script src="/static/path" type="module"></script>'


class TestInputToS3FileRewriter:
def test_transforms_file_input(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test">')
assert parser.get_html() == '<s3-file name="test">'

def test_preserves_non_file_input(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="text" name="test">')
assert parser.get_html() == '<input type="text" name="test">'

def test_handles_attribute_ordering(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input name="test" type="file" class="foo">')
result = parser.get_html()
assert result.startswith('<s3-file')
assert 'name="test"' in result
assert 'class="foo"' in result
assert 'type="file"' not in result

def test_handles_multiple_attributes(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test" accept="image/*" required multiple>')
result = parser.get_html()
assert result.startswith('<s3-file')
assert 'name="test"' in result
assert 'accept="image/*"' in result
assert 'required' in result
assert 'multiple' in result

def test_escapes_html_entities(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test" data-value="test&value">')
result = parser.get_html()
assert 'data-value="test&amp;value"' in result

def test_handles_self_closing_tag(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test" />')
assert parser.get_html() == '<s3-file name="test">'

def test_preserves_surrounding_elements(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<p><input type="file" name="test"></p>')
result = parser.get_html()
assert result == '<p><s3-file name="test"></p>'


@contextmanager
def wait_for_page_load(driver, timeout=30):
old_page = driver.find_element(By.TAG_NAME, "html")
Expand Down Expand Up @@ -186,6 +234,21 @@ def test_render_wraps_in_s3_file_element(self, freeze_upload_folder):
# Check that the output is the s3-file custom element
assert html.startswith("<s3-file")

def test_render_preserves_attributes(self, freeze_upload_folder):
widget = ClearableFileInput(attrs={"class": "test-class", "accept": "image/*"})
html = widget.render(name="file", value=None)
assert html.startswith("<s3-file")
assert 'name="file"' in html
assert 'class="test-class"' in html
assert 'accept="image/*"' in html
assert 'type="file"' not in html

def test_render_excludes_type_attribute(self, freeze_upload_folder):
widget = ClearableFileInput()
html = widget.render(name="file", value=None)
assert 'type="file"' not in html
assert html.startswith("<s3-file")

@pytest.mark.selenium
def test_no_js_error(self, driver, live_server):
driver.get(live_server + self.create_url)
Expand Down
Loading