bug-wget
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[Bug-wget] [GSoC PATCH 06/11] move server classes to package server.prot


From: Zihang Chen
Subject: [Bug-wget] [GSoC PATCH 06/11] move server classes to package server.protocol
Date: Fri, 14 Mar 2014 21:18:22 +0800

 delete mode 100644 testenv/FTPServer.py
 delete mode 100644 testenv/HTTPServer.py
 create mode 100644 testenv/misc/constants.py
 create mode 100644 testenv/server/__init__.py
 create mode 100644 testenv/server/ftp/__init__.py
 create mode 100644 testenv/server/ftp/ftp_server.py
 create mode 100644 testenv/server/http/__init__.py
 create mode 100644 testenv/server/http/http_server.py

diff --git a/testenv/ChangeLog b/testenv/ChangeLog
index 73b92e7..390becc 100644
--- a/testenv/ChangeLog
+++ b/testenv/ChangeLog
@@ -1,4 +1,13 @@
 2014-03-13  Zihang Chen  <address@hidden>
+       * server: (new package) package for the server classes
+       * server.http: (new package) package for HTTP server
+       * server.ftp: (new package) package for FTP server
+       * HTTPServer.py: Move to server/http/http_server.py. Also change the
+       CERTFILE to '../certs/wget-cert.pem'.
+       * FTPServer.py: Move to server/ftp/ftp_server.py.
+       * WgetTest.py: Optimize import respect to the server classes.
+       (HTTP, HTTPS): Theses two string constants are move to 
misc/constants.py.
+2014-03-13  Zihang Chen  <address@hidden>
        * conf: (new package) package for rule classes and hook methods
        * WgetTest.py:
        (CommonMethods.Authentication): Move to conf/authentication.py.
