瀏覽代碼

remove superflous get_header method (base class has been updated to use string methods), add coverage for deferring_http_request's done method

Chris McDonough 11 年之前
父節點
當前提交
a3c5bb7d56
共有 4 個文件被更改,包括 150 次插入36 次删除
  1. 3 21
      supervisor/http.py
  2. 6 0
      supervisor/medusa/http_server.py
  3. 21 15
      supervisor/tests/base.py
  4. 120 0
      supervisor/tests/test_http.py

+ 3 - 21
supervisor/http.py

@@ -141,24 +141,6 @@ class deferring_http_request(http_server.http_request):
     about deferred responses, so we override various methods here.  This was
     added to support tail -f like behavior on the logtail handler """
 
-    def get_header(self, header):
-        # this is overridden purely for speed (the base class doesn't
-        # use string methods
-        header = header.lower()
-        hc = self._header_cache
-        if not hc.has_key(header):
-            h = header + ': '
-            for line in self.header:
-                if line.lower().startswith(h):
-                    hl = len(h)
-                    r = line[hl:]
-                    hc[header] = r
-                    return r
-            hc[header] = None
-            return None
-        else:
-            return hc[header]
-
     def done(self, *arg, **kw):
 
         """ I didn't want to override this, but there's no way around
@@ -181,7 +163,7 @@ class deferring_http_request(http_server.http_request):
 
         if self.version == '1.0':
             if connection == 'keep-alive':
-                if not self.has_key ('Content-Length'):
+                if not 'Content-Length' in self:
                     close_it = 1
                 else:
                     self['Connection'] = 'Keep-Alive'
@@ -190,8 +172,8 @@ class deferring_http_request(http_server.http_request):
         elif self.version == '1.1':
             if connection == 'close':
                 close_it = 1
-            elif not self.has_key('Content-Length'):
-                if self.has_key('Transfer-Encoding'):
+            elif not 'Content-Length' in self:
+                if 'Transfer-Encoding' in self:
                     if not self['Transfer-Encoding'] == 'chunked':
                         close_it = 1
                 elif self.use_chunked:

+ 6 - 0
supervisor/medusa/http_server.py

@@ -82,6 +82,9 @@ class http_request:
     def __getitem__ (self, key):
         return self.reply_headers[key]
 
+    def __contains__(self, key):
+        return key in self.reply_headers
+
     def has_key (self, key):
         return self.reply_headers.has_key (key)
 
@@ -449,6 +452,9 @@ class http_request:
              )
             )
 
+    def log_info(self, msg, level):
+        pass
+
 
 # ===========================================================================
 #                                                HTTP Channel Object

+ 21 - 15
supervisor/tests/base.py

@@ -142,7 +142,7 @@ class DummyOptions:
     def get_pid(self):
         import os
         return os.getpid()
-        
+
     def check_execv_args(self, filename, argv, st):
         if filename == '/bad/filename':
             from supervisor.options import NotFound
@@ -261,7 +261,7 @@ class DummyLogger:
         if kw:
             msg = msg % kw
         self.data.append(msg)
-        
+
     def reopen(self):
         self.reopened = True
     def close(self):
@@ -345,7 +345,7 @@ class DummySocketManager:
 
     def get_socket(self):
         return DummySocket(self._config.fd)
-        
+
 class DummyProcess:
     # Initial state; overridden by instance variables
     pid = 0 # Subprocess pid; 0 when not running
@@ -551,7 +551,7 @@ def makeExecutable(file, substitutions=None):
     import os
     import sys
     import tempfile
-    
+
     if substitutions is None:
         substitutions = {}
     data = open(file).read()
@@ -560,7 +560,7 @@ def makeExecutable(file, substitutions=None):
     substitutions['PYTHON'] = sys.executable
     for key in substitutions.keys():
         data = data.replace('<<%s>>' % key.upper(), substitutions[key])
-    
+
     tmpnam = tempfile.mktemp(prefix=last)
     f = open(tmpnam, 'w')
     f.write(data)
@@ -599,7 +599,7 @@ class DummyMedusaChannel:
     def set_terminator(self, terminator):
         pass
 
-class DummyRequest(dict):
+class DummyRequest(object):
     command = 'GET'
     _error = None
     _done = False
@@ -628,6 +628,12 @@ class DummyRequest(dict):
     def __setitem__(self, header, value):
         self.headers[header] = value
 
+    def __getitem__(self, header):
+        return self.headers[header]
+
+    def __delitem__(self, header):
+        del self.headers[header]
+
     def has_key(self, header):
         return self.headers.has_key(header)
 
@@ -645,7 +651,7 @@ class DummyRequest(dict):
 
     def get_server_url(self):
         return 'http://example.com'
-        
+
 
 class DummyRPCInterfaceFactory:
     def __init__(self, supervisord, **config):
