#!/usr/bin/env python

# Copyright (C) 2004-2016 CS-SI. All Rights Reserved.
# Author: Nicolas Delon <nicolas.delon@prelude-ids.com>
#
# This file is part of the Prewikka program.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import sys
import shutil
from optparse import OptionParser
import multiprocessing
import urllib, urlparse
import BaseHTTPServer
import ssl

from prewikka.web import request
from prewikka import main, siteconfig, localization



class PrewikkaServer(BaseHTTPServer.HTTPServer):
    pass


class PrewikkaRequestHandler(request.Request, BaseHTTPServer.BaseHTTPRequestHandler):
    def __init__(self, *args, **kwargs):
        request.Request.__init__(self, *args, **kwargs)
        BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args, **kwargs)

    def init(self, method):
        request.Request.init(self, self.server.core)
        self.method = method

        if self.headers.get("x-requested-with", "") == "XMLHttpRequest":
            self.is_xhr = True

        if self.headers.get("accept", "") == "text/event-stream":
            self.is_stream = True

    def log_request(self, code):
        pass

    def _process_dynamic(self, arguments):
        for name, value in arguments.items():
            self.arguments[name] = (len(value) == 1) and value[0] or value

        self.server.core.process(self)

    def send_headers(self, headers=None, code=200, status_text=None):
        BaseHTTPServer.BaseHTTPRequestHandler.send_response(self, code, status_text)
        request.Request.send_headers(self, headers)

    def send_error(self, code, status_text):
        self.send_response(None, code=code, status_text=status_text)

    def do_GET(self):
        self.init("GET")

        path = self._resolve_static(self.path)
        if path:
           return self._process_static(path, lambda fd: shutil.copyfileobj(fd, self.wfile))

        self._process_dynamic(urlparse.parse_qs(self._uri.query))

    def do_HEAD(self):
        self.do_GET()

    def do_POST(self):
        self.init("POST")

        if self.headers.get("content-type", "").startswith("multipart/form-data"):
            self._query_string = ""
            self.body = arguments = self._handle_multipart(fp=self.rfile, headers=self.headers, environ={'REQUEST_METHOD': 'POST'})
        else:
            self.body = qs = self.rfile.read(int(self.headers["Content-Length"]))
            if self._uri.query:
                qs = "&".join((self._uri.query, qs))

            self._query_string = qs
            arguments = urlparse.parse_qs(qs)

        self._process_dynamic(arguments)

    def write(self, data):
        self.wfile.write(data)

    def get_remote_addr(self):
        return self.client_address[0]

    def get_remote_port(self):
        return self.client_address[1]

    def get_query_string(self):
        return self._query_string

    def get_cookie(self):
        return self.headers.get("Cookie")

    def get_query_string(self):
        return self._query_string

    def get_raw_uri(self, include_qs=False):
        uri = self._path

        if include_qs:
            qs = self.get_query_string()
            if qs:
                uri = "?".join((uri, qs))

        return uri

def serve_forever(server, config_file):
    server.core = main.get_core_from_config(config_file)

    try:
        server.serve_forever()
    except KeyboardInterrupt:
        pass

if __name__ == "__main__":
    parser = OptionParser(epilog=" ")

    parser.add_option("-a", "--address", action="store", type="string", dest="addr", default="0.0.0.0", help="IP to bind to (default: %default)")
    parser.add_option("-p", "--port", action="store", type="int", dest="port", default=8000, help="port number to use (default: %default)")
    parser.add_option("", "--key", action="store", type="string", dest="key", default=None, help="SSL private key to use (default: no SSL)")
    parser.add_option("", "--cert", action="store", type="string", dest="cert", default=None, help="SSL certificate to use (default: no SSL)")
    parser.add_option("-c", "--config", action="store", type="string", dest="config", default="%s/prewikka.conf" % siteconfig.conf_dir, help="configuration file to use (default: %default)")
    parser.add_option("-m", "--multiprocess", action="store", type="int", dest="num_process", default=multiprocessing.cpu_count(),
                      help="number of processes to use. Default value matches the number of available CPUs (i.e. %d)" % multiprocessing.cpu_count())

    (options, args) = parser.parse_args()

    server = PrewikkaServer((options.addr, options.port), PrewikkaRequestHandler)
    if options.key and options.cert:
        server.socket = ssl.wrap_socket(server.socket, keyfile=options.key, certfile=options.cert, server_side=True)

    for i in range(options.num_process - 1):
        multiprocessing.Process(target=serve_forever, args=(server, options.config)).start()

    serve_forever(server, options.config)
