Source code for cdh.files.db.manager

from django.db import models
from django.db import connections, router, transaction
from django.db.models import Q, QuerySet, signals
from django.db.models.utils import resolve_callables


[docs]class FileManager(models.Manager):
[docs] def get_queryset_from_model(self, model): """Given a model, it will return a QS returning all files for the given model""" if not issubclass(model, models.Model): raise ValueError("Given model is not a Django model!") return self.get_queryset().filter( app_name=model._meta.app_label, # NoQA model_name=model._meta.object_name, # NoQA )
[docs] def get_queryset_from_field(self, field): """Given a field, it will return a QS returning all files for the given field""" from cdh.files.db import FileField if not issubclass(field, FileField): raise ValueError("Given field is not a FileField!") return self.get_queryset().filter( app_name=field.model._meta.app_label, # NoQA model_name=field.model._meta.object_name, # NoQA field_name=field.name, )
[docs]def create_tracked_file_manager(superclass, rel): """ Create a manager for the TrackedFileField's TrackedFileWrapper This manager subclasses another manager, generally the default manager of the related model, and adds behaviors specific to many-to-many relations. """ class TrackedFileManager(superclass): def __init__(self, instance=None): super().__init__() self.instance = instance self.model = rel.model self.query_field_name = rel.field.related_query_name() self.prefetch_cache_name = rel.field.name self.source_field_name = rel.field.m2m_field_name() self.target_field_name = rel.field.m2m_reverse_field_name() self.symmetrical = rel.symmetrical self.through = rel.through self.reverse = False self.source_field = self.through._meta.get_field(self.source_field_name) self.target_field = self.through._meta.get_field(self.target_field_name) self.core_filters = {} self.pk_field_names = {} for lh_field, rh_field in self.source_field.related_fields: core_filter_key = '%s__%s' % (self.query_field_name, rh_field.name) self.core_filters[core_filter_key] = getattr(instance, rh_field.attname) self.pk_field_names[lh_field.name] = rh_field.name self.related_val = self.source_field.get_foreign_related_value(instance) if None in self.related_val: raise ValueError('"%r" needs to have a value for field "%s" before ' 'this many-to-many relationship can be used.' % (instance, self.pk_field_names[self.source_field_name])) # Even if this relation is not to pk, we require still pk value. # The wish is that the instance has been already saved to DB, # although having a pk value isn't a guarantee of that. if instance.pk is None: raise ValueError("%r instance needs to have a primary key value before " "a many-to-many relationship can be used." % instance.__class__.__name__) def __call__(self, *, manager): manager = getattr(self.model, manager) manager_class = create_tracked_file_manager(manager.__class__, rel) return manager_class(instance=self.instance) do_not_call_in_templates = True def _build_remove_filters(self, removed_vals): filters = Q(**{self.source_field_name: self.related_val}) # No need to add a subquery condition if removed_vals is a QuerySet # without filters. removed_vals_filters = (not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()) if removed_vals_filters: filters &= Q(**{'%s__in' % self.target_field_name: removed_vals}) if self.symmetrical: symmetrical_filters = Q(**{self.target_field_name: self.related_val}) if removed_vals_filters: symmetrical_filters &= Q( **{'%s__in' % self.source_field_name: removed_vals}) filters |= symmetrical_filters return filters def _apply_rel_filters(self, queryset): """ Filter the queryset for the instance this manager is bound to. """ queryset._add_hints(instance=self.instance) if self._db: queryset = queryset.using(self._db) queryset._defer_next_filter = True return queryset._next_is_sticky().filter(**self.core_filters) def _remove_prefetched_objects(self): try: self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name) except (AttributeError, KeyError): pass # nothing to clear from cache def get_queryset(self): try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] except (AttributeError, KeyError): queryset = super().get_queryset() return self._apply_rel_filters(queryset) def get_prefetch_queryset(self, instances, queryset=None): if queryset is None: queryset = super().get_queryset() queryset._add_hints(instance=instances[0]) queryset = queryset.using(queryset._db or self._db) query = {'%s__in' % self.query_field_name: instances} queryset = queryset._next_is_sticky().filter(**query) # M2M: need to annotate the query in order to get the primary model # that the secondary model was actually related to. We know that # there will already be a join on the join table, so we can just add # the select. # For non-autocreated 'through' models, can't assume we are # dealing with PK values. fk = self.through._meta.get_field(self.source_field_name) join_table = fk.model._meta.db_table connection = connections[queryset.db] qn = connection.ops.quote_name queryset = queryset.extra(select={ '_prefetch_related_val_%s' % f.attname: '%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields}) return ( queryset, lambda result: tuple( getattr(result, '_prefetch_related_val_%s' % f.attname) for f in fk.local_related_fields ), lambda inst: tuple( f.get_db_prep_value(getattr(inst, f.attname), connection) for f in fk.foreign_related_fields ), False, self.prefetch_cache_name, False, ) def add(self, *objs, through_defaults=None): self._remove_prefetched_objects() db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): self._add_items( self.source_field_name, self.target_field_name, *objs, through_defaults=through_defaults, ) # If this is a symmetrical m2m relation to self, add the mirror # entry in the m2m table. if self.symmetrical: self._add_items( self.target_field_name, self.source_field_name, *objs, through_defaults=through_defaults, ) add.alters_data = True def remove(self, *objs): self._remove_prefetched_objects() self._remove_items(self.source_field_name, self.target_field_name, *objs) remove.alters_data = True def clear(self): db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): signals.m2m_changed.send( sender=self.through, action="pre_clear", instance=self.instance, reverse=self.reverse, model=self.model, pk_set=None, using=db, ) self._remove_prefetched_objects() filters = self._build_remove_filters(super().get_queryset().using(db)) self.through._default_manager.using(db).filter(filters).delete() signals.m2m_changed.send( sender=self.through, action="post_clear", instance=self.instance, reverse=self.reverse, model=self.model, pk_set=None, using=db, ) clear.alters_data = True def set(self, objs, *, clear=False, through_defaults=None): # Force evaluation of `objs` in case it's a queryset whose value # could be affected by `manager.clear()`. Refs #19816. objs = tuple(objs) db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): if clear: self.clear() self.add(*objs, through_defaults=through_defaults) else: old_ids = set(self.using(db).values_list(self.target_field.target_field.attname, flat=True)) new_objs = [] for obj in objs: fk_val = ( self.target_field.get_foreign_related_value(obj)[0] if isinstance(obj, self.model) else self.target_field.get_prep_value(obj) ) if fk_val in old_ids: old_ids.remove(fk_val) else: new_objs.append(obj) self.remove(*old_ids) self.add(*new_objs, through_defaults=through_defaults) set.alters_data = True def create(self, *, through_defaults=None, **kwargs): db = router.db_for_write(self.instance.__class__, instance=self.instance) new_obj = super(TrackedFileManager, self.db_manager(db)).create(**kwargs) self.add(new_obj, through_defaults=through_defaults) return new_obj create.alters_data = True def get_or_create(self, *, through_defaults=None, **kwargs): db = router.db_for_write(self.instance.__class__, instance=self.instance) obj, created = super(TrackedFileManager, self.db_manager(db)).get_or_create(**kwargs) # We only need to add() if created because if we got an object back # from get() then the relationship already exists. if created: self.add(obj, through_defaults=through_defaults) return obj, created get_or_create.alters_data = True def update_or_create(self, *, through_defaults=None, **kwargs): db = router.db_for_write(self.instance.__class__, instance=self.instance) obj, created = super(TrackedFileManager, self.db_manager(db)).update_or_create(**kwargs) # We only need to add() if created because if we got an object back # from get() then the relationship already exists. if created: self.add(obj, through_defaults=through_defaults) return obj, created update_or_create.alters_data = True def _get_target_ids(self, target_field_name, objs): """ Return the set of ids of `objs` that the target field references. """ from django.db.models import Model target_ids = set() target_field = self.through._meta.get_field(target_field_name) for obj in objs: if isinstance(obj, self.model): if not router.allow_relation(obj, self.instance): raise ValueError( 'Cannot add "%r": instance is on database "%s", ' 'value is on database "%s"' % (obj, self.instance._state.db, obj._state.db) ) target_id = target_field.get_foreign_related_value(obj)[0] if target_id is None: raise ValueError( 'Cannot add "%r": the value for field "%s" is None' % (obj, target_field_name) ) target_ids.add(target_id) elif isinstance(obj, Model): raise TypeError( "'%s' instance expected, got %r" % (self.model._meta.object_name, obj) ) else: target_ids.add(target_field.get_prep_value(obj)) return target_ids def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids): """ Return the subset of ids of `objs` that aren't already assigned to this relationship. """ vals = self.through._default_manager.using(db).values_list( target_field_name, flat=True ).filter(**{ source_field_name: self.related_val[0], '%s__in' % target_field_name: target_ids, }) return target_ids.difference(vals) def _get_add_plan(self, db, source_field_name): """ Return a boolean triple of the way the add should be performed. The first element is whether or not bulk_create(ignore_conflicts) can be used, the second whether or not signals must be sent, and the third element is whether or not the immediate bulk insertion with conflicts ignored can be performed. """ # Conflicts can be ignored when the intermediary model is # auto-created as the only possible collision is on the # (source_id, target_id) tuple. The same assertion doesn't hold for # user-defined intermediary models as they could have other fields # causing conflicts which must be surfaced. can_ignore_conflicts = ( connections[db].features.supports_ignore_conflicts and self.through._meta.auto_created is not False ) # Don't send the signal when inserting duplicate data row # for symmetrical reverse entries. must_send_signals = (self.reverse or source_field_name == self.source_field_name) and ( signals.m2m_changed.has_listeners(self.through) ) # Fast addition through bulk insertion can only be performed # if no m2m_changed listeners are connected for self.through # as they require the added set of ids to be provided via # pk_set. return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals) def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None): # source_field_name: the PK fieldname in join table for the source object # target_field_name: the PK fieldname in join table for the target object # *objs - objects to add. Either object instances, or primary keys of object instances. if not objs: return through_defaults = dict(resolve_callables(through_defaults or {})) target_ids = self._get_target_ids(target_field_name, objs) db = router.db_for_write(self.through, instance=self.instance) can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name) if can_fast_add: self.through._default_manager.using(db).bulk_create([ self.through(**{ '%s_id' % source_field_name: self.related_val[0], '%s_id' % target_field_name: target_id, }) for target_id in target_ids ], ignore_conflicts=True) return missing_target_ids = self._get_missing_target_ids( source_field_name, target_field_name, db, target_ids ) with transaction.atomic(using=db, savepoint=False): if must_send_signals: signals.m2m_changed.send( sender=self.through, action='pre_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=missing_target_ids, using=db, ) # Add the ones that aren't there already. self.through._default_manager.using(db).bulk_create([ self.through(**through_defaults, **{ '%s_id' % source_field_name: self.related_val[0], '%s_id' % target_field_name: target_id, }) for target_id in missing_target_ids ], ignore_conflicts=can_ignore_conflicts) if must_send_signals: signals.m2m_changed.send( sender=self.through, action='post_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=missing_target_ids, using=db, ) def _remove_items(self, source_field_name, target_field_name, *objs): # source_field_name: the PK colname in join table for the source object # target_field_name: the PK colname in join table for the target object # *objs - objects to remove. Either object instances, or primary # keys of object instances. if not objs: return # Check that all the objects are of the right type old_ids = set() for obj in objs: if isinstance(obj, self.model): fk_val = self.target_field.get_foreign_related_value(obj)[0] old_ids.add(fk_val) else: old_ids.add(obj) db = router.db_for_write(self.through, instance=self.instance) with transaction.atomic(using=db, savepoint=False): # Send a signal to the other end if need be. signals.m2m_changed.send( sender=self.through, action="pre_remove", instance=self.instance, reverse=self.reverse, model=self.model, pk_set=old_ids, using=db, ) target_model_qs = super().get_queryset() if target_model_qs._has_filters(): old_vals = target_model_qs.using(db).filter(**{ '%s__in' % self.target_field.target_field.attname: old_ids}) else: old_vals = old_ids filters = self._build_remove_filters(old_vals) self.through._default_manager.using(db).filter(filters).delete() signals.m2m_changed.send( sender=self.through, action="post_remove", instance=self.instance, reverse=self.reverse, model=self.model, pk_set=old_ids, using=db, ) return TrackedFileManager