【问题标题】:How should I test db service module with temp database?我应该如何使用临时数据库测试 db 服务模块?
【发布时间】:2026-02-03 16:50:02
【问题描述】:

我正在寻找使用临时数据库测试服务模块的正确方法。 我有一个可行的示例,但我觉得有更好的解决方案

这是一个测试 db_service 函数的示例(还有更多要检查的函数),我对我提出的解决方案有点恼火,我现在不想为每个 db_service 函数添加这些行。 ~ 你可以看到我在 get_categories 和 original_get_categories 之间的差异中添加的行。

最终我想创建一个可靠的行为,即在默认运行时使用原始应用会话,并在测试时使用临时数据库的会话。

models.py -

from sqlalchemy.orm import declarative_base


Base = declarative_base()


class Category(Base):
    __tablename__ = 'Categories'

    ID = Column(Integer, primary_key=True, autoincrement=True)
    Name = Column(Unicode(100))

db_service.py -

from models import Category
from app import session_scope  # this is the default session my service uses to query

class DBService:
    """provide CRUD operations"""
    
    @staticmethod
        def get_categories(session=None) -> List[Category]:
            if session is None:
                with session_scope() as session:
                    categories = session.query(Category).all()
                    return categories
            categories = session.query(Category).all()
            return categories

    @staticmethod
        def original_get_categories() -> List[Category]:
            with session_scope() as session:
                categories = session.query(Category).all()
                return categories

db_service_test.py -

from typing import List
from fixtures import db
from models import Category
from db_service import DBService


def test_get_all_categories(db):  # uses db fixture to create new database env
    first_category: Category = Category(ID=1, Name="first category")
    second_category: Category = Category(ID=2, Name="second category")
    session = db()
    session.add(first_category)
    session.add(second_category)
    session.commit()

    categories: List[Category] = DBService.get_categories(session)

    assert len(categories) == 2
    assert first_category in categories and second_category in categories

fixtures.py - 它不应该很有趣,但无论如何

import pytest
from models import Base
from sqlalchemy.orm import scoped_session, sessionmaker


@pytest.fixture
def db():
    engine = create_engine('sqlite:///database.sqlite3')
    Base.metadata.create_all(bind=engine)
    db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine))
    yield db_session
    Base.metadata.drop_all(engine)
    db_session().commit()

【问题讨论】:

    标签: python testing sqlalchemy pytest


    【解决方案1】:

    在花了一些时间之后,我想出了两个想法。

    第一个选项 - 将 init 设置为 DBService 类

    from models import Category
    
    
    class DBService:
        """provide CRUD operations"""
    
    
        def __init___(self, session):
            self.session = session
        
    
        def get_categories() -> List[Category]:
            with self.session() as session:
                categories = session.query(Category).all()
                return categories
    

    现在在测试模块上,我只需要使用测试会话初始化 DBService。

    from typing import List
    from fixtures import db
    from models import Category
    from db_service import DBService
    
    
    def test_get_all_categories(db):  # uses db fixture to create new database env
        first_category: Category = Category(ID=1, Name="first category")
        second_category: Category = Category(ID=2, Name="second category")
        session = db()
        session.add(first_category)
        session.add(second_category)
        session.commit()
    
        test_service = DBService(db)
        categories: List[Category] = test_service.get_categories(session)
    
        assert len(categories) == 2
        assert first_category in categories and second_category in categories
    

    第二个选项 - 在每个服务函数中设置可选参数

    from models import Category
    from app import Session # this is the default session my service uses to query
    
    class DBService:
        """provide CRUD operations"""
        
        @staticmethod
        def original_get_categories(session=Session) -> List[Category]:
            with session() as session:
                categories = session.query(Category).all()
                return categories
    

    现在在测试模块上,我只需要将我的测试数据库会话发送给函数。

    from typing import List
    from fixtures import db
    from models import Category
    from db_service import DBService
    
    
    def test_get_all_categories(db):  # uses db fixture to create new database env
        first_category: Category = Category(ID=1, Name="first category")
        second_category: Category = Category(ID=2, Name="second category")
        session = db()
        session.add(first_category)
        session.add(second_category)
        session.commit()
    
        categories: List[Category] = DBService.get_categories(db) #  db, because session is not callable.
    
        assert len(categories) == 2
        assert first_category in categories and second_category in categories
    

    我个人选择第二个选项,我不想在我使用它的任何地方初始化 DBService。

    代码中的较少更改和一个可选参数可以解决我的问题。

    如果您的服务需要保存比会话更多的数据,我会选择第一个选项,更有条理。

    【讨论】: