tests/testlib/badserverext.py
author Pierre-Yves David <pierre-yves.david@octobus.net>
Sun, 23 Jan 2022 21:25:01 +0100
changeset 48611 f91f98e9834a
parent 48610 caa6694dac45
child 48612 11e5cb170d36
permissions -rw-r--r--
test-http-bad-server: unify log printing for `sendall` and `write` The `write` function was logging before calling the function while the `sendall` function were logging after calling the function. We align the `write` behavior on the `sendall` behavior before unifying the code. Differential Revision: https://phab.mercurial-scm.org/D12042

# badserverext.py - Extension making servers behave badly
#
# Copyright 2017 Gregory Szorc <gregory.szorc@gmail.com>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

# no-check-code

"""Extension to make servers behave badly.

This extension is useful for testing Mercurial behavior when various network
events occur.

Various config options in the [badserver] section influence behavior:

close-before-accept
   If true, close() the server socket when a new connection arrives before
   accept() is called. The server will then exit.

close-after-accept
   If true, the server will close() the client socket immediately after
   accept().

close-after-recv-bytes
   If defined, close the client socket after receiving this many bytes.

close-after-send-bytes
   If defined, close the client socket after sending this many bytes.
"""

from __future__ import absolute_import

import socket

from mercurial import (
    pycompat,
    registrar,
)

from mercurial.hgweb import server

configtable = {}
configitem = registrar.configitem(configtable)

configitem(
    b'badserver',
    b'close-after-accept',
    default=False,
)
configitem(
    b'badserver',
    b'close-after-recv-bytes',
    default=b'0',
)
configitem(
    b'badserver',
    b'close-after-send-bytes',
    default=b'0',
)
configitem(
    b'badserver',
    b'close-before-accept',
    default=False,
)


class ConditionTracker(object):
    def __init__(self, close_after_recv_bytes, close_after_send_bytes):
        self._all_close_after_recv_bytes = close_after_recv_bytes
        self._all_close_after_send_bytes = close_after_send_bytes

    def start_next_request(self):
        """move to the next set of close condition"""
        if self._all_close_after_recv_bytes:
            self.target_recv_bytes = self._all_close_after_recv_bytes.pop(0)
            self.remaining_recv_bytes = self.target_recv_bytes
        else:
            self.target_recv_bytes = None
            self.remaining_recv_bytes = None
        if self._all_close_after_send_bytes:
            self.target_send_bytes = self._all_close_after_send_bytes.pop(0)
            self.remaining_send_bytes = self.target_send_bytes
        else:
            self.target_send_bytes = None
            self.remaining_send_bytes = None

    def might_close(self):
        """True, if any processing will be needed"""
        if self.remaining_recv_bytes is not None:
            return True
        if self.remaining_send_bytes is not None:
            return True
        return False


# We can't adjust __class__ on a socket instance. So we define a proxy type.
class socketproxy(object):
    __slots__ = ('_orig', '_logfp', '_cond')

    def __init__(self, obj, logfp, condition_tracked):
        object.__setattr__(self, '_orig', obj)
        object.__setattr__(self, '_logfp', logfp)
        object.__setattr__(self, '_cond', condition_tracked)

    def __getattribute__(self, name):
        if name in ('makefile', 'sendall', '_writelog', '_cond_close'):
            return object.__getattribute__(self, name)

        return getattr(object.__getattribute__(self, '_orig'), name)

    def __delattr__(self, name):
        delattr(object.__getattribute__(self, '_orig'), name)

    def __setattr__(self, name, value):
        setattr(object.__getattribute__(self, '_orig'), name, value)

    def _writelog(self, msg):
        msg = msg.replace(b'\r', b'\\r').replace(b'\n', b'\\n')

        object.__getattribute__(self, '_logfp').write(msg)
        object.__getattribute__(self, '_logfp').write(b'\n')
        object.__getattribute__(self, '_logfp').flush()

    def makefile(self, mode, bufsize):
        f = object.__getattribute__(self, '_orig').makefile(mode, bufsize)

        logfp = object.__getattribute__(self, '_logfp')
        cond = object.__getattribute__(self, '_cond')

        return fileobjectproxy(f, logfp, cond)

    def sendall(self, data, flags=0):
        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes

        # No read limit. Call original function.
        if not remaining:
            result = object.__getattribute__(self, '_orig').sendall(data, flags)
            self._writelog(b'sendall(%d) -> %s' % (len(data), data))
            return result

        if len(data) > remaining:
            newdata = data[0:remaining]
        else:
            newdata = data

        remaining -= len(newdata)

        result = object.__getattribute__(self, '_orig').sendall(newdata, flags)

        self._writelog(
            b'sendall(%d from %d) -> (%d) %s'
            % (len(newdata), len(data), remaining, newdata)
        )

        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining

        if remaining <= 0:
            self._writelog(b'write limit reached; closing socket')
            object.__getattribute__(self, '_orig').shutdown(socket.SHUT_RDWR)

            raise Exception('connection closed after sending N bytes')

        return result


# We can't adjust __class__ on socket._fileobject, so define a proxy.
class fileobjectproxy(object):
    __slots__ = ('_orig', '_logfp', '_cond')

    def __init__(self, obj, logfp, condition_tracked):
        object.__setattr__(self, '_orig', obj)
        object.__setattr__(self, '_logfp', logfp)
        object.__setattr__(self, '_cond', condition_tracked)

    def __getattribute__(self, name):
        if name in ('_close', 'read', 'readline', 'write', '_writelog'):
            return object.__getattribute__(self, name)

        return getattr(object.__getattribute__(self, '_orig'), name)

    def __delattr__(self, name):
        delattr(object.__getattribute__(self, '_orig'), name)

    def __setattr__(self, name, value):
        setattr(object.__getattribute__(self, '_orig'), name, value)

    def _writelog(self, msg):
        msg = msg.replace(b'\r', b'\\r').replace(b'\n', b'\\n')

        object.__getattribute__(self, '_logfp').write(msg)
        object.__getattribute__(self, '_logfp').write(b'\n')
        object.__getattribute__(self, '_logfp').flush()

    def _close(self):
        # Python 3 uses an io.BufferedIO instance. Python 2 uses some file
        # object wrapper.
        if pycompat.ispy3:
            orig = object.__getattribute__(self, '_orig')

            if hasattr(orig, 'raw'):
                orig.raw._sock.shutdown(socket.SHUT_RDWR)
            else:
                self.close()
        else:
            self._sock.shutdown(socket.SHUT_RDWR)

    def read(self, size=-1):
        remaining = object.__getattribute__(self, '_cond').remaining_recv_bytes

        # No read limit. Call original function.
        if not remaining:
            result = object.__getattribute__(self, '_orig').read(size)
            self._writelog(
                b'read(%d) -> (%d) (%s) %s' % (size, len(result), result)
            )
            return result

        origsize = size

        if size < 0:
            size = remaining
        else:
            size = min(remaining, size)

        result = object.__getattribute__(self, '_orig').read(size)
        remaining -= len(result)

        self._writelog(
            b'read(%d from %d) -> (%d) %s'
            % (size, origsize, len(result), result)
        )

        object.__getattribute__(self, '_cond').remaining_recv_bytes = remaining

        if remaining <= 0:
            self._writelog(b'read limit reached; closing socket')
            self._close()

            # This is the easiest way to abort the current request.
            raise Exception('connection closed after receiving N bytes')

        return result

    def readline(self, size=-1):
        remaining = object.__getattribute__(self, '_cond').remaining_recv_bytes

        # No read limit. Call original function.
        if not remaining:
            result = object.__getattribute__(self, '_orig').readline(size)
            self._writelog(
                b'readline(%d) -> (%d) %s' % (size, len(result), result)
            )
            return result

        origsize = size

        if size < 0:
            size = remaining
        else:
            size = min(remaining, size)

        result = object.__getattribute__(self, '_orig').readline(size)
        remaining -= len(result)

        self._writelog(
            b'readline(%d from %d) -> (%d) %s'
            % (size, origsize, len(result), result)
        )

        object.__getattribute__(self, '_cond').remaining_recv_bytes = remaining

        if remaining <= 0:
            self._writelog(b'read limit reached; closing socket')
            self._close()

            # This is the easiest way to abort the current request.
            raise Exception('connection closed after receiving N bytes')

        return result

    def write(self, data):
        remaining = object.__getattribute__(self, '_cond').remaining_send_bytes

        # No byte limit on this operation. Call original function.
        if not remaining:
            result = object.__getattribute__(self, '_orig').write(data)
            self._writelog(b'write(%d) -> %s' % (len(data), data))
            return result

        if len(data) > remaining:
            newdata = data[0:remaining]
        else:
            newdata = data

        remaining -= len(newdata)

        result = object.__getattribute__(self, '_orig').write(newdata)

        self._writelog(
            b'write(%d from %d) -> (%d) %s'
            % (len(newdata), len(data), remaining, newdata)
        )

        object.__getattribute__(self, '_cond').remaining_send_bytes = remaining

        if remaining <= 0:
            self._writelog(b'write limit reached; closing socket')
            self._close()

            raise Exception('connection closed after sending N bytes')

        return result


def process_config(value):
    parts = value.split(b',')
    integers = [int(v) for v in parts if v]
    return [v if v else None for v in integers]


def extsetup(ui):
    # Change the base HTTP server class so various events can be performed.
    # See SocketServer.BaseServer for how the specially named methods work.
    class badserver(server.MercurialHTTPServer):
        def __init__(self, ui, *args, **kwargs):
            self._ui = ui
            super(badserver, self).__init__(ui, *args, **kwargs)

            all_recv_bytes = self._ui.config(
                b'badserver', b'close-after-recv-bytes'
            )
            all_recv_bytes = process_config(all_recv_bytes)
            all_send_bytes = self._ui.config(
                b'badserver', b'close-after-send-bytes'
            )
            all_send_bytes = process_config(all_send_bytes)
            self._cond = ConditionTracker(all_recv_bytes, all_send_bytes)

            # Need to inherit object so super() works.
            class badrequesthandler(self.RequestHandlerClass, object):
                def send_header(self, name, value):
                    # Make headers deterministic to facilitate testing.
                    if name.lower() == 'date':
                        value = 'Fri, 14 Apr 2017 00:00:00 GMT'
                    elif name.lower() == 'server':
                        value = 'badhttpserver'

                    return super(badrequesthandler, self).send_header(
                        name, value
                    )

            self.RequestHandlerClass = badrequesthandler

        # Called to accept() a pending socket.
        def get_request(self):
            if self._ui.configbool(b'badserver', b'close-before-accept'):
                self.socket.close()

                # Tells the server to stop processing more requests.
                self.__shutdown_request = True

                # Simulate failure to stop processing this request.
                raise socket.error('close before accept')

            if self._ui.configbool(b'badserver', b'close-after-accept'):
                request, client_address = super(badserver, self).get_request()
                request.close()
                raise socket.error('close after accept')

            return super(badserver, self).get_request()

        # Does heavy lifting of processing a request. Invokes
        # self.finish_request() which calls self.RequestHandlerClass() which
        # is a hgweb.server._httprequesthandler.
        def process_request(self, socket, address):
            # Wrap socket in a proxy if we need to count bytes.
            self._cond.start_next_request()

            if self._cond.might_close():
                socket = socketproxy(
                    socket, self.errorlog, condition_tracked=self._cond
                )

            return super(badserver, self).process_request(socket, address)

    server.MercurialHTTPServer = badserver