# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
# to make it harder for the user to import the wrong thing without realizing.
import io
from importlib import import_module
import django
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler
from django.test import override_settings, testcases
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test.client import RequestFactory as DjangoRequestFactory
from django.utils.encoding import force_bytes
from django.utils.http import urlencode
from rest_framework.compat import coreapi, requests
from rest_framework.settings import api_settings
def force_authenticate(request, user=None, token=None):
request._force_auth_user = user
request._force_auth_token = token
if requests is not None:
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
return self.getheaders(key)
class MockOriginalResponse:
def __init__(self, headers):
self.msg = HeaderDict(headers)
self.closed = False
def isclosed(self):
return self.closed
def close(self):
self.closed = True
class DjangoTestAdapter(requests.adapters.HTTPAdapter):
"""
A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network.
"""
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()
def get_environ(self, request):
"""
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
"""
method = request.method
url = request.url
kwargs = {}
# Set request content, if any exists.
if request.body is not None:
if hasattr(request.body, 'read'):
kwargs['data'] = request.body.read()
else:
kwargs['data'] = request.body
if 'content-type' in request.headers:
kwargs['content_type'] = request.headers['content-type']
# Set request headers.
for key, value in request.headers.items():
key = key.upper()
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
continue
kwargs['HTTP_%s' % key.replace('-', '_')] = value
return self.factory.generic(method, url, **kwargs).environ
def send(self, request, *args, **kwargs):
"""
Make an outgoing request to the Django WSGI application.
"""
raw_kwargs = {}
def start_response(wsgi_status, wsgi_headers, exc_info=None):
status, _, reason = wsgi_status.partition(' ')
raw_kwargs['status'] = int(status)
raw_kwargs['reason'] = reason
raw_kwargs['headers'] = wsgi_headers
raw_kwargs['version'] = 11
raw_kwargs['preload_content'] = False
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
# Make the outgoing request via WSGI.
environ = self.get_environ(request)
wsgi_response = self.app(environ, start_response)
# Build the underlying urllib3.HTTPResponse
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
# Build the requests.Response
return self.build_response(request, raw)
def close(self):
pass
class RequestsClient(requests.Session):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
adapter = DjangoTestAdapter()
self.mount('http://', adapter)
self.mount('https://', adapter)
def request(self, method, url, *args, **kwargs):
if not url.startswith('http'):
raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
return super().request(method, url, *args, **kwargs)
else:
def RequestsClient(*args, **kwargs):
raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
if coreapi is not None:
class CoreAPIClient(coreapi.Client):
def __init__(self, *args, **kwargs):
self._session = RequestsClient()
kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
super().__init__(*args, **kwargs)
@property
def session(self):
return self._session
else:
def CoreAPIClient(*args, **kwargs):
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
def __init__(self, enforce_csrf_checks=False, **defaults):
self.enforce_csrf_checks = enforce_csrf_checks
self.renderer_classes = {}
for cls in self.renderer_classes_list:
self.renderer_classes[cls.format] = cls
super().__init__(**defaults)
def _encode_data(self, data, format=None, content_type=None):
"""
Encode the data returning a two tuple of (bytes, content_type)
"""
if data is None:
return ('', content_type)
assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
)
if content_type:
# Content type specified explicitly, treat data as a raw bytestring
ret = force_bytes(data, settings.DEFAULT_CHARSET)
else:
format = format or self.default_format
assert format in self.renderer_classes, (
"Invalid format '{}'. Available formats are {}. "
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format(
format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
)
)
# Use format and render the data into a bytestring
renderer = self.renderer_classes[format]()
ret = renderer.render(data)
# Determine the content-type header from the renderer
content_type = renderer.media_type
if renderer.charset:
content_type = "{}; charset={}".format(
content_type, renderer.charset
)
# Coerce text to bytes if required.
if isinstance(ret, str):
ret = ret.encode(renderer.charset)
return ret, content_type
def get(self, path, data=None, **extra):
r = {
'QUERY_STRING': urlencode(data or {}, doseq=True),
}
if not data and '?' in path:
# Fix to support old behavior where you have the arguments in the
# url. See #1461.
query_string = force_bytes(path.split('?')[1])
query_string = query_string.decode('iso-8859-1')
r['QUERY_STRING'] = query_string
r.update(extra)
return self.generic('GET', path, **r)
def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra)
def put(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PUT', path, data, content_type, **extra)
def patch(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PATCH', path, data, content_type, **extra)
def delete(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('DELETE', path, data, content_type, **extra)
def options(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra)
def generic(self, method, path, data='',
content_type='application/octet-stream', secure=False, **extra):
# Include the CONTENT_TYPE, regardless of whether or not data is empty.
if content_type is not None:
extra['CONTENT_TYPE'] = str(content_type)
return super().generic(
method, path, data, content_type, secure, **extra)
def request(self, **kwargs):
request = super().request(**kwargs)
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
return request
class ForceAuthClientHandler(ClientHandler):
"""
A patched version of ClientHandler that can enforce authentication
on the outgoing requests.
"""
def __init__(self, *args, **kwargs):
self._force_user = None
self._force_token = None
super().__init__(*args, **kwargs)
def get_response(self, request):
# This is the simplest place we can hook into to patch the
# request object.
force_authenticate(request, self._force_user, self._force_token)
return super().get_response(request)
class APIClient(APIRequestFactory, DjangoClient):
def __init__(self, enforce_csrf_checks=False, **defaults):
super().__init__(**defaults)
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
self._credentials = {}
def credentials(self, **kwargs):
"""
Sets headers that will be used on every outgoing request.
"""
self._credentials = kwargs
def force_authenticate(self, user=None, token=None):
"""
Forcibly authenticates outgoing requests with the given
user and/or token.
"""
self.handler._force_user = user
self.handler._force_token = token
if user is None and token is None:
self.logout() # Also clear any possible session info if required
def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
kwargs.update(self._credentials)
return super().request(**kwargs)
def get(self, path, data=None, follow=False, **extra):
response = super().get(path, data=data, **extra)
if follow:
response = self._handle_redirects(response, data=data, **extra)
return response
def post(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().post(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response
def put(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().put(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response
def patch(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().patch(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response
def delete(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().delete(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response
def options(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().options(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response
def logout(self):
self._credentials = {}
# Also clear any `force_authenticate`
self.handler._force_user = None
self.handler._force_token = None
if self.session:
super().logout()
class APITransactionTestCase(testcases.TransactionTestCase):
client_class = APIClient
class APITestCase(testcases.TestCase):
client_class = APIClient
class APISimpleTestCase(testcases.SimpleTestCase):
client_class = APIClient
class APILiveServerTestCase(testcases.LiveServerTestCase):
client_class = APIClient
def cleanup_url_patterns(cls):
if hasattr(cls, '_module_urlpatterns'):
cls._module.urlpatterns = cls._module_urlpatterns
else:
del cls._module.urlpatterns
class URLPatternsTestCase(testcases.SimpleTestCase):
"""
Isolate URL patterns on a per-TestCase basis. For example,
class ATestCase(URLPatternsTestCase):
urlpatterns = [...]
def test_something(self):
...
class AnotherTestCase(URLPatternsTestCase):
urlpatterns = [...]
def test_something_else(self):
...
"""
@classmethod
def setUpClass(cls):
# Get the module of the TestCase subclass
cls._module = import_module(cls.__module__)
cls._override = override_settings(ROOT_URLCONF=cls.__module__)
if hasattr(cls._module, 'urlpatterns'):
cls._module_urlpatterns = cls._module.urlpatterns
cls._module.urlpatterns = cls.urlpatterns
cls._override.enable()
if django.VERSION > (4, 0):
cls.addClassCleanup(cls._override.disable)
cls.addClassCleanup(cleanup_url_patterns, cls)
super().setUpClass()
if django.VERSION < (4, 0):
@classmethod
def tearDownClass(cls):
super().tearDownClass()
cls._override.disable()
if hasattr(cls, '_module_urlpatterns'):
cls._module.urlpatterns = cls._module_urlpatterns
else:
del cls._module.urlpatterns