@@ -820,9 +826,9 @@ class DummySupervisorRPCNamespace:
             raise Fault(xmlrpc.Faults.NOT_RUNNING, 'NOT_RUNNING')
         if name == 'FAILED':
             raise Fault(xmlrpc.Faults.FAILED, 'FAILED')
-        
+
         return True
-    
+
     def stopAllProcesses(self):
         from supervisor import xmlrpc
         return [
@@ -1014,15 +1020,15 @@ class DummyProcessGroup:
 
     def stop_all(self):
         self.all_stopped = True
-        
+
     def get_unstopped_processes(self):
         return self.unstopped_processes
 
     def get_dispatchers(self):
         return self.dispatchers
-        
+
 class DummyFCGIProcessGroup(DummyProcessGroup):
-    
+
     def __init__(self, config):
         DummyProcessGroup.__init__(self, config)
         self.socket_manager = DummySocketManager(config.socket_config)
@@ -1093,7 +1099,7 @@ class DummyDispatcher:
         if self.flush_error:
             raise OSError(self.flush_error)
         self.flushed = True
-                
+
 class DummyStream:
     def __init__(self, error=None):
         self.error = error
@@ -1114,7 +1120,7 @@ class DummyStream:
         pass
     def tell(self):
         return len(self.written)
-        
+
 class DummyEvent:
     def __init__(self, serial='abc'):
         if serial is not None:
@@ -1135,7 +1141,7 @@ class DummyPoller:
 
     def poll(self, timeout):
         return self.result
-        
+
 def dummy_handler(event, result):
     pass
 

+ 120 - 0
supervisor/tests/test_http.py

@@ -311,6 +311,126 @@ class DeferringHookedProducerTests(unittest.TestCase):
         self.assertEqual(producer.more(), '')
         self.assertEqual(L, [0])
 
+    def test_more_noproducer(self):
+        producer = self._makeOne(None, None)
+        self.assertEqual(producer.more(), '')
+
+class Test_deferring_http_request(unittest.TestCase):
+    def _getTargetClass(self):
+        from supervisor.http import deferring_http_request
+        return deferring_http_request
+
+    def _makeOne(
+        self,
+        channel=None,
+        req='GET / HTTP/1.0',
+        command='GET',
+        uri='/',
+        version='1.0',
+        header=(),
+        ):
+        return self._getTargetClass()(
+            channel, req, command, uri, version, header
+            )
+
+    def _makeChannel(self):
+        class Channel:
+            closed = False
+            def close_when_done(self):
+                self.closed = True
+            def push_with_producer(self, producer):
+                self.producer = producer
+        return Channel()
+    
+    def test_done_http_10_nokeepalive(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(channel=channel, version='1.0')
+        inst.done()
+        self.assertTrue(channel.closed)
+
+    def test_done_http_10_keepalive_no_content_length(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version='1.0',
+            header=['Connection: Keep-Alive'],
+            )
+        
+        inst.done()
+        self.assertTrue(channel.closed)
+        
+    def test_done_http_10_keepalive_and_content_length(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version='1.0',
+            header=['Connection: Keep-Alive'],
+            )
+        inst.reply_headers['Content-Length'] = 1
+        inst.done()
+        self.assertEqual(inst['Connection'], 'Keep-Alive')
+        self.assertFalse(channel.closed)
+
+    def test_done_http_11_connection_close(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version='1.1',
+            header=['Connection: close']
+            )
+        inst.done()
+        self.assertTrue(channel.closed)
+
+    def test_done_http_11_unknown_transfer_encoding(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version='1.1',
+            )
+        inst.reply_headers['Transfer-Encoding'] = 'notchunked'
+        inst.done()
+        self.assertTrue(channel.closed)
+
+    def test_done_http_11_chunked_transfer_encoding(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version='1.1',
+            )
+        inst.reply_headers['Transfer-Encoding'] = 'chunked'
+        inst.done()
+        self.assertFalse(channel.closed)
+
+    def test_done_http_11_use_chunked(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version='1.1',
+            )
+        inst.use_chunked = True
+        inst.done()
+        self.assertTrue('Transfer-Encoding' in inst)
+        self.assertFalse(channel.closed)
+
+    def test_done_http_11_wo_content_length_no_te_no_use_chunked_close(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version='1.1',
+            )
+        inst.use_chunked = False
+        inst.done()
+        self.assertTrue(channel.closed)
+
+    def test_done_http_09(self):
+        channel = self._makeChannel()
+        inst = self._makeOne(
+            channel=channel,
+            version=None,
+            )
+        inst.done()
+        self.assertTrue(channel.closed)
+        
 class EncryptedDictionaryAuthorizedTests(unittest.TestCase):
     def _getTargetClass(self):
         from supervisor.http import encrypted_dictionary_authorizer