from __future__ import unicode_literals
import six
import copy
from rest_framework.response import Response
from itertools import chain
from rest_framework.serializers import ListSerializer
from drf_sideloading.serializers import SideLoadableSerializer
[docs]class SideloadableRelationsMixin(object):
"""
TODO: Implement some protection for too large queries.
* limit the number of sideloadable elements?
if over limit:
- raise error
- show warning
- paginate, show first page and add a link to remaining paginated list of related elements?
- show only the link to paginated list of related elements?
"""
query_param_name = 'sideload'
sideloading_serializer_class = None
_primary_field_name = None
_sideloadable_fields = None
relations_to_sideload = None
def __init__(self, **kwargs):
self.check_sideloading_serializer_class()
self._primary_field_name = self.get_primary_field_name()
self._sideloadable_fields = self.get_sideloadable_fields()
self._prefetches = self.get_sideloading_prefetches()
super(SideloadableRelationsMixin, self).__init__(**kwargs)
[docs] def check_sideloading_serializer_class(self):
assert self.sideloading_serializer_class is not None, (
"'%s' should either include a `sideloading_serializer_class` attribute, "
# "or override the `get_sideloading_serializer_class()` method."
% self.__class__.__name__
)
assert issubclass(self.sideloading_serializer_class, SideLoadableSerializer), (
"'%s' `sideloading_serializer_class` must be a SideLoadableSerializer subclass"
% self.__class__.__name__
)
assert not getattr(self.sideloading_serializer_class, 'many', None), (
'Sideloadable serializer can not be \'many=True\'!'
)
# Check Meta class
assert hasattr(self.sideloading_serializer_class, 'Meta'), (
'Sideloadable serializer must have a Meta class defined with the \'primary\' field name!'
)
assert getattr(self.sideloading_serializer_class.Meta, 'primary', None), (
'Sideloadable serializer must have a Meta attribute called primary!'
)
assert self.sideloading_serializer_class.Meta.primary in self.sideloading_serializer_class._declared_fields, (
'Sideloadable serializer Meta.primary must point to a field in the serializer!'
)
if getattr(self.sideloading_serializer_class.Meta, 'prefetches', None) is not None:
assert isinstance(self.sideloading_serializer_class.Meta.prefetches, dict), (
'Sideloadable serializer Meta attribute \'prefetches\' must be a dict.'
)
# check serializer fields:
for name, field in self.sideloading_serializer_class._declared_fields.items():
assert getattr(field, 'many', None), (
'SideLoadable field \'%s\' must be set as many=True' % name
)
# check serializer fields:
for name, field in self.sideloading_serializer_class._declared_fields.items():
assert getattr(field, 'many', None), (
'SideLoadable field \'%s\' must be set as many=True' % name
)
[docs] def get_primary_field_name(self):
return self.sideloading_serializer_class.Meta.primary
[docs] def get_sideloadable_fields(self):
sideloadable_fields = copy.deepcopy(self.sideloading_serializer_class._declared_fields)
sideloadable_fields.pop(self._primary_field_name, None)
return sideloadable_fields
[docs] def get_sideloading_prefetches(self):
prefetches = getattr(self.sideloading_serializer_class.Meta, 'prefetches', {})
if not prefetches:
return None
cleaned_prefetches = {}
for k, v in prefetches.items():
if v is not None:
if isinstance(v, list):
cleaned_prefetches[k] = v
elif isinstance(v, six.string_types):
cleaned_prefetches[k] = [v]
else:
raise RuntimeError('Sideloadable prefetch values must be presented either as a list or a string')
return cleaned_prefetches
[docs] def list(self, request, *args, **kwargs):
sideload_params = self.parse_query_param(sideload_parameter=request.query_params.get(self.query_param_name, ''))
if not sideload_params:
# do nothing if there is no or empty parameter provided
return super(SideloadableRelationsMixin, self).list(request, *args, **kwargs)
# After this `relations_to_sideload` is safe to use
queryset = self.get_queryset()
# add prefetches if applicable
prefetch_relations = self.get_relevant_prefetches()
if prefetch_relations:
queryset = queryset.prefetch_related(*prefetch_relations)
queryset = self.filter_queryset(queryset)
# create page
page = self.paginate_queryset(queryset)
if page is not None:
sideloadable_page = self.get_sideloadable_page(page)
serializer = self.sideloading_serializer_class(instance=sideloadable_page, context={'request': request})
return self.get_paginated_response(serializer.data)
else:
sideloadable_page = self.get_sideloadable_page_from_queryset(queryset)
serializer = self.sideloading_serializer_class(instance=sideloadable_page, context={'request': request})
return Response(serializer.data)
[docs] def parse_query_param(self, sideload_parameter):
"""
Parse query param and take validated names
:param sideload_parameter string
:return valid relation names list
comma separated relation names may contain invalid or unusable characters.
This function finds string match between requested names and defined relation in view
"""
self.relations_to_sideload = set(sideload_parameter.split(',')) & set(self._sideloadable_fields.keys())
return self.relations_to_sideload
[docs] def get_relevant_prefetches(self):
if not self._prefetches:
return set()
return set(pf for relation in self.relations_to_sideload for pf in self._prefetches.get(relation, []))
[docs] def get_sideloadable_page_from_queryset(self, queryset):
# this works wonders, but can't be used when page is paginated...
sideloadable_page = {self._primary_field_name: queryset}
for relation in self.relations_to_sideload:
if not isinstance(self._sideloadable_fields[relation], ListSerializer):
raise RuntimeError('SideLoadable field \'%s\' must be set as many=True' % relation)
source = self._sideloadable_fields[relation].source or relation
rel_model = self._sideloadable_fields[relation].child.Meta.model
rel_qs = rel_model.objects.filter(pk__in=queryset.values_list(source, flat=True))
sideloadable_page[source] = rel_qs
return sideloadable_page
[docs] def get_sideloadable_page(self, page):
sideloadable_page = {self._primary_field_name: page}
for relation in self.relations_to_sideload:
if not isinstance(self._sideloadable_fields[relation], ListSerializer):
raise RuntimeError('SideLoadable field \'%s\' must be set as many=True' % relation)
source = self._sideloadable_fields[relation].source or relation
sideloadable_page[source] = self.filter_related_objects(related_objects=page, lookup=source)
return sideloadable_page