import ssl
from ssl import utils, bn

from ctypes import *

class Error(Exception):
    pass

CALLBACKFUNC = CFUNCTYPE(c_int, c_int, c_void_p)

# P is the 768 bit prime from rfc 2412 (e.g. http://www.faqs.org/rfcs/rfc2412.html)
# BEWARE!!! it's easy to attack this script by modifying P. If you're not sure, check!
RFC2412_P = 1552518092300708935130918131258481755631334049434514313202351194902966239949102107258669453876591642442910007680288864229150803718918046342632727613031282983744380820890196288509170691316593175367469551763119843371637221007210577919
RFC2412_G = 2

class dhst(Structure):
    _fields_ = [
        ("pad", c_int),
        ("version", c_int),
        ("p", bn.BN),
        ("g", bn.BN),
        ("length", c_long),
        ("pub_key", bn.BN),
        ("priv_key", bn.BN),]

dhst_p = POINTER(dhst)

class DH:
    def __init__(self, dh=None):
        self._dll = ssl.dll
        if dh:
            self._dh = dh
        else:
            self._dh = self._dll.DH_new()
        
    def __del__(self):
        if self._dh:
            self._dll.DH_free(self._dh)
        del self._dll

    def size(self):
        return self._dll.DH_size(self._dh)
    
    def generateParameters(cls, primeLen, generator, callback=None):
        """
        Class method to create new DH
        """
        if callback:
            callback = CALLBACKFUNC(callback)
        dh = ssl.dll.DH_generate_parameters(primeLen, generator, callback, None)
        return DH(dh)
    generateParameters = classmethod(generateParameters)

    def _get_g(self):
        return long(self._dh.contents.g)
    def _set_g(self, g):
        assert g == 2 or g == 5, "2 and 5 are the only two valid generators"
        self._dh.contents.g = bn.BN(g)
    g = property(_get_g)

    def _get_p(self):
        return long(self._dh.contents.p)
    def _set_p(self, p):
        self._dh.contents.p = bn.BH(p)
    p = property(_get_p)
    
    def check(self):
        raise NotImplementedError("TODO")
        
    def generateKey(self):
        if not self._dll.DH_generate_key(self._dh):
            raise utils.getErrorFromQueue()

    def computeKey(self, key, pubKey):
        raise NotImplementedError("TODO")
        
        if self._dll.DH_compute_key(self._dh) == -1:
            raise utils.getErrorFromQueue()

    def asPEM(self):
        """
        PKCS#3 DH parameter
        """
        #self._dll.i2d_DHparams.restype = c_void_p
        #self._dll.i2d_DHparams.argtypes = (dhst_p, c_void_p,)
        
        c_ubyte_p = POINTER(c_ubyte)
        buf = c_ubyte_p()
        
        size = self._dll.i2d_DHparams(self._dh, None)
        
        print "size when NULL:", size
        
        size = self._dll.i2d_DHparams(self._dh, byref(buf))

        if size < 0:
            raise Error("problems with i2d_DHparams")
        print "size:", size
        for ii in range(size):
            print buf[ii]
            #print buf.contents.value
        
        self._dll.CRYPTO_free(buf)
        
        raise Error(dir(buf))
        
                
        
        #bio = utils.MemoryBIO()
        #ret = self._dll.PEM_write_bio_DHparams(bio, self._dh)
        #raise Error("%r %r" % (ret, str(bio)))
        #ret = POINTER(c_char * size)()
        #size = self._dll.i2d_DHparams(self._dh, byref(ret))

                
        #return ret.contents.raw

ssl.dll.DH_new.restype = dhst_p
ssl.dll.DH_new.argtypes = ()
ssl.dll.DH_free.argtypes = (dhst_p,)
ssl.dll.DH_generate_parameters.restype = dhst_p


"""    
 DH *   DH_generate_parameters(int prime_len, int generator,
                void (*callback)(int, int, void *), void *cb_arg);
 int    DH_check(const DH *dh, int *codes);

 int    DH_generate_key(DH *dh);
 int    DH_compute_key(unsigned char *key, BIGNUM *pub_key, DH *dh);

 DH *   d2i_DHparams(DH **a, unsigned char **pp, long length);
 int    i2d_DHparams(const DH *a, unsigned char **pp);

 int    DHparams_print_fp(FILE *fp, const DH *x);
 int    DHparams_print(BIO *bp, const DH *x);
"""