【问题标题】:SQLAlchemy, array_agg, and matching an input listSQLAlchemy、array_agg 和匹配输入列表
【发布时间】:2018-09-01 11:21:46
【问题描述】:

我正在尝试更充分地使用 SQLAlchemy,而不是在遇到困境的第一个迹象时退回到纯 SQL。在这种情况下,我在 Postgres 数据库 (9.5) 中有一个表,它通过将单个项目 atom_id 与组标识符 group_id 相关联,将一组整数存储为一个组。

给定atom_ids 的列表,我希望能够确定group_id 的集合(如果有的话)atom_ids 属于哪个集合。仅使用 group_idatom_id 列即可解决此问题。

现在我试图概括这样一个“组”不仅由atom_ids 的列表组成,还由其他上下文组成。在下面的示例中,列表通过包含 sequence 列进行排序,但从概念上讲,可以使用其他列代替,例如 weight 列,它为每个 atom_id 提供一个 [0,1] 浮点值,表示atom 在组中的“份额”。

以下是演示我的问题的大部分单元测试。

首先,一些设置:

def test_multi_column_grouping(self):
    class MultiColumnGroups(base.Base):
        __tablename__ = 'multi_groups'

        group_id = Column(Integer)
        atom_id = Column(Integer)
        sequence = Column(Integer)  # arbitrary 'other' column.  In this case, an integer, but it could be a float (e.g. weighting factor)

    base.Base.metadata.create_all(self.engine)

    # Insert 6 rows representing 2 different 'groups' of values
    vals = [
        # Group 1
        {'group_id': 1, 'atom_id': 1, 'sequence': 1},
        {'group_id': 1, 'atom_id': 2, 'sequence': 2},
        {'group_id': 1, 'atom_id': 3, 'sequence': 3},
        # Group 2
        {'group_id': 2, 'atom_id': 1, 'sequence': 3},
        {'group_id': 2, 'atom_id': 2, 'sequence': 2},
        {'group_id': 2, 'atom_id': 3, 'sequence': 1},
    ]

    self.session.bulk_save_objects(
        [MultiColumnGroups(**x) for x in vals])
    self.session.flush()

    self.assertEqual(6, len(self.session.query(MultiColumnGroups).all()))

现在,我想查询上表以找出一组特定的输入属于哪个组。我正在使用(命名的)元组列表来表示查询参数。

    from collections import namedtuple
    Entity = namedtuple('Entity', ['atom_id', 'sequence'])
    values_to_match = [
        # (atom_id, sequence)
        Entity(1, 3),
        Entity(2, 2),
        Entity(3, 1),
        ]
    # The above list _should_ match with `group_id == 2`

原始 SQL 解决方案。我宁愿不要依赖于此,因为本练习的一部分是学习更多 SQLAlchemy。

    r = self.session.execute('''
        select group_id
        from multi_groups
        group by group_id
        having array_agg((atom_id, sequence)) = :query_tuples
        ''', {'query_tuples': values_to_match}).fetchone()
    print(r)  # > (2,)
    self.assertEqual(2, r[0])

这是上面的原始 SQL 解决方案,相当直接地转换为 损坏的 SQLAlchemy 查询。运行此程序会产生 psycopg2 错误:(psycopg2.ProgrammingError) operator does not exist: record[] = integer[]。我相信我需要将array_agg 转换为int[]?只要分组列都是整数(如果需要,这是一个可接受的限制),这将起作用,但理想情况下,这将适用于混合类型的输入元组/表列。

    from sqlalchemy import tuple_
    from sqlalchemy.dialects.postgresql import array_agg

    existing_group = self.session.query(MultiColumnGroups).\
        with_entities(MultiColumnGroups.group_id).\
        group_by(MultiColumnGroups.group_id).\
        having(array_agg(tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.sequence)) == values_to_match).\
        one_or_none()

    self.assertIsNotNone(existing_group)
    print('|{}|'.format(existing_group))

上面的session.query() 关闭了吗?我是否在这里蒙蔽了自己,并且错过了可以通过其他方式解决此问题的超级明显的东西?

