# -*- coding: UTF-8 -*-
# dsa.py: FIPS 186-3 (DSA) implementation

# Copyright © 2007-2009 Kyle McFarland
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#	http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import with_statement

__license__ = "Apache License, Version 2.0"
__copyright__ = u"Copyright © 2007-2009 Kyle McFarland"
__author__ = "Kyle McFarland"

from Crypto.Util import number
from Crypto.Util.number import bytes_to_long, long_to_bytes, isPrime, size, getRandomNumber, inverse
from math import ceil
import os
import sys

import hashlib

N_hashfuncs = {
	160: hashlib.sha1,
	224: hashlib.sha224,
	256: hashlib.sha256,
	384: hashlib.sha384,
	512: hashlib.sha512,
}

class default_progress(object):
	def __init__(self, var):
		pass

	def __enter__(self):
		def progressfunc(c):
			pass
		return progressfunc

	def __exit__(self, *args):
		pass

class openssl_progress(object):
	def __init__(self, var):
		self.var = var
		self.count = 0

	@staticmethod
	def _progressfunc(c):
		sys.stdout.write(c)
		sys.stdout.flush()

	def __enter__(self):
		print ("Generating %s: " % self.var),
		return self._progressfunc

	def __exit__(self, *args):
		print

class openssl2_progress(openssl_progress):
	def _progressfunc(self, c):
		self.count += 1
		if self.count > 10 and c == "." and self.count % 10:
			return
		openssl_progress._progressfunc(c)

class openssl3_progress(openssl_progress):
	def __init__(self, var):
		super(openssl3_progress, self).__init__(var)
		self.first = True

	def _progressfunc(self, c):
		if c == "." and not self.count and not self.first:
			sys.stdout.write("\b"*10)
			sys.stdout.write("d")
			sys.stdout.flush()
		openssl_progress._progressfunc(c)
		self.first = False
		self.count = (self.count + 1) % 10

class openssl4_progress(openssl_progress):
	steps = {
		1: u"○",
		3: u"◔",
		5: u"◑",
		7: u"◕",
		9: u"●",
	}
	def _progressfunc(self, c):
		import math
		if c == ".":
			def format(count):
				return self.steps.get(count, None)
			newc = format(self.count + 1)
			if newc:
				if self.count != 0:
					sys.stdout.write("\b")
				sys.stdout.write(newc.encode(sys.stdout.encoding))
				sys.stdout.flush()
			self.count += 1
			self.count %= 10
		elif c == "+":
			openssl_progress._progressfunc(c)

class number_progress(openssl3_progress):
	def _progressfunc(self, c):
		if c == ".":
			def format(count):
				if count > 1:
					return "%d attempts... " % count
				else:
					return "%d attempt... " % count
			if not self.first:
				sys.stdout.write("\b" * len(format(self.count)))
			self.count += 1
			sys.stdout.write(format(self.count))
			sys.stdout.flush()
			self.first = False
		elif c == "+":
			sys.stdout.write("done")
			sys.stdout.flush()

#default_progress = DefaultProgress

#def default_progress(*args):
#	import sys
#	sys.stdout.write(args[0])
#	sys.stdout.flush()

class Invalid(Exception):
	def __init__(self, part, value, reason):
		Exception.__init__(self, part, value, reason)
		self.part = part
		self.value = value
		self.reason = reason

	def __str__(self):
		return "Invalid Value for '%s' (%s): %s" % (self.part, self.value, self.reason)

class NoPrivate(Exception):
	pass

