Skip to content

Commit b29b7e7

Browse files
authored
Merge branch 'main' into fest/menu-endpoint
2 parents e3dfc79 + b29d423 commit b29b7e7

File tree

6 files changed

+168
-6
lines changed

6 files changed

+168
-6
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
pip install .
3333
3434
- name: Run coverage
35-
run: coverage run runtests.py
35+
run: coverage run -m pytest
3636

3737
- name: Upload Coverage to Codecov
3838
uses: codecov/[email protected]

djangocms_rest/middleware.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Callable
2+
3+
from django.contrib.sites.shortcuts import get_current_site
4+
from django.contrib.sites.models import Site
5+
from django.http import (
6+
HttpRequest,
7+
HttpResponse,
8+
HttpResponseBadRequest,
9+
HttpResponseNotFound,
10+
)
11+
12+
13+
class SiteContextMiddleware:
14+
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
15+
self.get_response = get_response
16+
17+
def __call__(self, request: HttpRequest) -> HttpResponse:
18+
"""
19+
Process the request to determine the site context.
20+
Sets the site object on the request based on the site ID provided in
21+
the request headers or falls back to the current site.
22+
23+
Args:
24+
request: The HTTP request object
25+
26+
Returns:
27+
Optional[HttpResponse]: Either an HTTP error response if site identification
28+
fails, or None to continue down the middleware chain
29+
"""
30+
site_id = request.headers.get("X-Site-ID")
31+
32+
if site_id:
33+
try:
34+
site_id = int(site_id)
35+
# Using _get_site_by_id directly as it leverages Django's internal site caching
36+
site = Site.objects._get_site_by_id(site_id)
37+
request.site = site
38+
except ValueError:
39+
return HttpResponseBadRequest("Invalid site ID format.")
40+
except Site.DoesNotExist:
41+
return HttpResponseNotFound("The requested site could not be found.")
42+
43+
else:
44+
request.site = get_current_site(request)
45+
return self.get_response(request)

djangocms_rest/views_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from django.contrib.sites.shortcuts import get_current_site
22
from django.utils.functional import cached_property
3-
43
from rest_framework.generics import ListAPIView
54
from rest_framework.permissions import IsAdminUser
65
from rest_framework.views import APIView
@@ -18,7 +17,8 @@ def site(self):
1817
"""
1918
Fetch and cache the current site and make it available to all views.
2019
"""
21-
return get_current_site(self.request)
20+
site = getattr(self.request, "site", None)
21+
return site if site is not None else get_current_site(self.request)
2222