【问题讨论】:

    标签: python postgresql sqlalchemy


    【解决方案1】:

    我认为您的解决方案会产生不确定的结果,因为组中的行未指定顺序,因此数组聚合和给定数组之间的比较可能会基于此产生真或假:

    [local]:5432 u@sopython*=> select group_id
    [local] u@sopython- > from multi_groups 
    [local] u@sopython- > group by group_id
    [local] u@sopython- > having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
     group_id 
    ----------
            2
    (1 row)
    
    [local]:5432 u@sopython*=> update multi_groups set atom_id = atom_id where atom_id = 2;
    UPDATE 2
    [local]:5432 u@sopython*=> select group_id                                             
    from multi_groups 
    group by group_id
    having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
     group_id 
    ----------
    (0 rows)
    

    您可以对两者应用排序,或者尝试完全不同的方法:您可以使用 relational division 代替数组比较。

    为了进行划分,您必须从Entity 记录列表中形成一个临时关系。同样,有很多方法可以解决这个问题。这是一个使用非嵌套数组的方法:

    In [112]: vtm = select([
         ...:     func.unnest(postgresql.array([
         ...:         getattr(e, f) for e in values_to_match
         ...:     ])).label(f)
         ...:     for f in Entity._fields
         ...: ]).alias()
    

    另一个使用联合:

    In [114]: vtm = union_all(*[
         ...:     select([literal(e.atom_id).label('atom_id'),
         ...:             literal(e.sequence).label('sequence')])
         ...:     for e in values_to_match
         ...: ]).alias()
    

    临时表也可以。

    使用手头的新关系,您希望找到“找到那些不存在组中不存在实体的multi_groups”的答案。这是一个可怕的句子,但有道理:

    In [117]: mg = aliased(MultiColumnGroups)
    
    In [119]: session.query(MultiColumnGroups.group_id).\
         ...:     filter(~exists().
         ...:         select_from(vtm).
         ...:         where(~exists().
         ...:             where(MultiColumnGroups.group_id == mg.group_id).
         ...:             where(tuple_(vtm.c.atom_id, vtm.c.sequence) ==
         ...:                   tuple_(mg.atom_id, mg.sequence)).
         ...:             correlate_except(mg))).\
         ...:     distinct().\
         ...:     all()
         ...: 
    Out[119]: [(2)]
    

    另一方面,您也可以只选择组与给定实体的交集:

    In [19]: gs = intersect(*[
        ...:     session.query(MultiColumnGroups.group_id).
        ...:         filter(MultiColumnGroups.atom_id == vtm.atom_id,
        ...:                MultiColumnGroups.sequence == vtm.sequence)
        ...:     for vtm in values_to_match
        ...: ])
    
    In [20]: session.execute(gs).fetchall()
    Out[20]: [(2,)]
    

    错误

    ProgrammingError: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]
    LINE 3: ...gg((multi_groups.atom_id, multi_groups.sequence)) = ARRAY[AR...
                                                                 ^
    HINT:  No operator matches the given name and argument type(s). You might need to add explicit type casts.
     [SQL: 'SELECT multi_groups.group_id AS multi_groups_group_id \nFROM multi_groups GROUP BY multi_groups.group_id \nHAVING array_agg((multi_groups.atom_id, multi_groups.sequence)) = %(array_agg_1)s'] [parameters: {'array_agg_1': [[1, 3], [2, 2], [3, 1]]}] (Background on this error at: http://sqlalche.me/e/f405)
    

    是您的values_to_match 首先转换为列表列表(原因未知)然后converted to an array by your DB-API driver 的结果。它产生一个整数数组,而不是记录数组(int,int)。使用 raw DB-API connection 和光标,传递元组列表可以按您的预期工作。

    在 SQLAlchemy 中,如果您将列表 values_to_match 包装为 sqlalchemy.dialects.postgresql.array(),它会按照您的预期工作,但请记住结果是不确定的。

    【讨论】:

    • 感谢您的指导和指出缺陷;这很有帮助。
    猜你喜欢
    • 2014-06-09
    • 1970-01-01
    • 2015-02-23
    • 2016-12-04
    • 1970-01-01
    • 2014-05-31
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多