Przeglądaj źródła

Raise an exception if a user in the config file does not exist

Mike Naberezny 12 lat temu
rodzic
commit
448b9e622e
3 zmienionych plików z 114 dodań i 44 usunięć
  1. 34 38
      supervisor/datatypes.py
  2. 17 6
      supervisor/options.py
  3. 63 0
      supervisor/tests/test_datatypes.py

+ 34 - 38
supervisor/datatypes.py

@@ -262,75 +262,71 @@ class UnixStreamSocketConfig(SocketConfig):
                 raise ValueError("Could not change ownership of socket file: "
                                     + "%s" % (e))
 
-
 def colon_separated_user_group(arg):
+    """ Find a user ID and group ID from a string like 'user:group'.  Returns
+        a tuple (uid, gid).  If the string only contains a user like 'user'
+        then (uid, -1) will be returned.  Raises ValueError if either
+        the user or group can't be resolved to valid IDs on the system. """
     try:
-        result = arg.split(':', 1)
-        if len(result) == 1:
-            username = result[0]
-            uid = name_to_uid(username)
-            if uid is None:
-                raise ValueError('Invalid user name %s' % username)
-            return (uid, -1)
+        parts = arg.split(':', 1)
+        if len(parts) == 1:
+            uid = name_to_uid(parts[0])
+            gid = -1
         else:
-            username = result[0]
-            groupname = result[1]
-            uid = name_to_uid(username)
-            gid = name_to_gid(groupname)
-            if uid is None:
-                raise ValueError('Invalid user name %s' % username)
-            if gid is None:
-                raise ValueError('Invalid group name %s' % groupname)
-            return (uid, gid)
-        return result
+            uid = name_to_uid(parts[0])
+            gid = name_to_gid(parts[1])
+        return (uid, gid)
     except:
-        raise ValueError, 'Invalid user.group definition %s' % arg
-
-def octal_type(arg):
-    try:
-        return int(arg, 8)
-    except TypeError:
-        raise ValueError('%s is not convertable to an octal type' % arg)
+        raise ValueError, 'Invalid user:group definition %s' % arg
 
 def name_to_uid(name):
-    if name is None:
-        return None
-
+    """ Find a user ID from a string containing a user name or ID.
+        Raises ValueError if the string can't be resolved to a valid
+        user ID on the system. """
     try:
         uid = int(name)
     except ValueError:
         try:
-            pwrec = pwd.getpwnam(name)
+            pwdrec = pwd.getpwnam(name)
         except KeyError:
-            return None
-        uid = pwrec[2]
+            raise ValueError("Invalid user name %s" % name)
+        uid = pwdrec[2]
     else:
         try:
-            pwrec = pwd.getpwuid(uid)
+            pwd.getpwuid(uid) # check if uid is valid
         except KeyError:
-            return None
+            raise ValueError("Invalid user id %s" % name)
     return uid
 
 def name_to_gid(name):
+    """ Find a group ID from a string containing a group name or ID.
+        Raises ValueError if the string can't be resolved to a valid
+        group ID on the system. """
     try:
         gid = int(name)
     except ValueError:
         try:
-            pwrec = grp.getgrnam(name)
+            grprec = grp.getgrnam(name)
         except KeyError:
-            return None
-        gid = pwrec[2]
+            raise ValueError("Invalid group name %s" % name)
+        gid = grprec[2]
     else:
         try:
-            pwrec = grp.getgrgid(gid)
+            grp.getgrgid(gid) # check if gid is valid
         except KeyError:
-            return None
+            raise ValueError("Invalid group id %s" % name)
     return gid
 
 def gid_for_uid(uid):
     pwrec = pwd.getpwuid(uid)
     return pwrec[3]
 
+def octal_type(arg):
+    try:
+        return int(arg, 8)
+    except TypeError:
+        raise ValueError('%s can not be converted to an octal type' % arg)
+
 def existing_directory(v):
     nv = v % {'here':here}
     nv = os.path.expanduser(nv)

+ 17 - 6
supervisor/options.py

@@ -449,9 +449,10 @@ class ServerOptions(Options):
 
         # Additional checking of user option; set uid and gid
         if self.user is not None:
-            uid = name_to_uid(self.user)
-            if uid is None:
-                self.usage("No such user %s" % self.user)
+            try:
+                uid = name_to_uid(self.user)
+            except ValueError, msg:
+                self.usage(msg) # invalid user
             self.uid = uid
             self.gid = gid_for_uid(uid)
 
@@ -683,7 +684,12 @@ class ServerOptions(Options):
             program_name = section.split(':', 1)[1]
             priority = integer(get(section, 'priority', 999))
 
-            proc_uid = name_to_uid(get(section, 'user', None))
+            # find proc_uid from "user" option
+            proc_user = get(section, 'user', None)
+            if proc_user is None:
+                proc_uid = None
+            else:
+                proc_uid = name_to_uid(proc_user)
 
             socket_owner = get(section, 'socket_owner', None)
             if socket_owner is not None:
@@ -769,13 +775,11 @@ class ServerOptions(Options):
         programs = []
         get = parser.saneget
         program_name = section.split(':', 1)[1]
-
         priority = integer(get(section, 'priority', 999))
         autostart = boolean(get(section, 'autostart', 'true'))
         autorestart = auto_restart(get(section, 'autorestart', 'unexpected'))
         startsecs = integer(get(section, 'startsecs', 1))
         startretries = integer(get(section, 'startretries', 3))
-        uid = name_to_uid(get(section, 'user', None))
         stopsignal = signal_number(get(section, 'stopsignal', 'TERM'))
         stopwaitsecs = integer(get(section, 'stopwaitsecs', 10))
         stopasgroup = boolean(get(section, 'stopasgroup', 'false'))
@@ -795,6 +799,13 @@ class ServerOptions(Options):
         if serverurl and serverurl.strip().upper() == 'AUTO':
             serverurl = None
 
+        # find uid from "user" option
+        user = get(section, 'user', None)
+        if user is None:
+            uid = None
+        else:
+            uid = name_to_uid(user)
+
         umask = get(section, 'umask', None)
         if umask is not None:
             umask = octal_type(umask)

+ 63 - 0
supervisor/tests/test_datatypes.py

@@ -211,6 +211,69 @@ class DatatypesTest(unittest.TestCase):
         bad_url = "unix://"
         self.assertRaises(ValueError, datatypes.url, bad_url)
 
+    @patch("pwd.getpwnam", Mock(return_value=[0,0,42]))
+    def test_name_to_uid_gets_uid_from_username(self):
+        uid = datatypes.name_to_uid("foo")
+        self.assertEquals(uid, 42)
+
+    @patch("pwd.getpwuid", Mock(return_value=[0,0,42]))
+    def test_name_to_uid_gets_uid_from_user_id(self):
+        uid = datatypes.name_to_uid("42")
+        self.assertEquals(uid, 42)
+
+    @patch("pwd.getpwnam", Mock(side_effect=KeyError("bad username")))
+    def test_name_to_uid_raises_for_bad_username(self):
+        self.assertRaises(ValueError, datatypes.name_to_uid, "foo")
+
+    @patch("pwd.getpwuid", Mock(side_effect=KeyError("bad user id")))
+    def test_name_to_uid_raises_for_bad_user_id(self):
+        self.assertRaises(ValueError, datatypes.name_to_uid, "42")
+
+    @patch("grp.getgrnam", Mock(return_value=[0,0,42]))
+    def test_name_to_gid_gets_gid_from_group_name(self):
+        gid = datatypes.name_to_gid("foo")
+        self.assertEquals(gid, 42)
+
+    @patch("grp.getgrgid", Mock(return_value=[0,0,42]))
+    def test_name_to_gid_gets_gid_from_group_id(self):
+        gid = datatypes.name_to_gid("42")
+        self.assertEquals(gid, 42)
+
+    @patch("grp.getgrnam", Mock(side_effect=KeyError("bad group name")))
+    def test_name_to_gid_raises_for_bad_group_name(self):
+        self.assertRaises(ValueError, datatypes.name_to_gid, "foo")
+
+    @patch("grp.getgrgid", Mock(side_effect=KeyError("bad group id")))
+    def test_name_to_gid_raises_for_bad_group_name(self):
+        self.assertRaises(ValueError, datatypes.name_to_gid, "42")
+
+    def test_colon_separated_user_group_returns_both(self):
+        name_to_uid = Mock(return_value=12)
+        name_to_gid = Mock(return_value=34)
+
+        @patch("supervisor.datatypes.name_to_uid", name_to_uid)
+        @patch("supervisor.datatypes.name_to_gid", name_to_gid)
+        def colon_separated():
+            return datatypes.colon_separated_user_group("foo:bar")
+
+        uid, gid = colon_separated()
+        name_to_uid.assert_called_with("foo")
+        self.assertEquals(12, uid)
+        name_to_gid.assert_called_with("bar")
+        self.assertEquals(34, gid)
+
+    def test_colon_separated_user_group_returns_user_only(self):
+        name_to_uid = Mock(return_value=42)
+
+        @patch("supervisor.datatypes.name_to_uid", name_to_uid)
+        def colon_separated():
+            return datatypes.colon_separated_user_group("foo")
+
+        uid, gid = colon_separated()
+        name_to_uid.assert_called_with("foo")
+        self.assertEquals(42, uid)
+        self.assertEquals(-1, gid)
+
 class InetStreamSocketConfigTests(unittest.TestCase):
     def _getTargetClass(self):
         return datatypes.InetStreamSocketConfig