2323
@property
2424
def content_getter(self):
@@ -46,5 +46,4 @@ class BaseListAPIView(BaseAPIMixin, ListAPIView):
4646
"""
4747
This is a base class for all list API views. It supports default pagination and sets the allowed methods to GET and OPTIONS.
4848
"""
49-
5049
pass
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from django.urls import reverse
2+
3+
from tests.base import BaseCMSRestTestCase
4+
5+
from django.contrib.sites.models import Site
6+
from cms.api import create_page, publish_page
7+
8+
9+
class SiteContextMiddlewareTestCase(BaseCMSRestTestCase):
10+
@classmethod
11+
def setUpClass(cls):
12+
"""
13+
Sets up a test environment with multiple sites.
14+
"""
15+
super().setUpClass()
16+
17+
cls.site1 = Site.objects.get(id=1)
18+
cls.site1.domain = "site1.example.com"
19+
cls.site1.name = "Site 1"
20+
cls.site1.save()
21+
22+
cls.site2 = Site.objects.create(domain="site2.example.com", name="Site 2")
23+
24+
@classmethod
25+
def tearDownClass(cls):
26+
"""
27+
Clean up the test environment and reset site 1.
28+
"""
29+
try:
30+
cls.site2.delete()
31+
except Site.DoesNotExist:
32+
pass
33+
34+
try:
35+
cls.site1.domain = "example.com"
36+
cls.site1.name = "example.com"
37+
cls.site1.save()
38+
except Site.DoesNotExist:
39+
pass
40+
41+
super().tearDownClass()
42+
43+
def test_site_middleware_with_header(self):
44+
"""
45+
Test the SiteContextMiddleware correctly handles X-Site-ID header
46+
and returns different pages based on the site id.
47+
48+
Verifies:
49+
- Middleware uses the site ID from X-Site-ID header
50+
- the Same path returns different content based on site ID
51+
- Invalid site ID returns appropriate error
52+
- Missing site ID uses default site
53+
"""
54+
# Create specific test pages with unique titles for each site
55+
site1_test_page = create_page(
56+
title="Site 1 Test Page",
57+
template="page.html",
58+
language="en",
59+
slug="test-page",
60+
site=self.site1,
61+
)
62+
publish_page(site1_test_page, "en", True)
63+
64+
site2_test_page = create_page(
65+
title="Site 2 Test Page",
66+
template="page.html",
67+
language="en",
68+
slug="test-page",
69+
site=self.site2,
70+
)
71+
publish_page(site2_test_page, "en", True)
72+
73+
# Test with site 1 header
74+
response = self.client.get(
75+
reverse("page-detail", kwargs={"language": "en", "path": "test-page"}),
76+
HTTP_X_SITE_ID="1",
77+
)
78+
self.assertEqual(response.status_code, 200)
79+
site1_data = response.json()
80+
81+
# Test with site 2 header
82+
response = self.client.get(
83+
reverse("page-detail", kwargs={"language": "en", "path": "test-page"}),
84+
HTTP_X_SITE_ID="2",
85+
)
86+
self.assertEqual(response.status_code, 200)
87+
site2_data = response.json()
88+
89+
# Compare titles - these should be different
90+
self.assertEqual(site1_data.get("title"), "Site 1 Test Page")
91+
self.assertEqual(site2_data.get("title"), "Site 2 Test Page")
92+
self.assertNotEqual(site1_data.get("title"), site2_data.get("title"))
93+
94+
# Test invalid site ID
95+
response = self.client.get(
96+
reverse("page-detail", kwargs={"language": "en", "path": "test-page"}),
97+
HTTP_X_SITE_ID="999",
98+
)
99+
self.assertEqual(response.status_code, 404)
100+
101+
# Test invalid site ID format
102+
response = self.client.get(
103+
reverse("page-detail", kwargs={"language": "en", "path": "test-page"}),
104+
HTTP_X_SITE_ID="invalid",
105+
)
106+
self.assertEqual(response.status_code, 400)
107+
108+
# Test without a site ID header (should default to site 1)
109+
response = self.client.get(
110+
reverse("page-detail", kwargs={"language": "en", "path": "test-page"})
111+
)
112+
self.assertEqual(response.status_code, 200)
113+
default_data = response.json()
114+
115+
# Verify default site returns the same content as site 1
116+
self.assertEqual(default_data.get("title"), site1_data.get("title"))
117+
self.assertEqual(default_data.get("title"), "Site 1 Test Page")

tests/settings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def __getitem__(self, item):
1919
SECRET_KEY = "djangocms-text-test-suite"
2020

2121
INSTALLED_APPS = [
22+
"django.contrib.sites",
2223
"django.contrib.contenttypes",
2324
"django.contrib.auth",
24-
"django.contrib.sites",
2525
"django.contrib.sessions",
2626
"django.contrib.admin",
2727
"django.contrib.messages",
@@ -40,6 +40,7 @@ def __getitem__(self, item):
4040
"django.contrib.sessions.middleware.SessionMiddleware",
4141
"django.contrib.auth.middleware.AuthenticationMiddleware",
4242
"django.contrib.messages.middleware.MessageMiddleware",
43+
"djangocms_rest.middleware.SiteContextMiddleware",
4344
"cms.middleware.user.CurrentUserMiddleware",
4445
"cms.middleware.page.CurrentPageMiddleware",
4546
"cms.middleware.toolbar.ToolbarMiddleware",

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ deps =
2121
commands =
2222
{envpython} --version
2323
{env:COMMAND:coverage} erase
24-
{env:COMMAND:coverage} run runtests.py
24+
{env:COMMAND:coverage} run -m pytest
2525
{env:COMMAND:coverage} report
2626

2727
[testenv:ruff]

0 commit comments

Comments
 (0)