class DSA(object):
	generation_steps = [
		{"name": "q"},
		{"name": "p", "max_attempts": 4096},
		{"name": "g"},
		{"name": "x"},
		{"name": "y"},
	]

	def __init__(self, p, q, g, y, x=None, L=None, N=None, domain_parameter_seed=None, counter=None, gindex=None):
		# do conversions for easy interop with php, XXX: remove this when php sucks less
		self._L = int(L) if L != None else None
		self._N = int(N) if N != None else None
		self.domain_parameter_seed = int(domain_parameter_seed) if domain_parameter_seed != None else None
		self.counter = int(counter) if counter != None else None
		self.p = int(p)
		self.q = int(q)
		self.g = int(g)
		self.gindex = int(gindex) if gindex != None else None
		self.y = int(y)
		self.x = int(x) if x != None else None

	@property
	def L(self):
		return self._L or size(self.p)

	@property
	def N(self):
		return self._N or size(self.q)

	def _get_keydata(self):
		"""Return data needed to reconstruct this key via. DSA(\**keydata)."""
		# XXX: should I use an automatic sizing of L here (see the L/N properties) or just what was specified originally?
		return {
			"L": self.L,
			"N": self.N,
			"domain_parameter_seed": self.domain_parameter_seed,
			"counter": self.counter,
			"p": self.p,
			"q": self.q,
			"g": self.g,
			"gindex": self.gindex,
			"y": self.y,
			"x": self.x,
		}

	keydata = property(_get_keydata)

	@staticmethod
	def _generatePQ(L, N, randfunc=os.urandom, seedlen=None, outlen=None, progress=default_progress):
		"""Generate the Primes `p` and `q` needed for DSA keypair generation.

		`L` should be the "size" of the key you want (commonly 1024, 2048, etc.),
		`N` is the secondary size, and must correspond to the size of an SHA-* hash result (160, 224, 256, 384, 512),
		`randfunc` is the function used for random data generation, `os.urandom` typically works fairly well,
		`seedlen` and `outlen` both default to the value of `N` if not specified and most of the time don't need manually specifying

		Return a dict with the following keys:
		`status` is "valid" if generating everything succeeded,
		`p` and `q` are needed for public/private key generation and use,
		`domain_parameter_seed` and `counter` aren't strictly needed for key usage, but are needed to generate `g` and to validate `p`, `q` and `g`.

		Uses the algorithm from A.1.1.2 in FIPS 186-3.
		"""
		# TODO: 1. Check that the (L,N) pair is in the list of acceptable (L, N pairs) (see Section 4.2). If the pair is not in the list, then return INVALID.
		Hash = N_hashfuncs[N]
		if not seedlen or seedlen < N:
			seedlen = N
		if not outlen or outlen < N:
			outlen = N
		n = int(ceil(L / outlen))-1
		b = L - 1 - (n * outlen)
		q = None
		p = None
		pprime = False
		two_seedlen = 2**seedlen
		outlen_n = outlen * n
		two_b = 2**b
		min_L = 2**(L-1)
		while p is None or not pprime:
			q = None
			with progress("q") as progressfunc:
				while q is None or not isPrime(q):
					progressfunc(".")
					domain_parameter_seed = getRandomNumber(seedlen, randfunc)
					raw_dps = long_to_bytes(domain_parameter_seed)
					#print domain_parameter_seed
					U = bytes_to_long(Hash(raw_dps).digest()) % 2**(N-1)
					#print U
					q = 2**(N-1) + U+1 - (U % 2)
					#print "q: %d %s" % (q, isPrime(q))
				progressfunc("+")
			offset = 1
			with progress("p") as progressfunc:
				for counter in xrange(4*L):
					progressfunc(".")
					V = []
					for j in range(n+1):
						# XXX: why is % 2**seedlen used in validation but not generation? (in the nov 2008 draft)
						vj = bytes_to_long(Hash(long_to_bytes(domain_parameter_seed + offset + j)).digest())
						V.append(vj)
					#print V
					W = V[0] + sum([V[j] << (outlen*j) for j in range(1, len(V)-1)]) + ((V[n] % two_b) << outlen_n)
					#print W
					X = W + min_L
					c = X % (2*q)
					p = X - (c - 1)
					pprime = isPrime(p)
					#print "p: %d %s" % (p, pprime)
					if p >= min_L and pprime:
						progressfunc("+")
						break
					offset += n + 1
		return {"status": "valid", "p": p, "q": q, "domain_parameter_seed": domain_parameter_seed, "counter": counter}

	@staticmethod
	def _validatePQ(p, q, domain_parameter_seed, counter, L=None, N=None):
		"""Validate the primes `p`, `q`, and the `counter`.

		By validating them you make sure that they were correctly generated based on the input, according to this implementation anyway.

		Return True if `p`, `q` and `counter` are valid, else raise
		an Invalid Exception.

		Uses the algorithm from A.1.1.3 in FIPS 186-3.
		"""
		if L == None:
			L = size(p)
		if N == None:
			N = size(q)
		outlen = N
		Hash = N_hashfuncs[N]
		# TODO: 3. Check that the (L, N) pair is in the list of acceptable (L, N) pairs (see Section 4.2). If the pair is not in the list, return INVALID.
		if counter > 4*L - 1:
			raise Invalid("counter", counter, "larger than 4*L - 1 (%d)" %  (4*L - 1))
		seedlen = size(domain_parameter_seed)
		if seedlen < N:
			raise Invalid("seed length", seedlen, "smaller than N (%d)" % N)
		raw_dps = long_to_bytes(domain_parameter_seed)
		U = bytes_to_long(Hash(raw_dps).digest()) % 2**(N-1)
		computed_q = 2**(N-1) + U+1 - (U % 2)
		if computed_q != q:
			raise Invalid("q", q, "not the same as the computed value for q (%d)" % computed_q)
		elif not isPrime(computed_q):
			raise Invalid("q", computed_q, "not prime")
		n = int(ceil(L/outlen))-1
		b = L - 1 - (n * outlen)
		offset = 1
		two_seedlen = 2**seedlen
		outlen_n = outlen * n
		two_b = 2**b
		min_L = 2**(L-1)
		for i in xrange(counter+1):
			V = []
			for j in range(n+1):
				# XXX: why is % 2**seedlen used in validation but not generation? (in the nov 2008 draft)
				vj = bytes_to_long(Hash(long_to_bytes((domain_parameter_seed + offset + j) % two_seedlen)).digest())
				V.append(vj)
			W = V[0] + sum([V[j] << (outlen*j) for j in range(1, len(V)-1)]) + ((V[n] % two_b) << outlen_n)
			X = W + min_L
			c = X % (2*q)
			computed_p = X - (c - 1)
			if computed_p >= min_L and isPrime(computed_p):
				break
			offset += n + 1
		if i != counter:
			raise Invalid("counter", counter, "not the same as the computed counter value (%d)" % i)
		elif computed_p != p:
			raise Invalid("p", p, "not the same as the computed value for p (%d)" % computed_p)
		elif not isPrime(computed_p):
			raise Invalid("p", computed_p, "not prime")
		return True

	@staticmethod
	def _generateG(p, q, domain_parameter_seed, index=1, progress=default_progress):
		"""Generate the mathematical generator `g` needed for signing/verifying.

		`p`, `q`, and `domain_parameter_seed` should be the equivalent variables generated by `_generatePQ`,
		`index` is a numerical "salt" index, if you want to generate multiple keys using the same `p`/`q` it's a good idea to specify a different index for each of them.

		Return a dict with the following keys:
		`status` is again "valid" if everything went fine,
		`g`,
		`gindex` just reiterates the index you gave.

		Uses the algorithm from A.2.3 in FIPS 186-3.
		"""
		if index >= 2**8:
			raise Invalid("gindex", index, "larger than 255")
		N = size(q)
		Hash = N_hashfuncs[N]
		e = (p - 1)//q
		count = 0
		g = 0
		raw_dps = long_to_bytes(domain_parameter_seed)
		with progress("g") as progressfunc:
			while g < 2:
				progressfunc(".")
				count += 1
				# XXX: are index/count supposed to be ascii'ized or just their binary representation?
				U = raw_dps + "ggen" + str(index) + str(count)
				W = bytes_to_long(Hash(U).digest())
				g = pow(W, e, p)
			progressfunc("+")
		return {"status": "valid", "g": g, "gindex": index}

	@staticmethod
	def _validateG(p, q, domain_parameter_seed, index, g, N=None):
		"""Validate the mathematical generator `g`."""
		if index >= 2**8:
			raise Invalid("gindex", index, "larger than 255")
		if not (2 <= g < p):
			if not (2 <= g):
				raise Invalid("g", g, "smaller than 2")
			else:
				raise Invalid("g", g, "larger or equal to p (%d)" % p)
		if pow(g, q, p) != 1:
			raise Invalid("g", g, "g**q mod p (q == %d, p == %d) != 1" % (q, p))
		if N == None:
			N = size(q)
		Hash = N_hashfuncs[N]
		e = (p - 1)//q
		count = 0
		computed_g = 0
		raw_dps = long_to_bytes(domain_parameter_seed)
		while computed_g < 2:
			count += 1
			# XXX: are index/count supposed to be ascii'ized or just their binary representation?
			U = raw_dps + "ggen" + str(index) + str(count)
			W = bytes_to_long(Hash(U).digest())
			computed_g = pow(W, e, p)
		if computed_g == g:
			return True
		else:
			raise Invalid("g", g, "not the same as the computed value for g (%d)" % computed_g)

	@classmethod
	def generate(cls, L, N, randfunc=os.urandom, validate=False, progress=default_progress):
		"""Generate a new public/private DSA keypair and return the private key.

		`L` should be the "size" of the key you want (commonly 1024, 2048, etc.),
		`N` is the secondary size, and must correspond to the size of an SHA-* hash result (160, 224, 256, 384, 512),
		`randfunc` is the function used for random data generation, `os.urandom` typically works fairly well,
		if `validate` is True we also validate that the generated `p`, `q` and `g` values are correct.

		Generation of `x` and `y` are done with the algorithm from B.1.1 in FIPS 186-3.
		"""
		pq = cls._generatePQ(L, N, randfunc, progress=progress)
		p, q, domain_parameter_seed, counter = pq["p"], pq["q"], pq["domain_parameter_seed"], pq["counter"]
		gd = cls._generateG(p, q, domain_parameter_seed, progress=progress)
		g, gindex = gd["g"], gd["gindex"]
		if validate:
			pqv = cls._validatePQ(p, q, domain_parameter_seed, counter)
			gv = cls._validateG(p, q, domain_parameter_seed, gindex, g)
		# Generate x, y: this uses a weak version of the Using Extra Random Bits method (B.1.1)
		with progress("x") as progressfunc:
			c = getRandomNumber(N+64, randfunc)
			x = (c % (q-1)) + 1
			progressfunc("+")
		with progress("y") as progressfunc:
			y = pow(g, x, p)
			progressfunc("+")
		return cls(L=L, N=N, domain_parameter_seed=domain_parameter_seed, counter=counter, p=p, q=q, g=g, gindex=gindex, y=y, x=x)

	def sign(self, M, k=None, hash=True, randfunc=os.urandom):
		"""Sign the message `M` and Return the signature.

		`k` is the per-signature random number, if left undefined it's generated automatically.
		If `hash` is True the message is hashed using the SHA-* hashing algorithm with the output size the same as `self.N` as per FIPS 186-3, if False then we don't hash the message and let the chips fall where they may (not hashing could have unexpected results if it translated to a number is larger than `self.q` and is really only around for compatability with pycrypto).
		`randfunc` is used to generate `k`, again `os.urandom` is good enough most of the time.

		Uses the algorithm from 4.6 in FIPS 186-3
		"""
		if not self.x:
			raise NoPrivate("Private Key not available in this object")
		N = self.N
		L = self.L
		Hash = N_hashfuncs[N]
		kinverse = 0
		if not k:
			# generate k, this is done with a weak version of the Using Extra Random Bits method (B.2.1)
			while not (0 < kinverse < self.q):
				c = getRandomNumber(N+64, randfunc)
				k = (c % (self.q-1)) + 1
				kinverse = inverse(k, self.q)
				#print "k, kinverse: %s %s" % (k, kinverse)
		else:
			kinverse = inverse(k, self.q)
		r = pow(self.g, k, self.p) % self.q
		z = M
		if hash:
			z = Hash(M).digest()[:N//8]
		z = bytes_to_long(z)
		s = (kinverse * (z + self.x*r)) % self.q
		# TODO: The values of r and s shall be checked to determine if r = 0 or s = 0. If either r = 0 or s = 0, a new value of k shall be generated, and the signature shall be recalculated. It is extremely unlikely that r = 0 or s = 0 if signatures are generated properly.
		return (r, s)

	def verify(self, M, signature, hash=True):
		"""Return True if the signature is valid, False otherwise.

		If `hash` is True `signature` is checked against the hash of `M`, if the message was hashed when signing you want to have this as True, note that pycrypto does not hash messages, so in the case the signature came from it you want to specify it as False, but otherwise True is typically good (see the docs for `self.sign` for details).

		Uses the algorithm from 4.7 in FIPS 186-3.
		"""
		r, s = signature
		N = self.N
		Hash = N_hashfuncs[N]
		if not (0 < r < self.q) or not (0 < s < self.q):
			return False
		w = inverse(s, self.q)
		z = M
		if hash:
			z = Hash(M).digest()[:N//8]
		z = bytes_to_long(z)
		u1 = (z*w) % self.q
		u2 = (r*w) % self.q
		v1 = pow(self.g, u1, self.p)
		v2 = pow(self.y, u2, self.p)
		v = ((v1*v2) % self.p) % self.q
		if v == r:
			return True
		return False

	def validate(self):
		"""Validate `p`, `q` and `g`.

		Return True if everything is ok, otherwise raises an Invalid Exception.
		"""
		pqv = self._validatePQ(self.p, self.q, self.domain_parameter_seed, self.counter, L=self.L, N=self.N)
		gv = self._validateG(self.p, self.q, self.domain_parameter_seed, self.gindex, self.g, N=self.N)
		return all((pqv, gv))

	def publickey(self):
		"""Return a copy of this key without private key data."""
		keydata = self.keydata
		del keydata["x"]
		return self.__class__(**keydata)

if __name__ == "__main__":
	key = DSA.generate(1024, 160)
	import pprint
	pprint.pprint(key.keydata)
	print key.validate()
	signature = key.sign("foo")
	print signature
	print key.verify("foo", signature)
	#stuff = DSA._generatePQ(1024, 160)
	#print stuff
	#print DSA._validatePQ(stuff["p"], stuff["q"], stuff["domain_parameter_seed"], stuff["counter"])
	#gd = DSA._generateG(stuff["p"], stuff["q"], stuff["domain_parameter_seed"])
	#print gd
	#print DSA._validateG(stuff["p"], stuff["q"], stuff["domain_parameter_seed"], gd["index"], gd["g"])

