【问题标题】:Python: Mocking a context managerPython:模拟上下文管理器
【发布时间】:2025-06-22 05:40:01
【问题描述】:

我不明白为什么我不能在这个例子中模拟 NamedTemporaryFile.name:

from mock import Mock, patch
import unittest
import tempfile

def myfunc():
    with tempfile.NamedTemporaryFile() as mytmp:
        return mytmp.name

class TestMock(unittest.TestCase):
    @patch('tempfile.NamedTemporaryFile')
    def test_cm(self, mock_tmp):
        mytmpname = 'abcde'
        mock_tmp.__enter__.return_value.name = mytmpname
        self.assertEqual(myfunc(), mytmpname)

测试结果:

AssertionError: <MagicMock name='NamedTemporaryFile().__enter__().name' id='140275675011280'> != 'abcde'

【问题讨论】:

    标签: python mocking


    【解决方案1】:

    您设置了错误的模拟:mock_tmp 不是上下文管理器,而是返回一个上下文管理器。将您的设置行替换为:

    mock_tmp.return_value.__enter__.return_value.name = mytmpname
    

    你的测试会成功。

    【讨论】:

      【解决方案2】:

      为了扩展纳撒尼尔的答案,这个代码块

      with tempfile.NamedTemporaryFile() as mytmp:
          return mytmp.name
      

      有效地做三件事

      # Firstly, it calls NamedTemporaryFile, to create a new instance of the class.
      context_manager = tempfile.NamedTemporaryFile()  
      
      # Secondly, it calls __enter__ on the context manager instance.
      mytmp = context_manager.__enter__()  
      
      # Thirdly, we are now "inside" the context and can do some work. 
      return mytmp.name
      

      当您将tempfile.NamedTemporaryFile 替换为MockMagicMock实例

      context_manager = mock_tmp()
      # This first line, above, will call mock_tmp().
      # Therefore we need to set the return_value with
      # mock_tmp.return_value
      
      mytmp = context_manager.__enter__()
      # This will call mock_tmp.return_value.__enter__() so we need to set 
      # mock_tmp.return_value.__enter__.return_value
      
      return mytmp.name
      # This will access mock_tmp.return_value.__enter__.return_value.name
      

      【讨论】:

        【解决方案3】:

        使用 pytest 和 mocker 夹具扩展 Peter K 的答案。

        def myfunc():
            with tempfile.NamedTemporaryFile(prefix='fileprefix') as fh:
                return fh.name
        
        
        def test_myfunc(mocker):
            mocker.patch('tempfile.NamedTemporaryFile').return_value.__enter__.return_value.name = 'tempfilename'
            assert myfunc() == 'tempfilename'
        

        【讨论】:

          【解决方案4】:

          这是pytestmocker fixture 的替代方案,这也是一种常见做法:

          def test_myfunc(mocker):
              mock_tempfile = mocker.MagicMock(name='tempfile')
              mocker.patch(__name__ + '.tempfile', new=mock_tempfile)
              mytmpname = 'abcde'
              mock_tempfile.NamedTemporaryFile.return_value.__enter__.return_value.name = mytmpname
              assert myfunc() == mytmpname
          

          【讨论】:

            【解决方案5】:

            我将 hmobrienv 的答案扩展到一个小型工作程序

            import tempfile
            import pytest
            
            
            def myfunc():
                with tempfile.NamedTemporaryFile(prefix="fileprefix") as fh:
                    return fh.name
            
            
            def test_myfunc(mocker):
                mocker.patch("tempfile.NamedTemporaryFile").return_value.__enter__.return_value.name = "tempfilename"
                assert myfunc() == "tempfilename"
            
            
            if __name__ == "__main__":
                pytest.main(args=[__file__])
            

            【讨论】:

              【解决方案6】:

              另一种可能性是使用工厂创建实现上下文管理器接口的对象:

              import unittest
              import unittest.mock
              import tempfile
              
              
              def myfunc():
                  with tempfile.NamedTemporaryFile() as mytmp:
                      return mytmp.name
              
              
              def mock_named_temporary_file(tmpname):
                  class MockNamedTemporaryFile(object):
                      def __init__(self, *args, **kwargs):
                          self.name = tmpname
              
                      def __enter__(self):
                          return self
              
                      def __exit__(self, type, value, traceback):
                          pass
              
                  return MockNamedTemporaryFile()
              
              
              class TestMock(unittest.TestCase):
                  @unittest.mock.patch("tempfile.NamedTemporaryFile")
                  def test_cm(self, mock_tmp):
                      mytmpname = "abcde"
                      mock_tmp.return_value = mock_named_temporary_file(mytmpname)
                      self.assertEqual(myfunc(), mytmpname)
              
              

              【讨论】: