Browse Source

- Move set of 'SUPERVISOR_ENABLED' envvar to process.spawn() method.

- Set 'SUPERVISOR_SERVER_URL' envvar from within process.spawn()
  method (used by childutils.getRPCInterface()).  This can be used by
  children to determine the current XML-RPC server URL for the
  supervisord that it's running under.

- Get rid of options.BasicAuthTransport in favor of
  xmlrpc.SupervisorTransport.  They have the same constructor, but
  SupervisorTransport allows for HTTP/1.1 persistent connections.

- Move gettags function to xmlrpc module.
Chris McDonough 18 years ago
parent
commit
01eb3e0458

+ 39 - 0
src/supervisor/childutils.py

@@ -0,0 +1,39 @@
+##############################################################################
+#
+# Copyright (c) 2007 Agendaless Consulting and Contributors.
+# All Rights Reserved.
+#
+# This software is subject to the provisions of the Zope Public License,
+# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
+# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
+# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
+# FOR A PARTICULAR PURPOSE
+#
+##############################################################################
+
+import sys
+import xmlrpclib
+from supervisor.xmlrpc import SupervisorTransport
+
+def getRPCInterface(env):
+    # dumbass ServerProxy won't allow us to pass in a non-HTTP url,
+    # so we fake the url we pass into it and always use the transport's
+    # 'serverurl' to figure out what to attach to
+    u = env.get('SUPERVISOR_USERNAME', '')
+    p = env.get('SUPERVISOR_PASSWORD', '')
+    return xmlrpclib.ServerProxy(
+        'http://127.0.0.1',
+        transport = SupervisorTransport(u, p, env['SUPERVISOR_SERVER_URL'])
+        )
+
+def write_stderr(msg):
+    sys.stderr.write(msg)
+    sys.stderr.flush()
+
+def write_stdout(msg):
+    sys.stdout.write(msg)
+    sys.stdout.flush()
+
+def get_headers(line):
+    return dict([ x.split(':') for x in line.split() ])

+ 1 - 0
src/supervisor/datatypes.py

@@ -299,3 +299,4 @@ def auto_restart(value):
                               RestartUnconditionally, False):
         raise ValueError("invalid 'autorestart' value %r" % value)
     return computed_value
+

+ 15 - 120
src/supervisor/options.py

@@ -58,6 +58,7 @@ from supervisor.datatypes import auto_restart
 
 from supervisor import loggers
 from supervisor import states
+from supervisor import xmlrpc
 
 here = os.path.abspath(os.path.dirname(__file__))
 version_txt = os.path.join(here, 'version.txt')
@@ -443,6 +444,17 @@ class ServerOptions(Options):
 
         self.identifier = section.identifier
 
+        if section.http_port is None:
+            self.serverurl = None
+
+        else:
+            if section.http_port.family == socket.AF_INET:
+                host, port = section.http_port.address
+                self.serverurl = 'http://%s:%s' % (host, port)
+            else:
+                # domain socket
+                self.serverurl = 'unix://%s' % section.http_port.address
+
     def convert_sockchown(self, sockchown):
         # Convert chown stuff to uid/gid
         user = sockchown[0]
@@ -1214,7 +1226,6 @@ class ClientOptions(Options):
         self.add("password", "supervisorctl.password", "p:", "password=")
 
     def realize(self, *arg, **kw):
-        os.environ['SUPERVISOR_ENABLED'] = '1'
         Options.realize(self, *arg, **kw)
         if not self.args:
             self.interactive = 1
@@ -1256,9 +1267,9 @@ class ClientOptions(Options):
             # so we fake the url we pass into it and always use the transport's
             # 'serverurl' to figure out what to attach to
             'http://127.0.0.1',
-            transport = BasicAuthTransport(self.username,
-                                           self.password,
-                                           self.serverurl)
+            transport = xmlrpc.SupervisorTransport(self.username,
+                                                   self.password,
+                                                   self.serverurl)
             )
 
 _marker = []
@@ -1415,78 +1426,6 @@ class EventListenerPoolConfig(Config):
         from supervisor.process import EventListenerPool
         return EventListenerPool(self)
 
-class BasicAuthTransport(xmlrpclib.Transport):
-    """ A transport that understands basic auth and UNIX domain socket
-    URLs """
-    _use_datetime = 0 # python 2.5 fwd compatibility
-    def __init__(self, username=None, password=None, serverurl=None):
-        self.username = username
-        self.password = password
-        self.verbose = False
-        self.serverurl = serverurl
-
-    def request(self, host, handler, request_body, verbose=False):
-        # issue XML-RPC request
-
-        h = self.make_connection(host)
-        if verbose:
-            h.set_debuglevel(1)
-
-        h.putrequest("POST", handler)
-
-        # required by HTTP/1.1
-        h.putheader("Host", host)
-
-        # required by XML-RPC
-        h.putheader("User-Agent", self.user_agent)
-        h.putheader("Content-Type", "text/xml")
-        h.putheader("Content-Length", str(len(request_body)))
-
-        # basic auth
-        if self.username is not None and self.password is not None:
-            unencoded = "%s:%s" % (self.username, self.password)
-            encoded = unencoded.encode('base64')
-            encoded = encoded.replace('\012', '')
-            h.putheader("Authorization", "Basic %s" % encoded)
-
-        h.endheaders()
-
-        if request_body:
-            h.send(request_body)
-
-        errcode, errmsg, headers = h.getreply()
-
-        if errcode != 200:
-            raise xmlrpclib.ProtocolError(
-                host + handler,
-                errcode, errmsg,
-                headers
-                )
-
-        return self.parse_response(h.getfile())
-
-    def make_connection(self, host):
-        serverurl = self.serverurl
-        if not serverurl.startswith('http'):
-            if serverurl.startswith('unix://'):
-                serverurl = serverurl[7:]
-            http = UnixStreamHTTP(serverurl)
-            return http
-        else:            
-            type, uri = urllib.splittype(serverurl)
-            host, path = urllib.splithost(uri)
-            hostpath = host+path
-            return xmlrpclib.Transport.make_connection(self, hostpath)
-            
-class UnixStreamHTTPConnection(httplib.HTTPConnection):
-    def connect(self):
-        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        # we abuse the host parameter as the socketname
-        self.sock.connect(self.host)
-
-class UnixStreamHTTP(httplib.HTTP):
-    _connection_class = UnixStreamHTTPConnection
-
 def readFile(filename, offset, length):
     """ Read length bytes from the file named by filename starting at
     offset """
@@ -1559,50 +1498,6 @@ def tailFile(filename, offset, length):
     except (OSError, IOError):
         return ['', offset, False]
 
-def gettags(comment):
-    """ Parse documentation strings into JavaDoc-like tokens """
-
-    tags = []
-
-    tag = None
-    datatype = None
-    name = None
-    tag_lineno = lineno = 0
-    tag_text = []
-
-    for line in comment.split('\n'):
-        line = line.strip()
-        if line.startswith("@"):
-            tags.append((tag_lineno, tag, datatype, name, '\n'.join(tag_text)))
-            parts = line.split(None, 3)
-            if len(parts) == 1:
-                datatype = ''
-                name = ''
-                tag_text = []
-            elif len(parts) == 2:
-                datatype = parts[1]
-                name = ''
-                tag_text = []
-            elif len(parts) == 3:
-                datatype = parts[1]
-                name = parts[2]
-                tag_text = []
-            elif len(parts) == 4:
-                datatype = parts[1]
-                name = parts[2]
-                tag_text = [parts[3].lstrip()]
-            tag = parts[0][1:]
-            tag_lineno = lineno
-        else:
-            if line:
-                tag_text.append(line)
-        lineno = lineno + 1
-
-    tags.append((tag_lineno, tag, datatype, name, '\n'.join(tag_text)))
-
-    return tags
-
-
 # Helpers for dealing with signals and exit status
 
 def decode_wait_status(sts):

+ 4 - 0
src/supervisor/process.py

@@ -269,6 +269,10 @@ class Subprocess:
                     options.write(2, "(%s)\n" % msg)
                 try:
                     env = os.environ.copy()
+                    env['SUPERVISOR_ENABLED'] = '1'
+                    serverurl = self.config.options.serverurl
+                    if serverurl:
+                        env['SUPERVISOR_SERVER_URL'] = serverurl
                     env['SUPERVISOR_PROCESS_NAME'] = self.config.name
                     if self.group:
                         env['SUPERVISOR_GROUP_NAME'] = self.group.config.name

+ 1 - 0
src/supervisor/tests/base.py

@@ -63,6 +63,7 @@ class DummyOptions:
         self.openreturn = None
         self.readfd_result = ''
         self.parse_warnings = []
+        self.serverurl = 'http://localhost:9001'
 
     def getLogger(self, *args, **kw):
         logger = DummyLogger()

+ 0 - 22
src/supervisor/tests/test_options.py

@@ -778,28 +778,6 @@ class ProcessGroupConfigTests(unittest.TestCase):
         from supervisor.process import ProcessGroup
         self.assertEqual(group.__class__, ProcessGroup)
             
-class BasicAuthTransportTests(unittest.TestCase):
-    def _getTargetClass(self):
-        from supervisor.options import BasicAuthTransport
-        return BasicAuthTransport
-
-    def _makeOne(self, username=None, password=None, serverurl=None):
-        klass = self._getTargetClass()
-        return klass(username, password, serverurl)
-
-    def test_ctor(self):
-        instance = self._makeOne('username', 'password', 'serverurl')
-        self.assertEqual(instance.username, 'username')
-        self.assertEqual(instance.password, 'password')
-        self.assertEqual(instance.serverurl, 'serverurl')
-        self.assertEqual(instance.verbose, False)
-
-    def test_works_with_py25(self):
-        instance = self._makeOne('username', 'password', 'serverurl')
-        # the test is just to insure that this method can be called; failure
-        # would be an AttributeError for _use_datetime under Python 2.5
-        parser, unmarshaller = instance.getparser() # this uses _use_datetime
-
 def test_suite():
     return unittest.findTestCases(sys.modules[__name__])
 

+ 6 - 1
src/supervisor/tests/test_process.py

@@ -362,7 +362,7 @@ class SubprocessTests(unittest.TestCase):
         self.assertEqual(options.execv_args, ('/bin/cat', ['/bin/cat']) )
         self.assertEqual(options.execv_environment['_TEST_'], '1')
 
-    def test_spawn_as_child_environment_supervisor_process_name(self):
+    def test_spawn_as_child_environment_supervisor_envvars(self):
         options = DummyOptions()
         options.forkpid = 0
         config = DummyPConfig(options, 'cat', '/bin/cat')
@@ -374,10 +374,15 @@ class SubprocessTests(unittest.TestCase):
         result = instance.spawn()
         self.assertEqual(result, None)
         self.assertEqual(options.execv_args, ('/bin/cat', ['/bin/cat']) )
+        self.assertEqual(
+            options.execv_environment['SUPERVISOR_ENABLED'], '1')
         self.assertEqual(
             options.execv_environment['SUPERVISOR_PROCESS_NAME'], 'cat')
         self.assertEqual(
             options.execv_environment['SUPERVISOR_GROUP_NAME'], 'dummy')
+        self.assertEqual(
+            options.execv_environment['SUPERVISOR_SERVER_URL'],
+            'http://localhost:9001')
 
     def test_spawn_as_child_stderr_redirected(self):
         options = DummyOptions()

+ 1 - 2
src/supervisor/tests/test_rpcinterfaces.py

@@ -1393,7 +1393,6 @@ class SystemNamespaceXMLRPCInterfaceTests(TestBase):
         from supervisor import xmlrpc
         # belt-and-suspenders test for docstring-as-typing parsing correctness
         # and documentation validity vs. implementation
-        from supervisor import options
         _RPCTYPES = ['int', 'double', 'string', 'boolean', 'dateTime.iso8601',
                      'base64', 'binary', 'array', 'struct']
         interface = self._makeOne()
@@ -1415,7 +1414,7 @@ class SystemNamespaceXMLRPCInterfaceTests(TestBase):
             meth = getattr(namespace, method_name)
             code = meth.func_code
             argnames = code.co_varnames[1:code.co_argcount]
-            parsed = options.gettags(str(meth.__doc__))
+            parsed = xmlrpc.gettags(str(meth.__doc__))
 
             plines = []
             ptypes = []

+ 123 - 0
src/supervisor/tests/test_xmlrpc.py

@@ -114,6 +114,129 @@ class TraverseTests(unittest.TestCase):
         xmlrpc.traverse(dummy, 'foo', [1])
         self.assertEqual(L, [1])
 
+class TesstSupervisorTransport(unittest.TestCase):
+    def _getTargetClass(self):
+        from supervisor.xmlrpc import SupervisorTransport
+        return SupervisorTransport
+
+    def _makeOne(self, *arg, **kw):
+        return self._getTargetClass()(*arg, **kw)
+
+    def test_ctor_unix(self):
+        from supervisor import xmlrpc
+        transport = self._makeOne('user', 'pass', 'unix:///foo/bar')
+        conn = transport._get_connection()
+        self.failUnless(isinstance(conn, xmlrpc.UnixStreamHTTPConnection))
+        self.assertEqual(conn.host, '/foo/bar')
+
+    def test__get_connection_http_9001(self):
+        from supervisor import xmlrpc
+        import httplib
+        transport = self._makeOne('user', 'pass', 'http://127.0.0.1:9001/')
+        conn = transport._get_connection()
+        self.failUnless(isinstance(conn, httplib.HTTPConnection))
+        self.assertEqual(conn.host, '127.0.0.1')
+        self.assertEqual(conn.port, 9001)
+
+    def test__get_connection_http_80(self):
+        from supervisor import xmlrpc
+        import httplib
+        transport = self._makeOne('user', 'pass', 'http://127.0.0.1/')
+        conn = transport._get_connection()
+        self.failUnless(isinstance(conn, httplib.HTTPConnection))
+        self.assertEqual(conn.host, '127.0.0.1')
+        self.assertEqual(conn.port, 80)
+
+    def test_request_non_200_response(self):
+        import xmlrpclib
+        transport = self._makeOne('user', 'pass', 'http://127.0.0.1/')
+        dummy_conn = DummyConnection(400, '')
+        def getconn():
+            return dummy_conn
+        transport._get_connection = getconn
+        self.assertRaises(xmlrpclib.ProtocolError,
+                          transport.request, 'localhost', '/', '')
+        self.assertEqual(transport.connection, None)
+        self.assertEqual(dummy_conn.closed, True)
+
+    def test_request_400_response(self):
+        import xmlrpclib
+        transport = self._makeOne('user', 'pass', 'http://127.0.0.1/')
+        dummy_conn = DummyConnection(400, '')
+        def getconn():
+            return dummy_conn
+        transport._get_connection = getconn
+        self.assertRaises(xmlrpclib.ProtocolError,
+                          transport.request, 'localhost', '/', '')
+        self.assertEqual(transport.connection, None)
+        self.assertEqual(dummy_conn.closed, True)
+        self.assertEqual(dummy_conn.requestargs[0], 'POST')
+        self.assertEqual(dummy_conn.requestargs[1], '/')
+        self.assertEqual(dummy_conn.requestargs[2], '')
+        self.assertEqual(dummy_conn.requestargs[3]['Content-Length'], '0')
+        self.assertEqual(dummy_conn.requestargs[3]['Content-Type'], 'text/xml')
+        self.assertEqual(dummy_conn.requestargs[3]['Authorization'],
+                         'Basic dXNlcjpwYXNz')
+        self.assertEqual(dummy_conn.requestargs[3]['Accept'], 'text/xml')
+
+    def test_request_200_response(self):
+        import xmlrpclib
+        transport = self._makeOne('user', 'pass', 'http://127.0.0.1/')
+        response = """<?xml version="1.0"?>
+        <methodResponse>
+        <params>
+        <param>
+        <value><string>South Dakota</string></value>
+        </param>
+        </params>
+        </methodResponse>"""
+        dummy_conn = DummyConnection(200, response)
+        def getconn():
+            return dummy_conn
+        transport._get_connection = getconn
+        result = transport.request('localhost', '/', '')
+        self.assertEqual(transport.connection, dummy_conn)
+        self.assertEqual(dummy_conn.closed, False)
+        self.assertEqual(dummy_conn.requestargs[0], 'POST')
+        self.assertEqual(dummy_conn.requestargs[1], '/')
+        self.assertEqual(dummy_conn.requestargs[2], '')
+        self.assertEqual(dummy_conn.requestargs[3]['Content-Length'], '0')
+        self.assertEqual(dummy_conn.requestargs[3]['Content-Type'], 'text/xml')
+        self.assertEqual(dummy_conn.requestargs[3]['Authorization'],
+                         'Basic dXNlcjpwYXNz')
+        self.assertEqual(dummy_conn.requestargs[3]['Accept'], 'text/xml')
+        self.assertEqual(result, ('South Dakota',))
+
+    def test_works_with_py25(self):
+        instance = self._makeOne('username', 'password', 'http://127.0.0.1')
+        # the test is just to insure that this method can be called; failure
+        # would be an AttributeError for _use_datetime under Python 2.5
+        parser, unmarshaller = instance.getparser() # this uses _use_datetime
+
+class DummyResponse:
+    def __init__(self, status=200, body='', reason='reason'):
+        self.status = status
+        self.body = body
+        self.reason = reason
+
+    def read(self):
+        return self.body
+
+class DummyConnection:
+    closed = False
+    def __init__(self, status=200, body='', reason='reason'):
+        self.response = DummyResponse(status, body, reason)
+
+    def getresponse(self):
+        return self.response
+        
+    def request(self, *arg, **kw):
+        self.requestargs = arg
+        self.requestkw = kw
+
+    def close(self):
+        self.closed = True
+
 def test_suite():
     return unittest.findTestCases(sys.modules[__name__])
 

