#!/usr/bin/python
# Copyright 2011  Lars Wirzenius
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# Excercise my B-tree implementation, for simple benchmarking purposes.
# The benchmark gets a location and an operation count as command line
# arguments.
#
# If the location is the empty string, an in-memory node store is used.
# Otherwise it must be a non-existent directory name.
#
# The benchmark will do the given number of insertions into the tree, and
# measure the speed of that. Then it will look up each of those, and measure
# the lookups.


import cliapp
import cProfile
import csv
import gc
import logging
import os
import random
import shutil
import subprocess
import sys
import time
import tracing

import larch


class RefcountSpeedTest(cliapp.Application):

    def add_settings(self):
        self.settings.boolean(['profile'], 
                'profile with cProfile?')
        self.settings.boolean(['log-memory-use'], 'log VmRSS?')
        self.settings.string(['trace'], 
                'code module in which to do trace logging')
        self.settings.integer(['refs'], 
                'how many refs to test with (default is %default)',
                default=2**15)
        self.settings.integer(['times'], 
                'how many times to test each op (default is %default)',
                default=1000)

    def process_args(self, args):
        if self.settings['trace']:
            tracing.trace_add_pattern(self.settings['trace'])

        n = self.settings['refs']
        refcounts = {}
        for i in xrange(n):
            refcounts[i] = i

        # Helper functions.
        nop = lambda *args: None
        
        # Calibrate.
        looptime = self.measure(nop, 'calibrate')
        
        num_refcounts = len(refcounts)
        keys = refcounts.keys()
        encode = self.measure(
            lambda: larch.refcountstore.encode_refcounts(
                refcounts, 0, num_refcounts, keys),
            'encode')
        encoded = larch.refcountstore.encode_refcounts(
            refcounts, 0, num_refcounts, keys)
        decode = self.measure(lambda: 
                                larch.refcountstore.decode_refcounts(encoded),
                              'decode')

        # Report
        def speed(result):
            return n / (result - looptime)
        def report(label, result):
            print '%-12s: %5.3f s (%8.1f/s)' % \
                (label, result, speed(result))

        print 'refs: %d' % self.settings['refs']
        print 'times: %d' % self.settings['times']
        report('encode', encode)
        report('decode', decode)

    def measure(self, func, profname):

        def log_memory_use(stage):
            if self.settings['log-memory-use']:
                logging.info('%s memory use: %s' % (profname, stage))
                logging.info('  VmRSS: %s KiB' %  self.vmrss())
                logging.info('  # objects: %d' % len(gc.get_objects()))
                logging.info('  # garbage: %d' % len(gc.garbage))

        def helper():
            n = self.settings['times']
            log_memory_use('at start')
            finished = False
            while not finished:
                for i in xrange(n):
                    func()
                if time.clock() > start:
                    # at least one time unit passed - this is enough
                    finished = True
                else:
                    # Not a single time unit passed: we need more iterations.
                    # Multiply 'times' by 10 and execute the remaining 9 loops.
                    self.settings['times'] *= 10
                    n *= 9
            log_memory_use('after calls')

        print 'measuring', profname
        start = time.clock()
        if self.settings['profile']:
            globaldict = globals().copy()
            localdict = locals().copy()
            cProfile.runctx('helper()', globaldict, localdict, 
                            '%s.prof' % profname)
        else:
            helper()
        end = time.clock()
        return end - start

    def vmrss(self):
        f = open('/proc/self/status')
        rss = 0
        for line in f:
            if line.startswith('VmRSS'):
                rss = line.split()[1]
        f.close()
        return rss

    def format(self, value):
        return '%.0f' % value


if __name__ == '__main__':
    RefcountSpeedTest().run()
