Procházet zdrojové kódy

Allow a signal name like "SIGTERM" and validate signal number

Mike Naberezny před 11 roky
rodič
revize
34c899bd25
2 změnil soubory, kde provedl 46 přidání a 8 odebrání
  1. 13 8
      supervisor/datatypes.py
  2. 33 0
      supervisor/tests/test_datatypes.py

+ 13 - 8
supervisor/datatypes.py

@@ -396,17 +396,22 @@ def url(value):
         return value
     raise ValueError("value %s is not a URL" % value)
 
+# all valid signal numbers
+SIGNUMS = [ getattr(signal, k) for k in dir(signal) if k.startswith('SIG') ]
+
 def signal_number(value):
-    result = None
-    try:
-        result = int(value)
-    except (ValueError, TypeError):
-        result = getattr(signal, 'SIG'+value, None)
     try:
-        result = int(result)
-        return result
+        num = int(value)
     except (ValueError, TypeError):
-        raise ValueError('value %s is not a signal name/number' % value)
+        name = value.strip().upper()
+        if not name.startswith('SIG'):
+            name = 'SIG' + name
+        num = getattr(signal, name, None)
+        if num is None:
+            raise ValueError('value %s is not a valid signal name' % value)
+    if num not in SIGNUMS:
+        raise ValueError('value %s is not a valid signal number' % value)
+    return num
 
 class RestartWhenExitUnexpected:
     pass

+ 33 - 0
supervisor/tests/test_datatypes.py

@@ -526,3 +526,36 @@ class TestOctalType(unittest.TestCase):
         except ValueError, e:
             expected = '1.2 can not be converted to an octal type'
             self.assertEqual(e.args[0], expected)
+
+class TestSignalNumber(unittest.TestCase):
+    def _callFUT(self, arg):
+        from supervisor.datatypes import signal_number
+        return signal_number(arg)
+
+    def test_converts_number(self):
+        import signal
+        self.assertEqual(self._callFUT(signal.SIGTERM), signal.SIGTERM)
+
+    def test_converts_name(self):
+        import signal
+        self.assertEqual(self._callFUT(' term '), signal.SIGTERM)
+
+    def test_converts_signame(self):
+        import signal
+        self.assertEqual(self._callFUT('SIGTERM'), signal.SIGTERM)
+
+    def test_raises_for_bad_number(self):
+        try:
+            self._callFUT('12345678')
+            self.fail()
+        except ValueError, e:
+            expected = "value 12345678 is not a valid signal number"
+            self.assertEqual(e.args[0], expected)
+
+    def test_raises_for_bad_name(self):
+        try:
+            self._callFUT('BADSIG')
+            self.fail()
+        except ValueError, e:
+            expected = "value BADSIG is not a valid signal name"
+            self.assertEqual(e.args[0], expected)