diff --git a/testenv/FTPServer.py b/testenv/FTPServer.py
deleted file mode 100644
index f7d7771..0000000
--- a/testenv/FTPServer.py
+++ /dev/null
@@ -1,162 +0,0 @@
-import os
-import re
-import threading
-import socket
-import pyftpdlib.__main__
-from pyftpdlib.ioloop import IOLoop
-import pyftpdlib.handlers as Handle
-from pyftpdlib.servers import FTPServer
-from pyftpdlib.authorizers import DummyAuthorizer
-from pyftpdlib._compat import PY3, u, b, getcwdu, callable
-
-class FTPDHandler (Handle.FTPHandler):
-
-    def ftp_LIST (self, path):
-        try:
-            iterator = self.run_as_current_user(self.fs.get_list_dir, path)
-        except (OSError, FilesystemError):
-            err = sys.exc_info()[1]
-            why = _strerror (err)
-            self.respond ('550 %s. ' % why)
-        else:
-            if self.isRule ("Bad List") is True:
-                iter_list = list ()
-                for flist in iterator:
-                    line = re.compile (r'(\s+)').split (flist.decode ('utf-8'))
-                    line[8] = '0'
-                    iter_l =  ''.join (line).encode ('utf-8')
-                    iter_list.append (iter_l)
-                iterator = (n for n in iter_list)
-            producer = Handle.BufferedIteratorProducer (iterator)
-            self.push_dtp_data (producer, isproducer=True, cmd="LIST")
-            return path
-
-    def ftp_PASV (self, line):
-        if self._epsvall:
-            self.respond ("501 PASV not allowed after EPSV ALL.")
-            return
-        self._make_epasv(extmode=False)
-        if self.isRule ("FailPASV") is True:
-            del self.server.global_rules["FailPASV"]
-            self.socket.close ()
-
-    def isRule (self, rule):
-        rule_obj = self.server.global_rules[rule]
-        return False if not rule_obj else rule_obj[0]
-
-class FTPDServer (FTPServer):
-
-    def set_global_rules (self, rules):
-        self.global_rules = rules
-
-class FTPd(threading.Thread):
-    """A threaded FTP server used for running tests.
-
-    This is basically a modified version of the FTPServer class which
-    wraps the polling loop into a thread.
-
-    The instance returned can be used to start(), stop() and
-    eventually re-start() the server.
-    """
-    handler = FTPDHandler
-    server_class = FTPDServer
-
-    def __init__(self, addr=None):
-        os.mkdir ('server')
-        os.chdir ('server')
-        try:
-            HOST = socket.gethostbyname ('localhost')
-        except socket.error:
-            HOST = 'localhost'
-        USER = 'user'
-        PASSWD = '12345'
-        HOME = getcwdu ()
-
-        threading.Thread.__init__(self)
-        self.__serving = False
-        self.__stopped = False
-        self.__lock = threading.Lock()
-        self.__flag = threading.Event()
-        if addr is None:
-            addr = (HOST, 0)
-
-        authorizer = DummyAuthorizer()
-        authorizer.add_user(USER, PASSWD, HOME, perm='elradfmwM')  # full perms
-        authorizer.add_anonymous(HOME)
-        self.handler.authorizer = authorizer
-        # lowering buffer sizes = more cycles to transfer data
-        # = less false positive test failures
-        self.handler.dtp_handler.ac_in_buffer_size = 32768
-        self.handler.dtp_handler.ac_out_buffer_size = 32768
-        self.server = self.server_class(addr, self.handler)
-        self.host, self.port = self.server.socket.getsockname()[:2]
-        os.chdir ('..')
-
-    def set_global_rules (self, rules):
-        self.server.set_global_rules (rules)
-
-    def __repr__(self):
-        status = [self.__class__.__module__ + "." + self.__class__.__name__]
-        if self.__serving:
-            status.append('active')
-        else:
-            status.append('inactive')
-        status.append('%s:%s' % self.server.socket.getsockname()[:2])
-        return '<%s at %#x>' % (' '.join(status), id(self))
-
-    @property
-    def running(self):
-        return self.__serving
-
-    def start(self, timeout=0.001):
-        """Start serving until an explicit stop() request.
-        Polls for shutdown every 'timeout' seconds.
-        """
-        if self.__serving:
-            raise RuntimeError("Server already started")
-        if self.__stopped:
-            # ensure the server can be started again
-            FTPd.__init__(self, self.server.socket.getsockname(), self.handler)
-        self.__timeout = timeout
-        threading.Thread.start(self)
-        self.__flag.wait()
-
-    def run(self):
-        self.__serving = True
-        self.__flag.set()
-        while self.__serving:
-            self.__lock.acquire()
-            self.server.serve_forever(timeout=self.__timeout, blocking=False)
-            self.__lock.release()
-        self.server.close_all()
-
-    def stop(self):
-        """Stop serving (also disconnecting all currently connected
-        clients) by telling the serve_forever() loop to stop and
-        waits until it does.
-        """
-        if not self.__serving:
-            raise RuntimeError("Server not started yet")
-        self.__serving = False
-        self.__stopped = True
-        self.join()
-
-
-def mk_file_sys (file_list):
-    os.chdir ('server')
-    for name, content in file_list.items ():
-        file_h = open (name, 'w')
-        file_h.write (content)
-        file_h.close ()
-    os.chdir ('..')
-
-def filesys ():
-    fileSys = dict ()
-    os.chdir ('server')
-    for parent, dirs, files in os.walk ('.'):
-        for filename in files:
-            file_handle = open (filename, 'r')
-            file_content = file_handle.read ()
-            fileSys[filename] = file_content
-    os.chdir ('..')
-    return fileSys
diff --git a/testenv/HTTPServer.py b/testenv/HTTPServer.py
deleted file mode 100644
index e554a10..0000000
--- a/testenv/HTTPServer.py
+++ /dev/null
@@ -1,467 +0,0 @@
-from http.server import HTTPServer, BaseHTTPRequestHandler
-from socketserver import BaseServer
-from posixpath import basename, splitext
-from base64 import b64encode
-from random import random
-from hashlib import md5
-import threading
-import socket
-import re
-import ssl
-import os
-
-
-class InvalidRangeHeader (Exception):
-
-    """ Create an Exception for handling of invalid Range Headers. """
-    # TODO: Eliminate this exception and use only ServerError
-
-    def __init__ (self, err_message):
-        self.err_message = err_message
-
-class ServerError (Exception):
-    def __init__ (self, err_message):
-        self.err_message = err_message
-
-
-class StoppableHTTPServer (HTTPServer):
-
-    request_headers = list ()
-
-    """ Define methods for configuring the Server. """
-
-    def server_conf (self, filelist, conf_dict):
-        """ Set Server Rules and File System for this instance. """
-        self.server_configs = conf_dict
-        self.fileSys = filelist
-
-    def server_sett (self, settings):
-        for settings_key in settings:
-            setattr (self.RequestHandlerClass, settings_key, 
settings[settings_key])
-
-    def get_req_headers (self):
-        return self.request_headers
-
-class HTTPSServer (StoppableHTTPServer):
-
-   def __init__ (self, address, handler):
-         BaseServer.__init__ (self, address, handler)
-         print (os.getcwd())
-         CERTFILE = os.path.abspath (os.path.join ('..', 'certs', 
'wget-cert.pem'))
-         print (CERTFILE)
-         fop = open (CERTFILE)
-         print (fop.readline())
-         self.socket = ssl.wrap_socket (
-               sock = socket.socket (self.address_family, self.socket_type),
-               ssl_version = ssl.PROTOCOL_TLSv1,
-               certfile = CERTFILE,
-               server_side = True
-               )
-         self.server_bind ()
-         self.server_activate ()
-
-class WgetHTTPRequestHandler (BaseHTTPRequestHandler):
-
-    """ Define methods for handling Test Checks. """
-
-    def get_rule_list (self, name):
-        r_list = self.rules.get (name) if name in self.rules else None
-        return r_list
-
-
-class _Handler (WgetHTTPRequestHandler):
-
-    """ Define Handler Methods for different Requests. """
-
-    InvalidRangeHeader = InvalidRangeHeader
-    protocol_version = 'HTTP/1.1'
-
-    """ Define functions for various HTTP Requests. """
-
-    def do_HEAD (self):
-        self.send_head ("HEAD")
-
-    def do_GET (self):
-        content, start = self.send_head ("GET")
-        if content:
-            if start is None:
-                self.wfile.write (content.encode ('utf-8'))
-            else:
-                self.wfile.write (content.encode ('utf-8')[start:])
-
-    def do_POST (self):
-        path = self.path[1:]
-        self.rules = self.server.server_configs.get (path)
-        if not self.custom_response ():
-            return (None, None)
-        if path in self.server.fileSys:
-            body_data = self.get_body_data ()
-            self.send_response (200)
-            self.send_header ("Content-type", "text/plain")
-            content = self.server.fileSys.pop (path) + "\n" + body_data
-            total_length = len (content)
-            self.server.fileSys[path] = content
-            self.send_header ("Content-Length", total_length)
-            self.finish_headers ()
-            try:
-                self.wfile.write (content.encode ('utf-8'))
-            except Exception:
-                pass
-        else:
-            self.send_put (path)
-
-    def do_PUT (self):
-        path = self.path[1:]
-        self.rules = self.server.server_configs.get (path)
-        if not self.custom_response ():
-            return (None, None)
-        self.server.fileSys.pop (path, None)
-        self.send_put (path)
-
-    """ End of HTTP Request Method Handlers. """
-
-    """ Helper functions for the Handlers. """
-
-    def parse_range_header (self, header_line, length):
-        if header_line is None:
-            return None
-        if not header_line.startswith ("bytes="):
-            raise InvalidRangeHeader ("Cannot parse header Range: %s" %
-                                     (header_line))
-        regex = re.match (r"^bytes=(\d*)\-$", header_line)
-        range_start = int (regex.group (1))
-        if range_start >= length:
-            raise InvalidRangeHeader ("Range Overflow")
-        return range_start
-
-    def get_body_data (self):
-        cLength_header = self.headers.get ("Content-Length")
-        cLength = int (cLength_header) if cLength_header is not None else 0
-        body_data = self.rfile.read (cLength).decode ('utf-8')
-        return body_data
-
-    def send_put (self, path):
-        body_data = self.get_body_data ()
-        self.send_response (201)
-        self.server.fileSys[path] = body_data
-        self.send_header ("Content-type", "text/plain")
-        self.send_header ("Content-Length", len (body_data))
-        self.finish_headers ()
-        try:
-            self.wfile.write (body_data.encode ('utf-8'))
-        except Exception:
-            pass
-
-    def SendHeader (self, header_obj):
-        pass
-#        headers_list = header_obj.headers
-#        for header_line in headers_list:
-#            print (header_line + " : " + headers_list[header_line])
-#            self.send_header (header_line, headers_list[header_line])
-
-    def send_cust_headers (self):
-        header_obj = self.get_rule_list ('SendHeader')
-        if header_obj:
-            for header in header_obj.headers:
-                self.send_header (header, header_obj.headers[header])
-
-    def finish_headers (self):
-        self.send_cust_headers ()
-        self.end_headers ()
-
-    def Response (self, resp_obj):
-        self.send_response (resp_obj.response_code)
-        self.finish_headers ()
-        raise ServerError ("Custom Response code sent.")
-
-    def custom_response (self):
-        codes = self.get_rule_list ('Response')
-        if codes:
-            self.send_response (codes.response_code)
-            self.finish_headers ()
-            return False
-        else:
-            return True
-
-    def base64 (self, data):
-        string = b64encode (data.encode ('utf-8'))
-        return string.decode ('utf-8')
-
-    def send_challenge (self, auth_type):
-        if auth_type == "Both":
-            self.send_challenge ("Digest")
-            self.send_challenge ("Basic")
-            return
-        if auth_type == "Basic":
-            challenge_str = 'Basic realm="Wget-Test"'
-        elif auth_type == "Digest" or auth_type == "Both_inline":
-            self.nonce = md5 (str (random ()).encode ('utf-8')).hexdigest ()
-            self.opaque = md5 (str (random ()).encode ('utf-8')).hexdigest ()
-            challenge_str = 'Digest realm="Test", nonce="%s", opaque="%s"' %(
-                                                                   self.nonce,
-                                                                   self.opaque)
-            challenge_str += ', qop="auth"'
-            if auth_type == "Both_inline":
-                challenge_str = 'Basic realm="Wget-Test", ' + challenge_str
-        self.send_header ("WWW-Authenticate", challenge_str)
-
-    def authorize_Basic (self, auth_header, auth_rule):
-        if auth_header is None or auth_header.split(' ')[0] != 'Basic':
-            return False
-        else:
-            self.user = auth_rule.auth_user
-            self.passw = auth_rule.auth_pass
-            auth_str = "Basic " + self.base64 (self.user + ":" + self.passw)
-            return True if auth_str == auth_header else False
-
-    def parse_auth_header (self, auth_header):
-        n = len("Digest ")
-        auth_header = auth_header[n:].strip()
-        items = auth_header.split(", ")
-        key_values = [i.split("=", 1) for i in items]
-        key_values = [(k.strip(), v.strip().replace('"', '')) for k, v in 
key_values]
-        return dict(key_values)
-
-    def KD (self, secret, data):
-        return self.H (secret + ":" + data)
-
-    def H (self, data):
-        return md5 (data.encode ('utf-8')).hexdigest ()
-
-    def A1 (self):
-        return "%s:%s:%s" % (self.user, "Test", self.passw)
-
-    def A2 (self, params):
-        return "%s:%s" % (self.command, params["uri"])
-
-    def check_response (self, params):
-        if "qop" in params:
-            data_str = params['nonce'] \
-                        + ":" + params['nc'] \
-                        + ":" + params['cnonce'] \
-                        + ":" + params['qop'] \
-                        + ":" + self.H (self.A2 (params))
-        else:
-            data_str = params['nonce'] + ":" + self.H (self.A2 (params))
-        resp = self.KD (self.H (self.A1 ()), data_str)
-
-        return True if resp == params['response'] else False
-
-    def authorize_Digest (self, auth_header, auth_rule):
-        if auth_header is None or auth_header.split(' ')[0] != 'Digest':
-            return False
-        else:
-            self.user = auth_rule.auth_user
-            self.passw = auth_rule.auth_pass
-            params = self.parse_auth_header (auth_header)
-            pass_auth = True
-            if self.user != params['username'] or \
-              self.nonce != params['nonce'] or self.opaque != params['opaque']:
-                pass_auth = False
-            req_attribs = ['username', 'realm', 'nonce', 'uri', 'response']
-            for attrib in req_attribs:
-                if not attrib in params:
-                    pass_auth = False
-            if not self.check_response (params):
-                pass_auth = False
-            return pass_auth
-
-    def authorize_Both (self, auth_header, auth_rule):
-        return False
-
-    def authorize_Both_inline (self, auth_header, auth_rule):
-        return False
-
-    def Authentication (self, auth_rule):
-        try:
-            self.handle_auth (auth_rule)
-        except ServerError as se:
-            self.send_response (401, "Authorization Required")
-            self.send_challenge (auth_rule.auth_type)
-            self.finish_headers ()
-            raise ServerError (se.__str__())
-
-    def handle_auth (self, auth_rule):
-        is_auth = True
-        auth_header = self.headers.get ("Authorization")
-        required_auth = auth_rule.auth_type
-        if required_auth == "Both" or required_auth == "Both_inline":
-            auth_type = auth_header.split(' ')[0] if auth_header else 
required_auth
-        else:
-            auth_type = required_auth
-        try:
-            assert hasattr (self, "authorize_" + auth_type)
-            is_auth = getattr (self, "authorize_" + auth_type) (auth_header, 
auth_rule)
-        except AssertionError:
-            raise ServerError ("Authentication Mechanism " + auth_rule + " not 
supported")
-        except AttributeError as ae:
-            raise ServerError (ae.__str__())
-        if is_auth is False:
-            raise ServerError ("Unable to Authenticate")
-
-    def is_authorized (self):
-        is_auth = True
-        auth_rule = self.get_rule_list ('Authentication')
-        if auth_rule:
-            auth_header = self.headers.get ("Authorization")
-            req_auth = auth_rule.auth_type
-            if req_auth == "Both" or req_auth == "Both_inline":
-                auth_type = auth_header.split(' ')[0] if auth_header else 
req_auth
-            else:
-                auth_type = req_auth
-            assert hasattr (self, "authorize_" + auth_type)
-            is_auth = getattr (self, "authorize_" + auth_type) (auth_header, 
auth_rule)
-            if is_auth is False:
-                self.send_response (401)
-                self.send_challenge (auth_type)
-                self.finish_headers ()
-        return is_auth
-
-    def ExpectHeader (self, header_obj):
-        exp_headers = header_obj.headers
-        for header_line in exp_headers:
-            header_recd = self.headers.get (header_line)
-            if header_recd is None or header_recd != exp_headers[header_line]:
-                self.send_error (400, "Expected Header " + header_line + " not 
found")
-                self.finish_headers ()
-                raise ServerError ("Header " + header_line + " not found")
-
-    def expect_headers (self):
-        """ This is modified code to handle a few changes. Should be removed 
ASAP """
-        exp_headers_obj = self.get_rule_list ('ExpectHeader')
-        if exp_headers_obj:
-            exp_headers = exp_headers_obj.headers
-            for header_line in exp_headers:
-                header_re = self.headers.get (header_line)
-                if header_re is None or header_re != exp_headers[header_line]:
-                    self.send_error (400, 'Expected Header not Found')
-                    self.end_headers ()
-                    return False
-        return True
-
-    def RejectHeader (self, header_obj):
-        rej_headers = header_obj.headers
-        for header_line in rej_headers:
-            header_recd = self.headers.get (header_line)
-            if header_recd is not None and header_recd == 
rej_headers[header_line]:
-                self.send_error (400, 'Blackisted Header ' + header_line + ' 
received')
-                self.finish_headers ()
-                raise ServerError ("Header " + header_line + ' received')
-
-    def reject_headers (self):
-        rej_headers = self.get_rule_list ("RejectHeader")
-        if rej_headers:
-            rej_headers = rej_headers.headers
-            for header_line in rej_headers:
-                header_re = self.headers.get (header_line)
-                if header_re is not None and header_re == 
rej_headers[header_line]:
-                    self.send_error (400, 'Blacklisted Header was Sent')
-                    self.end_headers ()
-                    return False
-        return True
-
-    def __log_request (self, method):
-        req = method + " " + self.path
-        self.server.request_headers.append (req)
-
-    def send_head (self, method):
-        """ Common code for GET and HEAD Commands.
-        This method is overriden to use the fileSys dict.
-
-        The method variable contains whether this was a HEAD or a GET Request.
-        According to RFC 2616, the server should not differentiate between
-        the two requests, however, we use it here for a specific test.
-        """
-
-        if self.path == "/":
-            path = "index.html"
-        else:
-            path = self.path[1:]
-
-        self.__log_request (method)
-
-        if path in self.server.fileSys:
-            self.rules = self.server.server_configs.get (path)
-
-            for rule_name in self.rules:
-                try:
-                    assert hasattr (self, rule_name)
-                    getattr (self, rule_name) (self.rules [rule_name])
-                except AssertionError as ae:
-                    msg = "Method " + rule_name + " not defined"
-                    self.send_error (500, msg)
-                    return (None, None)
-                except ServerError as se:
-                    print (se.__str__())
-                    return (None, None)
-
-            content = self.server.fileSys.get (path)
-            content_length = len (content)
-            try:
-                self.range_begin = self.parse_range_header (
-                    self.headers.get ("Range"), content_length)
-            except InvalidRangeHeader as ae:
-                # self.log_error("%s", ae.err_message)
-                if ae.err_message == "Range Overflow":
-                    self.send_response (416)
-                    self.finish_headers ()
-                    return (None, None)
-                else:
-                    self.range_begin = None
-            if self.range_begin is None:
-                self.send_response (200)
-            else:
-                self.send_response (206)
-                self.send_header ("Accept-Ranges", "bytes")
-                self.send_header ("Content-Range",
-                                  "bytes %d-%d/%d" % (self.range_begin,
-                                                      content_length - 1,
-                                                      content_length))
-                content_length -= self.range_begin
-            cont_type = self.guess_type (path)
-            self.send_header ("Content-type", cont_type)
-            self.send_header ("Content-Length", content_length)
-            self.finish_headers ()
-            return (content, self.range_begin)
-        else:
-            self.send_error (404, "Not Found")
-            return (None, None)
-
-    def guess_type (self, path):
-        base_name = basename ("/" + path)
-        name, ext = splitext (base_name)
-        extension_map = {
-        ".txt"   :   "text/plain",
-        ".css"   :   "text/css",
-        ".html"  :   "text/html"
-        }
-        if ext in extension_map:
-            return extension_map[ext]
-        else:
-            return "text/plain"
-
-
-class HTTPd (threading.Thread):
-    server_class = StoppableHTTPServer
-    handler = _Handler
-    def __init__ (self, addr=None):
-        threading.Thread.__init__ (self)
-        if addr is None:
-            addr = ('localhost', 0)
-        self.server_inst = self.server_class (addr, self.handler)
-        self.server_address = self.server_inst.socket.getsockname()[:2]
-
-    def run (self):
-       self.server_inst.serve_forever ()
-
-    def server_conf (self, file_list, server_rules):
-        self.server_inst.server_conf (file_list, server_rules)
-
-    def server_sett (self, settings):
-         self.server_inst.server_sett (settings)
-
-class HTTPSd (HTTPd):
-
-   server_class = HTTPSServer
-
-# vim: set ts=4 sts=4 sw=4 tw=80 et :
diff --git a/testenv/WgetTest.py b/testenv/WgetTest.py
index 6076012..92e4138 100644
--- a/testenv/WgetTest.py
+++ b/testenv/WgetTest.py
@@ -8,15 +8,11 @@ import time
 from subprocess import call
 from difflib import unified_diff
 
-import HTTPServer
 import conf
 from exc.test_failed import TestFailed
 from misc.colour_terminal import print_red, print_green, print_blue
-
-
-HTTP = "HTTP"
-HTTPS = "HTTPS"
-
+from misc.constants import HTTP, HTTPS
+from server.http import http_server
 
 
 """ Class that defines methods common to both HTTP and FTP Tests. """
@@ -220,12 +216,12 @@ class HTTPTest (CommonMethods):
         self.hook_call(post_hook, 'Post Test Function')
 
     def init_HTTP_Server (self):
-        server = HTTPServer.HTTPd ()
+        server = http_server.HTTPd ()
         server.start ()
         return server
 
     def init_HTTPS_Server (self):
-        server = HTTPServer.HTTPSd ()
+        server = http_server.HTTPSd ()
         server.start ()
         return server
 
diff --git a/testenv/misc/constants.py b/testenv/misc/constants.py
new file mode 100644
index 0000000..5fad2f8
--- /dev/null
+++ b/testenv/misc/constants.py
@@ -0,0 +1,3 @@
+
+HTTP = "HTTP"
+HTTPS = "HTTPS"
\ No newline at end of file
diff --git a/testenv/server/__init__.py b/testenv/server/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/testenv/server/ftp/__init__.py b/testenv/server/ftp/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/testenv/server/ftp/ftp_server.py b/testenv/server/ftp/ftp_server.py
new file mode 100644
index 0000000..f7d7771
--- /dev/null
+++ b/testenv/server/ftp/ftp_server.py
@@ -0,0 +1,162 @@
+import os
+import re
+import threading
+import socket
+import pyftpdlib.__main__
+from pyftpdlib.ioloop import IOLoop
+import pyftpdlib.handlers as Handle
+from pyftpdlib.servers import FTPServer
+from pyftpdlib.authorizers import DummyAuthorizer
+from pyftpdlib._compat import PY3, u, b, getcwdu, callable
+
+class FTPDHandler (Handle.FTPHandler):
+
+    def ftp_LIST (self, path):
+        try:
+            iterator = self.run_as_current_user(self.fs.get_list_dir, path)
+        except (OSError, FilesystemError):
+            err = sys.exc_info()[1]
+            why = _strerror (err)
+            self.respond ('550 %s. ' % why)
+        else:
+            if self.isRule ("Bad List") is True:
+                iter_list = list ()
+                for flist in iterator:
+                    line = re.compile (r'(\s+)').split (flist.decode ('utf-8'))
+                    line[8] = '0'
+                    iter_l =  ''.join (line).encode ('utf-8')
+                    iter_list.append (iter_l)
+                iterator = (n for n in iter_list)
+            producer = Handle.BufferedIteratorProducer (iterator)
+            self.push_dtp_data (producer, isproducer=True, cmd="LIST")
+            return path
+
+    def ftp_PASV (self, line):
+        if self._epsvall:
+            self.respond ("501 PASV not allowed after EPSV ALL.")
+            return
+        self._make_epasv(extmode=False)
+        if self.isRule ("FailPASV") is True:
+            del self.server.global_rules["FailPASV"]
+            self.socket.close ()
+
+    def isRule (self, rule):
+        rule_obj = self.server.global_rules[rule]
+        return False if not rule_obj else rule_obj[0]
+
+class FTPDServer (FTPServer):
+
+    def set_global_rules (self, rules):
+        self.global_rules = rules
+
+class FTPd(threading.Thread):
+    """A threaded FTP server used for running tests.
+
+    This is basically a modified version of the FTPServer class which
+    wraps the polling loop into a thread.
+
+    The instance returned can be used to start(), stop() and
+    eventually re-start() the server.
+    """
+    handler = FTPDHandler
+    server_class = FTPDServer
+
+    def __init__(self, addr=None):
+        os.mkdir ('server')
+        os.chdir ('server')
+        try:
+            HOST = socket.gethostbyname ('localhost')
+        except socket.error:
+            HOST = 'localhost'
+        USER = 'user'
+        PASSWD = '12345'
+        HOME = getcwdu ()
+
+        threading.Thread.__init__(self)
+        self.__serving = False
+        self.__stopped = False
+        self.__lock = threading.Lock()
+        self.__flag = threading.Event()
+        if addr is None:
+            addr = (HOST, 0)
+
+        authorizer = DummyAuthorizer()
+        authorizer.add_user(USER, PASSWD, HOME, perm='elradfmwM')  # full perms
+        authorizer.add_anonymous(HOME)
+        self.handler.authorizer = authorizer
+        # lowering buffer sizes = more cycles to transfer data
+        # = less false positive test failures
+        self.handler.dtp_handler.ac_in_buffer_size = 32768
+        self.handler.dtp_handler.ac_out_buffer_size = 32768
+        self.server = self.server_class(addr, self.handler)
+        self.host, self.port = self.server.socket.getsockname()[:2]
+        os.chdir ('..')
+
+    def set_global_rules (self, rules):
+        self.server.set_global_rules (rules)
+
+    def __repr__(self):
+        status = [self.__class__.__module__ + "." + self.__class__.__name__]
+        if self.__serving:
+            status.append('active')
+        else:
+            status.append('inactive')
+        status.append('%s:%s' % self.server.socket.getsockname()[:2])
+        return '<%s at %#x>' % (' '.join(status), id(self))
+
+    @property
+    def running(self):
+        return self.__serving
+
+    def start(self, timeout=0.001):
+        """Start serving until an explicit stop() request.
+        Polls for shutdown every 'timeout' seconds.
+        """
+        if self.__serving:
+            raise RuntimeError("Server already started")
+        if self.__stopped:
+            # ensure the server can be started again
+            FTPd.__init__(self, self.server.socket.getsockname(), self.handler)
+        self.__timeout = timeout
+        threading.Thread.start(self)
+        self.__flag.wait()
+
+    def run(self):
+        self.__serving = True
+        self.__flag.set()
+        while self.__serving:
+            self.__lock.acquire()
+            self.server.serve_forever(timeout=self.__timeout, blocking=False)
+            self.__lock.release()
+        self.server.close_all()
+
+    def stop(self):
+        """Stop serving (also disconnecting all currently connected
+        clients) by telling the serve_forever() loop to stop and
+        waits until it does.
+        """
+        if not self.__serving:
+            raise RuntimeError("Server not started yet")
+        self.__serving = False
+        self.__stopped = True
+        self.join()
+
+
+def mk_file_sys (file_list):
+    os.chdir ('server')
+    for name, content in file_list.items ():
+        file_h = open (name, 'w')
+        file_h.write (content)
+        file_h.close ()
+    os.chdir ('..')
+
+def filesys ():
+    fileSys = dict ()
+    os.chdir ('server')
+    for parent, dirs, files in os.walk ('.'):
+        for filename in files:
+            file_handle = open (filename, 'r')
+            file_content = file_handle.read ()
+            fileSys[filename] = file_content
+    os.chdir ('..')
+    return fileSys
diff --git a/testenv/server/http/__init__.py b/testenv/server/http/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/testenv/server/http/http_server.py 
b/testenv/server/http/http_server.py
new file mode 100644
index 0000000..946fb79
--- /dev/null
+++ b/testenv/server/http/http_server.py
@@ -0,0 +1,467 @@
+from http.server import HTTPServer, BaseHTTPRequestHandler
+from socketserver import BaseServer
+from posixpath import basename, splitext
+from base64 import b64encode
+from random import random
+from hashlib import md5
+import threading
+import socket
+import re
+import ssl
+import os
+
+
+class InvalidRangeHeader (Exception):
+
+    """ Create an Exception for handling of invalid Range Headers. """
+    # TODO: Eliminate this exception and use only ServerError
+
+    def __init__ (self, err_message):
+        self.err_message = err_message
+
+class ServerError (Exception):
+    def __init__ (self, err_message):
+        self.err_message = err_message
+
+
+class StoppableHTTPServer (HTTPServer):
+
+    request_headers = list ()
+
+    """ Define methods for configuring the Server. """
+
+    def server_conf (self, filelist, conf_dict):
+        """ Set Server Rules and File System for this instance. """
+        self.server_configs = conf_dict
+        self.fileSys = filelist
+
+    def server_sett (self, settings):
+        for settings_key in settings:
+            setattr (self.RequestHandlerClass, settings_key, 
settings[settings_key])
+
+    def get_req_headers (self):
+        return self.request_headers
+
+class HTTPSServer (StoppableHTTPServer):
+
+   def __init__ (self, address, handler):
+         BaseServer.__init__ (self, address, handler)
+         print (os.getcwd())
+         CERTFILE = os.path.abspath (os.path.join ('../', 'certs', 
'wget-cert.pem'))
+         print (CERTFILE)
+         fop = open (CERTFILE)
+         print (fop.readline())
+         self.socket = ssl.wrap_socket (
+               sock = socket.socket (self.address_family, self.socket_type),
+               ssl_version = ssl.PROTOCOL_TLSv1,
+               certfile = CERTFILE,
+               server_side = True
+               )
+         self.server_bind ()
+         self.server_activate ()
+
+class WgetHTTPRequestHandler (BaseHTTPRequestHandler):
+
+    """ Define methods for handling Test Checks. """
+
+    def get_rule_list (self, name):
+        r_list = self.rules.get (name) if name in self.rules else None
+        return r_list
+
+
+class _Handler (WgetHTTPRequestHandler):
+
+    """ Define Handler Methods for different Requests. """
+
+    InvalidRangeHeader = InvalidRangeHeader
+    protocol_version = 'HTTP/1.1'
+
+    """ Define functions for various HTTP Requests. """
+
+    def do_HEAD (self):
+        self.send_head ("HEAD")
+
+    def do_GET (self):
+        content, start = self.send_head ("GET")
+        if content:
+            if start is None:
+                self.wfile.write (content.encode ('utf-8'))
+            else:
+                self.wfile.write (content.encode ('utf-8')[start:])
+
+    def do_POST (self):
+        path = self.path[1:]
+        self.rules = self.server.server_configs.get (path)
+        if not self.custom_response ():
+            return (None, None)
+        if path in self.server.fileSys:
+            body_data = self.get_body_data ()
+            self.send_response (200)
+            self.send_header ("Content-type", "text/plain")
+            content = self.server.fileSys.pop (path) + "\n" + body_data
+            total_length = len (content)
+            self.server.fileSys[path] = content
+            self.send_header ("Content-Length", total_length)
+            self.finish_headers ()
+            try:
+                self.wfile.write (content.encode ('utf-8'))
+            except Exception:
+                pass
+        else:
+            self.send_put (path)
+
+    def do_PUT (self):
+        path = self.path[1:]
+        self.rules = self.server.server_configs.get (path)
+        if not self.custom_response ():
+            return (None, None)
+        self.server.fileSys.pop (path, None)
+        self.send_put (path)
+
+    """ End of HTTP Request Method Handlers. """
+
+    """ Helper functions for the Handlers. """
+
+    def parse_range_header (self, header_line, length):
+        if header_line is None:
+            return None
+        if not header_line.startswith ("bytes="):
+            raise InvalidRangeHeader ("Cannot parse header Range: %s" %
+                                     (header_line))
+        regex = re.match (r"^bytes=(\d*)\-$", header_line)
+        range_start = int (regex.group (1))
+        if range_start >= length:
+            raise InvalidRangeHeader ("Range Overflow")
+        return range_start
+
+    def get_body_data (self):
+        cLength_header = self.headers.get ("Content-Length")
+        cLength = int (cLength_header) if cLength_header is not None else 0
+        body_data = self.rfile.read (cLength).decode ('utf-8')
+        return body_data
+
+    def send_put (self, path):
+        body_data = self.get_body_data ()
+        self.send_response (201)
+        self.server.fileSys[path] = body_data
+        self.send_header ("Content-type", "text/plain")
+        self.send_header ("Content-Length", len (body_data))
+        self.finish_headers ()
+        try:
+            self.wfile.write (body_data.encode ('utf-8'))
+        except Exception:
+            pass
+
+    def SendHeader (self, header_obj):
+        pass
+#        headers_list = header_obj.headers
+#        for header_line in headers_list:
+#            print (header_line + " : " + headers_list[header_line])
+#            self.send_header (header_line, headers_list[header_line])
+
+    def send_cust_headers (self):
+        header_obj = self.get_rule_list ('SendHeader')
+        if header_obj:
+            for header in header_obj.headers:
+                self.send_header (header, header_obj.headers[header])
+
+    def finish_headers (self):
+        self.send_cust_headers ()
+        self.end_headers ()
+
+    def Response (self, resp_obj):
+        self.send_response (resp_obj.response_code)
+        self.finish_headers ()
+        raise ServerError ("Custom Response code sent.")
+
+    def custom_response (self):
+        codes = self.get_rule_list ('Response')
+        if codes:
+            self.send_response (codes.response_code)
+            self.finish_headers ()
+            return False
+        else:
+            return True
+
+    def base64 (self, data):
+        string = b64encode (data.encode ('utf-8'))
+        return string.decode ('utf-8')
+
+    def send_challenge (self, auth_type):
+        if auth_type == "Both":
+            self.send_challenge ("Digest")
+            self.send_challenge ("Basic")
+            return
+        if auth_type == "Basic":
+            challenge_str = 'Basic realm="Wget-Test"'
+        elif auth_type == "Digest" or auth_type == "Both_inline":
+            self.nonce = md5 (str (random ()).encode ('utf-8')).hexdigest ()
+            self.opaque = md5 (str (random ()).encode ('utf-8')).hexdigest ()
+            challenge_str = 'Digest realm="Test", nonce="%s", opaque="%s"' %(
+                                                                   self.nonce,
+                                                                   self.opaque)
+            challenge_str += ', qop="auth"'
+            if auth_type == "Both_inline":
+                challenge_str = 'Basic realm="Wget-Test", ' + challenge_str
+        self.send_header ("WWW-Authenticate", challenge_str)
+
+    def authorize_Basic (self, auth_header, auth_rule):
+        if auth_header is None or auth_header.split(' ')[0] != 'Basic':
+            return False
+        else:
+            self.user = auth_rule.auth_user
+            self.passw = auth_rule.auth_pass
+            auth_str = "Basic " + self.base64 (self.user + ":" + self.passw)
+            return True if auth_str == auth_header else False
+
+    def parse_auth_header (self, auth_header):
+        n = len("Digest ")
+        auth_header = auth_header[n:].strip()
+        items = auth_header.split(", ")
+        key_values = [i.split("=", 1) for i in items]
+        key_values = [(k.strip(), v.strip().replace('"', '')) for k, v in 
key_values]
+        return dict(key_values)
+
+    def KD (self, secret, data):
+        return self.H (secret + ":" + data)
+
+    def H (self, data):
+        return md5 (data.encode ('utf-8')).hexdigest ()
+
+    def A1 (self):
+        return "%s:%s:%s" % (self.user, "Test", self.passw)
+
+    def A2 (self, params):
+        return "%s:%s" % (self.command, params["uri"])
+
+    def check_response (self, params):
+        if "qop" in params:
+            data_str = params['nonce'] \
+                        + ":" + params['nc'] \
+                        + ":" + params['cnonce'] \
+                        + ":" + params['qop'] \
+                        + ":" + self.H (self.A2 (params))
+        else:
+            data_str = params['nonce'] + ":" + self.H (self.A2 (params))
+        resp = self.KD (self.H (self.A1 ()), data_str)
+
+        return True if resp == params['response'] else False
+
+    def authorize_Digest (self, auth_header, auth_rule):
+        if auth_header is None or auth_header.split(' ')[0] != 'Digest':
+            return False
+        else:
+            self.user = auth_rule.auth_user
+            self.passw = auth_rule.auth_pass
+            params = self.parse_auth_header (auth_header)
+            pass_auth = True
+            if self.user != params['username'] or \
+              self.nonce != params['nonce'] or self.opaque != params['opaque']:
+                pass_auth = False
+            req_attribs = ['username', 'realm', 'nonce', 'uri', 'response']
+            for attrib in req_attribs:
+                if not attrib in params:
+                    pass_auth = False
+            if not self.check_response (params):
+                pass_auth = False
+            return pass_auth
+
+    def authorize_Both (self, auth_header, auth_rule):
+        return False
+
+    def authorize_Both_inline (self, auth_header, auth_rule):
+        return False
+
+    def Authentication (self, auth_rule):
+        try:
+            self.handle_auth (auth_rule)
+        except ServerError as se:
+            self.send_response (401, "Authorization Required")
+            self.send_challenge (auth_rule.auth_type)
+            self.finish_headers ()
+            raise ServerError (se.__str__())
+
+    def handle_auth (self, auth_rule):
+        is_auth = True
+        auth_header = self.headers.get ("Authorization")
+        required_auth = auth_rule.auth_type
+        if required_auth == "Both" or required_auth == "Both_inline":
+            auth_type = auth_header.split(' ')[0] if auth_header else 
required_auth
+        else:
+            auth_type = required_auth
+        try:
+            assert hasattr (self, "authorize_" + auth_type)
+            is_auth = getattr (self, "authorize_" + auth_type) (auth_header, 
auth_rule)
+        except AssertionError:
+            raise ServerError ("Authentication Mechanism " + auth_rule + " not 
supported")
+        except AttributeError as ae:
+            raise ServerError (ae.__str__())
+        if is_auth is False:
+            raise ServerError ("Unable to Authenticate")
+
+    def is_authorized (self):
+        is_auth = True
+        auth_rule = self.get_rule_list ('Authentication')
+        if auth_rule:
+            auth_header = self.headers.get ("Authorization")
+            req_auth = auth_rule.auth_type
+            if req_auth == "Both" or req_auth == "Both_inline":
+                auth_type = auth_header.split(' ')[0] if auth_header else 
req_auth
+            else:
+                auth_type = req_auth
+            assert hasattr (self, "authorize_" + auth_type)
+            is_auth = getattr (self, "authorize_" + auth_type) (auth_header, 
auth_rule)
+            if is_auth is False:
+                self.send_response (401)
+                self.send_challenge (auth_type)
+                self.finish_headers ()
+        return is_auth
+
+    def ExpectHeader (self, header_obj):
+        exp_headers = header_obj.headers
+        for header_line in exp_headers:
+            header_recd = self.headers.get (header_line)
+            if header_recd is None or header_recd != exp_headers[header_line]:
+                self.send_error (400, "Expected Header " + header_line + " not 
found")
+                self.finish_headers ()
+                raise ServerError ("Header " + header_line + " not found")
+
+    def expect_headers (self):
+        """ This is modified code to handle a few changes. Should be removed 
ASAP """
+        exp_headers_obj = self.get_rule_list ('ExpectHeader')
+        if exp_headers_obj:
+            exp_headers = exp_headers_obj.headers
+            for header_line in exp_headers:
+                header_re = self.headers.get (header_line)
+                if header_re is None or header_re != exp_headers[header_line]:
+                    self.send_error (400, 'Expected Header not Found')
+                    self.end_headers ()
+                    return False
+        return True
+
+    def RejectHeader (self, header_obj):
+        rej_headers = header_obj.headers
+        for header_line in rej_headers:
+            header_recd = self.headers.get (header_line)
+            if header_recd is not None and header_recd == 
rej_headers[header_line]:
+                self.send_error (400, 'Blackisted Header ' + header_line + ' 
received')
+                self.finish_headers ()
+                raise ServerError ("Header " + header_line + ' received')
+
+    def reject_headers (self):
+        rej_headers = self.get_rule_list ("RejectHeader")
+        if rej_headers:
+            rej_headers = rej_headers.headers
+            for header_line in rej_headers:
+                header_re = self.headers.get (header_line)
+                if header_re is not None and header_re == 
rej_headers[header_line]:
+                    self.send_error (400, 'Blacklisted Header was Sent')
+                    self.end_headers ()
+                    return False
+        return True
+
+    def __log_request (self, method):
+        req = method + " " + self.path
+        self.server.request_headers.append (req)
+
+    def send_head (self, method):
+        """ Common code for GET and HEAD Commands.
+        This method is overriden to use the fileSys dict.
+
+        The method variable contains whether this was a HEAD or a GET Request.
+        According to RFC 2616, the server should not differentiate between
+        the two requests, however, we use it here for a specific test.
+        """
+
+        if self.path == "/":
+            path = "index.html"
+        else:
+            path = self.path[1:]
+
+        self.__log_request (method)
+
+        if path in self.server.fileSys:
+            self.rules = self.server.server_configs.get (path)
+
+            for rule_name in self.rules:
+                try:
+                    assert hasattr (self, rule_name)
+                    getattr (self, rule_name) (self.rules [rule_name])
+                except AssertionError as ae:
+                    msg = "Method " + rule_name + " not defined"
+                    self.send_error (500, msg)
+                    return (None, None)
+                except ServerError as se:
+                    print (se.__str__())
+                    return (None, None)
+
+            content = self.server.fileSys.get (path)
+            content_length = len (content)
+            try:
+                self.range_begin = self.parse_range_header (
+                    self.headers.get ("Range"), content_length)
+            except InvalidRangeHeader as ae:
+                # self.log_error("%s", ae.err_message)
+                if ae.err_message == "Range Overflow":
+                    self.send_response (416)
+                    self.finish_headers ()
+                    return (None, None)
+                else:
+                    self.range_begin = None
+            if self.range_begin is None:
+                self.send_response (200)
+            else:
+                self.send_response (206)
+                self.send_header ("Accept-Ranges", "bytes")
+                self.send_header ("Content-Range",
+                                  "bytes %d-%d/%d" % (self.range_begin,
+                                                      content_length - 1,
+                                                      content_length))
+                content_length -= self.range_begin
+            cont_type = self.guess_type (path)
+            self.send_header ("Content-type", cont_type)
+            self.send_header ("Content-Length", content_length)
+            self.finish_headers ()
+            return (content, self.range_begin)
+        else:
+            self.send_error (404, "Not Found")
+            return (None, None)
+
+    def guess_type (self, path):
+        base_name = basename ("/" + path)
+        name, ext = splitext (base_name)
+        extension_map = {
+        ".txt"   :   "text/plain",
+        ".css"   :   "text/css",
+        ".html"  :   "text/html"
+        }
+        if ext in extension_map:
+            return extension_map[ext]
+        else:
+            return "text/plain"
+
+
+class HTTPd (threading.Thread):
+    server_class = StoppableHTTPServer
+    handler = _Handler
+    def __init__ (self, addr=None):
+        threading.Thread.__init__ (self)
+        if addr is None:
+            addr = ('localhost', 0)
+        self.server_inst = self.server_class (addr, self.handler)
+        self.server_address = self.server_inst.socket.getsockname()[:2]
+
+    def run (self):
+       self.server_inst.serve_forever ()
+
+    def server_conf (self, file_list, server_rules):
+        self.server_inst.server_conf (file_list, server_rules)
+
+    def server_sett (self, settings):
+         self.server_inst.server_sett (settings)
+
+class HTTPSd (HTTPd):
+
+   server_class = HTTPSServer
+
+# vim: set ts=4 sts=4 sw=4 tw=80 et :
-- 
1.8.3.2




reply via email to

[Prev in Thread] Current Thread [Next in Thread]