def test_add_fk_column(self):
class Person(pw.Model):
class Meta:
database = self.db
class Car(pw.Model):
class Meta:
database = self.db
self.evolve_and_check_noop()
peeweedbevolve.unregister(Car)
class Car(pw.Model):
owner = pw.ForeignKeyField(rel_model=Person, null=False)
class Meta:
database = self.db
self.evolve_and_check_noop()
person = Person.create()
car = Car.create(owner=person)
python类ForeignKeyField()的实例源码
def test_change_int_column_to_fk(self):
class Person(pw.Model):
class Meta:
database = self.db
class Car(pw.Model):
owner_id = pw.IntegerField(null=False)
class Meta:
database = self.db
self.evolve_and_check_noop()
person = Person.create()
car = Car.create(owner_id=person.id)
peeweedbevolve.unregister(Car)
class Car(pw.Model):
owner = pw.ForeignKeyField(rel_model=Person, null=False)
class Meta:
database = self.db
self.evolve_and_check_noop()
self.assertEqual(Car.select().first().owner_id, person.id)
self.assertRaises(Exception, lambda: Car.create(owner=-1))
def test_change_fk_column_to_int(self):
class Person(pw.Model):
class Meta:
database = self.db
class Car(pw.Model):
owner = pw.ForeignKeyField(rel_model=Person, null=False)
class Meta:
database = self.db
self.evolve_and_check_noop()
person = Person.create()
car = Car.create(owner=person)
peeweedbevolve.unregister(Car)
class Car(pw.Model):
owner_id = pw.IntegerField(null=False)
class Meta:
database = self.db
self.evolve_and_check_noop()
self.assertEqual(Car.select().first().owner_id, person.id)
Car.create(owner_id=-1) # this should not fail
def test_change_integer_to_fake_fk_column(self):
class Person(pw.Model):
class Meta:
database = self.db
class Car(pw.Model):
owner_id = pw.IntegerField(null=False)
class Meta:
database = self.db
self.evolve_and_check_noop()
car = Car.create(owner_id=-1)
peeweedbevolve.unregister(Car)
class Car(pw.Model):
owner = pw.ForeignKeyField(rel_model=Person, null=False, fake=True)
class Meta:
database = self.db
self.evolve_and_check_noop()
person = Person.create()
car = Car.create(owner=-2)
self.assertEqual(Car.select().count(), 2)
def get_through_model(self):
if not self._through_model:
lhs, rhs = self.get_models()
tables = [model._meta.db_table for model in (lhs, rhs)]
class Meta:
database = self.model_class._meta.database
db_table = '%s_%s_through' % tuple(tables)
indexes = (
((lhs._meta.name, rhs._meta.name),
True),)
validate_backrefs = False
attrs = {
lhs._meta.name: ForeignKeyField(rel_model=lhs),
rhs._meta.name: ForeignKeyField(rel_model=rhs)}
attrs['Meta'] = Meta
self._through_model = type(
'%s%sThrough' % (lhs.__name__, rhs.__name__),
(AioModel,),
attrs)
return self._through_model
def test_callable_related_name():
class Foo(TestModel):
pass
def rel_name(field):
return '{}_{}_ref'.format(field.model_class._meta.name, field.name)
class Bar(TestModel):
fk1 = ForeignKeyField(Foo, related_name=rel_name)
fk2 = ForeignKeyField(Foo, related_name=rel_name)
class Baz(Bar):
pass
assert Foo.bar_fk1_ref.rel_model is Bar
assert Foo.bar_fk2_ref.rel_model is Bar
assert Foo.baz_fk1_ref.rel_model is Baz
assert Foo.baz_fk2_ref.rel_model is Baz
assert not hasattr(Foo, 'bar_set')
assert not hasattr(Foo, 'baz_set')
def test_object_id_descriptor_naming():
class Person(Model):
pass
class Foo(Model):
me = ForeignKeyField(Person, db_column='me', related_name='foo1')
another = ForeignKeyField(Person, db_column='_whatever_',
related_name='foo2')
another2 = ForeignKeyField(Person, db_column='person_id',
related_name='foo3')
plain = ForeignKeyField(Person, related_name='foo4')
assert Foo.me is Foo.me_id
assert Foo.another is Foo._whatever_
assert Foo.another2 is Foo.person_id
assert Foo.plain is Foo.plain_id
with pytest.raises(AttributeError):
Foo.another_id
with pytest.raises(AttributeError):
Foo.another2_id
def test_add_column_foreign_key(self):
'''
Versioned Models should not have foreign key references
'''
another_column = ForeignKeyField(
Menu, related_name='food', null=True, to_field=Menu.id)
migrate(migrator.add_column('food', 'another_column', another_column))
self.assertTableHasColumn('food', 'another_column', ForeignKeyField)
self.assertTableDoesNotHaveColumn('foodversion', 'another_column')
def test_drop_column_not_in_version(self):
another_column = ForeignKeyField(
Menu, related_name='food', null=True, to_field=Menu.id)
migrate(migrator.add_column('food', 'another_column', another_column))
self.assertTableDoesNotHaveColumn('foodversion', 'another_column')
migrate(migrator.drop_column('food', 'another_column'))
self.assertTableDoesNotHaveColumn('food', 'another_column')
def test_rename_column_not_in_version(self):
another_column = ForeignKeyField(
Menu, related_name='food', null=True, to_field=Menu.id)
migrate(migrator.add_column('food', 'another_column', another_column))
self.assertTableDoesNotHaveColumn('foodversion', 'another_column')
migrate(migrator.rename_column('food', 'another_column', 'new_column'))
self.assertTableDoesNotHaveColumn('food', 'another_column')
self.assertTableDoesNotHaveColumn('foodversion', 'another_column')
self.assertTableHasColumn('food', 'new_column')
self.assertTableDoesNotHaveColumn('foodversion', 'new_column')
def test_foreign_key():
tc = TableCreator('awesome')
tc.foreign_key('int', 'user_id', references='user.id', on_delete='cascade', on_update='cascade')
assert isinstance(tc.model.user_id, peewee.ForeignKeyField)
def test_foreign_key_index():
tc = TableCreator('awesome')
tc.foreign_key('int', 'user_id', references='user.id', on_delete='cascade', on_update='cascade')
tc.add_index(('user_id',), False)
assert isinstance(tc.model.user_id, peewee.ForeignKeyField)
assert tc.model._meta.indexes == [(('user_id',), False)]
def foreign_key(self, coltype, name, references, **kwargs):
"""
Add a foreign key to the model.
This has some special cases, which is why it's not handled like all the other column types.
:param name: Name of the foreign key.
:param references: Table name in the format of "table.column" or just
"table" (and id will be default column).
:param kwargs: Additional kwargs to pass to the column instance.
You can also provide "on_delete" and "on_update" to add constraints.
:return: None
"""
try:
rel_table, rel_column = references.split('.', 1)
except ValueError:
rel_table, rel_column = references, 'id'
# Create a dummy model that we can relate this field to.
# Add the foreign key as a local field on the dummy model.
# We only do this so that Peewee can generate the nice foreign key constraint for us.
class DummyRelated(peewee.Model):
class Meta:
primary_key = False
database = peewee.Proxy()
db_table = rel_table
rel_field_class = FIELD_TO_PEEWEE.get(coltype, peewee.IntegerField)
rel_field = rel_field_class()
rel_field.add_to_class(DummyRelated, rel_column)
field = peewee.ForeignKeyField(DummyRelated, db_column=name, to_field=rel_column, **kwargs)
field.add_to_class(self.model, name)
def get_referenced_models(self):
dependencies = []
for field in self.fields:
if isinstance(field, peewee.ForeignKeyField):
dependencies.append(field.rel_model)
return dependencies
def is_foreign_key_field(self, field_name):
field = getattr(self.model, field_name)
if isinstance(field, peewee.ForeignKeyField):
return True
def mark_fks_as_deferred(table_names):
add_fks = []
table_names_to_models = {cls._meta.db_table:cls for cls in all_models.keys() if cls._meta.db_table in table_names}
for model in table_names_to_models.values():
for field in model._meta.sorted_fields:
if isinstance(field, pw.ForeignKeyField):
add_fks.append(field)
if not field.deferred:
field.__pwdbev__not_deferred = True
field.deferred = True
return add_fks
def alter_add_column(db, migrator, ntn, column_name, field):
qc = db.compiler()
operation = migrator.alter_add_column(ntn, column_name, field, generate=True)
to_run = [qc.parse_node(operation)]
if is_mysql(db) and isinstance(field, pw.ForeignKeyField):
op = qc._create_foreign_key(field.model_class, field)
to_run.append(qc.parse_node(op))
return to_run
def _add_fake_fk_field_hook():
init = pw.ForeignKeyField.__init__
def _init(*args, **kwargs):
self = args[0]
if 'fake' in kwargs:
self.fake = kwargs['fake']
del kwargs['fake']
init(*args, **kwargs)
pw.ForeignKeyField.__init__ = _init
def test_create_table_with_fk(self):
class SomeModel(pw.Model):
some_field = pw.CharField(null=True)
class Meta:
database = self.db
class SomeModel2(pw.Model):
some_field2 = pw.CharField(null=True)
some_model = pw.ForeignKeyField(rel_model=SomeModel)
class Meta:
database = self.db
self.evolve_and_check_noop()
sm = SomeModel.create(some_field='woot')
sm2 = SomeModel2.create(some_field2='woot2', some_model=sm)
def test_circular_deps(self):
class SomeModel(pw.Model):
some_model2 = pw.ForeignKeyField(pw.DeferredRelation('SomeModel2'))
class Meta:
database = self.db
class SomeModel2(pw.Model):
some_model = pw.ForeignKeyField(SomeModel)
class Meta:
database = self.db
self.evolve_and_check_noop()
def convert_field(self, name, field):
"""
Convert a single field from a Peewee model field to a validator field.
:param name: Name of the field as defined on this validator.
:param name: Peewee field instance.
:return: Validator field.
"""
pwv_field = ModelValidator.FIELD_MAP.get(field.get_db_field(), StringField)
validators = []
required = not bool(getattr(field, 'null', True))
choices = getattr(field, 'choices', ())
default = getattr(field, 'default', None)
max_length = getattr(field, 'max_length', None)
unique = getattr(field, 'unique', False)
if required:
validators.append(validate_required())
if choices:
validators.append(validate_one_of([c[0] for c in choices]))
if max_length:
validators.append(validate_length(high=max_length))
if unique:
validators.append(validate_model_unique(
field, self.instance.select(), self.pk_field, self.pk_value))
if isinstance(field, peewee.ForeignKeyField):
return ModelChoiceField(
field.rel_model, field.to_field,
default=default, validators=validators)
if isinstance(field, ManyToManyField):
return ManyModelChoiceField(
field.rel_model, field.rel_model._meta.primary_key,
default=default, validators=validators)
return pwv_field(default=default, validators=validators)
def test_related_name_collision(flushdb):
class Foo(TestModel):
f1 = CharField()
with pytest.raises(AttributeError):
class FooRel(TestModel):
foo = ForeignKeyField(Foo, related_name='f1')
def test_meta_rel_for_model():
class User(Model):
pass
class Category(Model):
parent = ForeignKeyField('self')
class Tweet(Model):
user = ForeignKeyField(User)
class Relationship(Model):
from_user = ForeignKeyField(User, related_name='r1')
to_user = ForeignKeyField(User, related_name='r2')
UM = User._meta
CM = Category._meta
TM = Tweet._meta
RM = Relationship._meta
# Simple refs work.
assert UM.rel_for_model(Tweet) is None
assert UM.rel_for_model(Tweet, multi=True) == []
assert UM.reverse_rel_for_model(Tweet) == Tweet.user
assert UM.reverse_rel_for_model(Tweet, multi=True) == [Tweet.user]
# Multi fks.
assert RM.rel_for_model(User) == Relationship.from_user
assert RM.rel_for_model(User, multi=True) == [Relationship.from_user,
Relationship.to_user]
assert UM.reverse_rel_for_model(Relationship) == Relationship.from_user
exp = [Relationship.from_user, Relationship.to_user]
assert UM.reverse_rel_for_model(Relationship, multi=True) == exp
# Self-refs work.
assert CM.rel_for_model(Category) == Category.parent
assert CM.reverse_rel_for_model(Category) == Category.parent
# Field aliases work.
UA = User.alias()
assert TM.rel_for_model(UA) == Tweet.user
def fields(self):
if not self._fields:
self._fields = {}
for name, v in self.val._meta.fields.items():
if isinstance(v, peewee.ForeignKeyField):
name = '%s_id' % name # foreign key
self._fields[name] = v
return self._fields
def to_dict(self, available_columns=None):
data = {}
fields = self.val._meta.fields
for name, v in model_to_dict(self.val, recurse=False).items():
if isinstance(fields[name], peewee.ForeignKeyField):
name = name + '_id'
if self.selected and (name not in self.selected):
continue
data[name] = v
if available_columns:
return dict_filter(data, available_columns)
return data
def _get_args(self, args):
pw_args = []
for field_name, op, value in args:
field = self.view.fields[field_name]
if isinstance(field, peewee.ForeignKeyField):
tfield = field.to_field
else:
tfield = field
conv_func = None
# ?????? peewee ????? int/float ???????????
if isinstance(tfield, peewee.BlobField):
conv_func = to_bin
elif isinstance(tfield, peewee.BooleanField):
conv_func = bool_parse
if conv_func:
try:
if op == 'in':
value = list(map(conv_func, value))
else:
value = conv_func(value)
except binascii.Error:
self.err = RETCODE.INVALID_HTTP_PARAMS, 'Invalid query value for blob: Odd-length string'
return
except ValueError as e:
self.err = RETCODE.INVALID_HTTP_PARAMS, ' '.join(map(str, e.args))
pw_args.append(getattr(field, _peewee_method_map[op])(value))
return pw_args
def _fetch_fields(cls_or_self):
if cls_or_self.model:
cls_or_self.foreign_keys = {}
def wrap(name, field):
if isinstance(field, peewee.ForeignKeyField):
name = '%s_id' % name
cls_or_self.foreign_keys[name] = field.rel_model._meta.db_table
return name
cls_or_self.fields = {wrap(k, v): v for k, v in cls_or_self.model._meta.fields.items()}
cls_or_self.table_name = cls_or_self.model._meta.db_table
def folder_model(database):
class Folder(FieldSignatureMixin, ArchivedMixin):
# This class represents a Folder in a file system. Two Folders with
# the same name cannot exist in the same Folder. If the Folder has
# no Parent Folder, it exists in the top level of the file system.
name = peewee.CharField(max_length=255, null=False)
parent_folder = peewee.ForeignKeyField('self', null=True)
class Meta:
signature_fields = ('name', 'parent_folder')
Folder._meta.database = database.database
Folder.create_table(True)
return Folder
def copy_foreign_keys(self, event):
"""Copies possible foreign key values from the object into the Event,
skipping common keys like modified and created.
Args:
event (Event): The Event instance to copy the FKs into
obj (fleaker.db.Model): The object to pull the values from
"""
event_keys = set(event._meta.fields.keys())
obj_keys = self._meta.fields.keys()
matching_keys = event_keys.intersection(obj_keys)
for key in matching_keys:
# Skip created_by because that will always be the current_user
# for the Event.
if key == 'created_by':
continue
# Skip anything that isn't a FK
if not isinstance(self._meta.fields[key], peewee.ForeignKeyField):
continue
setattr(event, key, getattr(self, key))
# Attempt to set the obj's ID in the correct FK field on Event, if it
# exists. If this conflicts with desired behavior, handle this in the
# respective callback. This does rely on the FK matching the lower case
# version of the class name and that the event isn't trying to delete
# the current record, becuase that ends badly.
possible_key = self.__class__.__name__.lower()
if possible_key in event_keys and event.code != 'AUDIT_DELETE':
setattr(event, possible_key, self)
def _rename_table(operation, migrator, introspector, old_name, new_name):
version_old_name = old_name + 'version'
version_new_name = new_name + 'version'
# save all of the foreign key references
models = introspector.generate_models()
OldVersion = models[version_old_name]
# The name of the original record's primary key
to_field_name = OldVersion._original_record.to_field.name
version_id__original_id = []
for version in OldVersion.select(OldVersion._id, OldVersion._original_record):
version_id__original_id.append((version._id, version._original_record_id))
# drop the foreign key field in the OldVersion model
drop_field = Operation(migrator, 'drop_column', version_old_name, '_original_record_id')
drop_field.run()
# rename the original table
operation.run()
# rename the version table
version_rename_table = Operation(migrator, 'rename_table', version_old_name, version_new_name)
version_rename_table.run()
# lookup the new model so we can add a foreign key to it
models = introspector.generate_models()
NewModel = models[new_name]
# Add a new Foregin key reference
_original_record = ForeignKeyField(
NewModel, null=True, on_delete="SET NULL",
to_field=getattr(NewModel, to_field_name)
)
add_foregin_key = Operation(migrator, 'add_column', version_new_name, '_original_record_id', _original_record)
add_foregin_key.run()
# load the new version model with the foregin key
models = introspector.generate_models()
NewModel = models[new_name]
NewVersionModel = models[version_new_name]
# re link all versions
for _id, _original_record_id in version_id__original_id:
version = NewVersionModel.get(NewVersionModel._id == _id)
# ``to_field_name`` is the name of the original record's primary key
model = NewModel.get(getattr(NewModel, to_field_name) == _original_record_id)
version._original_record = model
version.save()