+ 122 - 1
src/supervisor/xmlrpc.py

@@ -13,12 +13,14 @@
 ##############################################################################
 
 import types
+import socket
 import xmlrpclib
+import httplib
+import urllib
 import re
 import StringIO
 import traceback
 import sys
-from supervisor.options import gettags
 
 from medusa.http_server import get_header
 from medusa.xmlrpc_handler import xmlrpc_handler
@@ -390,3 +392,122 @@ def traverse(ob, method, params):
     except TypeError:
         raise RPCError(Faults.INCORRECT_PARAMETERS)
 
+class SupervisorTransport(xmlrpclib.Transport):
+    """
+    Provides a Transport for xmlrpclib that uses
+    httplib.HTTPConnection in order to support persistent
+    connections.  Also support basic auth and UNIX domain socket
+    servers.
+    """
+    connection = None
+
+    _use_datetime = 0 # python 2.5 fwd compatibility
+    def __init__(self, username=None, password=None, serverurl=None):
+        self.username = username
+        self.password = password
+        self.verbose = False
+        self.serverurl = serverurl
+        if serverurl.startswith('http://'):
+            type, uri = urllib.splittype(serverurl)
+            host, path = urllib.splithost(uri)
+            host, port = urllib.splitport(host)
+            if port is None:
+                port = 80
+            else:
+                port = int(port)
+            def get_connection(host=host, port=port):
+                return httplib.HTTPConnection(host, port)
+            self._get_connection = get_connection
+        elif serverurl.startswith('unix://'):
+            serverurl = serverurl[7:]
+            def get_connection(serverurl=serverurl):
+                return UnixStreamHTTPConnection(serverurl)
+            self._get_connection = get_connection
+        else:
+            raise ValueError('Unknown protocol for serverurl %s' % serverurl)
+
+    def request(self, host, handler, request_body, verbose=0):
+        if not self.connection:
+            self.connection = self._get_connection()
+            self.headers = {
+                "User-Agent" : self.user_agent,
+                "Content-Type" : "text/xml",
+                "Accept": "text/xml"
+                }
+            
+            # basic auth
+            if self.username is not None and self.password is not None:
+                unencoded = "%s:%s" % (self.username, self.password)
+                encoded = unencoded.encode('base64')
+                encoded = encoded.replace('\012', '')
+                self.headers["Authorization"] = "Basic %s" % encoded
+                
+        self.headers["Content-Length"] = str(len(request_body))
+
+        self.connection.request('POST', handler, request_body, self.headers)
+
+        r = self.connection.getresponse()
+
+        if r.status != 200:
+            self.connection.close()
+            self.connection = None
+            raise xmlrpclib.ProtocolError(host + handler,
+                                          r.status,
+                                          r.reason,
+                                          '' )
+        data = r.read()
+        p, u = self.getparser()
+        p.feed(data)
+        p.close()
+        return u.close()    
+
+class UnixStreamHTTPConnection(httplib.HTTPConnection):
+    def connect(self):
+        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        # we abuse the host parameter as the socketname
+        self.sock.connect(self.host)
+
+def gettags(comment):
+    """ Parse documentation strings into JavaDoc-like tokens """
+
+    tags = []
+
+    tag = None
+    datatype = None
+    name = None
+    tag_lineno = lineno = 0
+    tag_text = []
+
+    for line in comment.split('\n'):
+        line = line.strip()
+        if line.startswith("@"):
+            tags.append((tag_lineno, tag, datatype, name, '\n'.join(tag_text)))
+            parts = line.split(None, 3)
+            if len(parts) == 1:
+                datatype = ''
+                name = ''
+                tag_text = []
+            elif len(parts) == 2:
+                datatype = parts[1]
+                name = ''
+                tag_text = []
+            elif len(parts) == 3:
+                datatype = parts[1]
+                name = parts[2]
+                tag_text = []
+            elif len(parts) == 4:
+                datatype = parts[1]
+                name = parts[2]
+                tag_text = [parts[3].lstrip()]
+            tag = parts[0][1:]
+            tag_lineno = lineno
+        else:
+            if line:
+                tag_text.append(line)
+        lineno = lineno + 1
+
+    tags.append((tag_lineno, tag, datatype, name, '\n'.join(tag_text)))
+
+    return tags
+
+