# we'll use json for the over the wire stuff
import simplejson
from Crypto.PublicKey import DSA
from Crypto.Util.number import getPrime, getRandomNumber
from tfplib import timeutils
import os
import sys
import pytz
import uuid
import datetime
import dateutil, dateutil.parser
import hashlib

N_hashfuncs = {
	160: hashlib.sha1,
	224: hashlib.sha224,
	256: hashlib.sha256,
}

def GetUnicode(String, encoding=None):
	if isinstance(String, unicode):
		return String
	if encoding:
		# try it with the user-specified encoding
		try:
			return String.decode(encoding)
		except (UnicodeDecodeError, TypeError):
			pass
	# First try with ascii codec
	try:
		return String.decode("ascii")
	except UnicodeDecodeError:
		pass
	# if that fails try utf-8
	try:
		return String.decode("utf-8")
	except UnicodeDecodeError:
		# if this fails Complain, loudly
		raise

class InvalidKeyfile(Exception):
	pass

class Signature(object):
	def __init__(self, text, signature):
		self.text = text
		self.signature = signature

	def serialize(self, to="json"):
		d = {"text": self.text, "signature": self.signature}
		if to == "dict":
			return d
		elif to == "json":
			return simplejson.dumps(d)

	def __repr__(self):
		return "Signature(%s, %s)" % (repr(self.text), repr(self.signature))

	def __str__(self):
		return self.serialize()

	def verify(self, key):
		return key.verify(self.text, self.signature)

	@classmethod
	def from_json(cls, json):
		d = simplejson.loads(json)
		return cls(d["text"], d["signature"])

class Key(object):
	def __init__(self, key=None, keydata=None, file=None, size=1024, randfunc=os.urandom, json=None, dict=None, N=160):
		self._key = None
		if key:
			self._key = key
			self.N = N
		elif keydata:
			self._key = DSA.construct(keydata)
			self.N = N
		elif dict:
			self.load_dict(dict)
		elif json:
			self.load_json(json)
		elif file:
			self.load_file(file)
		else:
			self._key = DSA.generate(size, randfunc)
			# XXX: hardcoded at 160, pycrypto needs to fixed to allow you to specify N/the size in bits of q before you can use other values
			self.N = 160
		for name in ("can_encrypt", "can_sign", "decrypt", "encrypt", "has_private", "size"):
			setattr(self, name, getattr(self._key, name))

	def load_dict(self, d):
		"""loads a key from the dictionary d"""
		self.N = d.get("N", 160)
		self._key = DSA.construct(d["key"])

	def load_json(self, s):
		"""loads a key from the JSON string s"""
		d = simplejson.loads(s)
		if d.get("filetype") != "Sauth Keyfile":
			raise InvalidKeyfile, "'filetype' isn't a Sauth Keyfile"
		self.load_dict(d)

	def load_file(self, f):
		"""loads a key from the file-like object f"""
		if isinstance(f, basestring):
			f = open(f, "r")
		self.load_json(f.read())

	def serialize(self, to="json", type="public", f=None):
		d = {"filetype": "Sauth Keyfile", "N": self.N}
		key = self._key
		if type == "private" and not key.has_private():
			raise TypeError, "this key doesn't contain a private part"
		elif type == "public" and key.has_private():
			key = key.publickey()
		d["type"] = type
		data = [getattr(key, part) for part in key.keydata if hasattr(key, part)]
		d["key"] = data
		if to == "dict":
			return d
		elif to == "json":
			return simplejson.dumps(d)
		elif to == "file":
			if isinstance(f, basestring):
				f = open(f, "w")
			f.write(simplejson.dumps(d))
			return

	def sign(self, text, encoding=None, rfunc=os.urandom, rlen=128):
		text = GetUnicode(text, encoding)
		k = getRandomNumber(rlen, rfunc)
		return Signature(text, self._key.sign(N_hashfuncs[self.N](text.encode("utf-8")).digest(), k))

	def verify(self, text, signature, encoding=None):
		text = GetUnicode(text, encoding)
		return bool(self._key.verify(N_hashfuncs[self.N](text.encode("utf-8")).digest(), signature))

	def publickey(self):
		return Key(key=self._key.publickey(), N=self.N)

class BasicDataSource(object):
	"""the most basic data source for sauth requests, this does not
	provide any sort of security or identity information and is not meant
	to be used directly. Provides the following keys:
		time: a datetime string providing the time of the request, the
			server should verify the provided time is reasonably recent
			taking into consideration the drift key
		drift: a timedelta string or offset in seconds that specifies the
			amount of +/- time the server should allow the time when
			checking it, so say the client sent an auth message at 01:02
			on its clock, but when the server receives it it's 01:05, if
			the drift is >=3 minutes it would validate, otherwise it
			wouldn't (note this should go backwards too, so the same if
			it was 01:08 on the client)
			DEFAULTS TO: 5 minutes
		uuid: the uuid for this request, mostly to stop people who may
			have been listening in on the request from using it again within
			the time that the server accepts it. the server should keep
			track of the uuid's so that within the drift period you can't
			use a request with the same uuid twice, it should be safe
			enough to assume that after time+drift+5 seconds you can throw
			away the corresponding uuid
	"""
	def __init__(self, drift=None):
		if drift is None:
			drift = datetime.timedelta(minutes=5)
		self.drift = drift

	def __call__(self):
		return {
			"time": datetime.datetime.now(pytz.utc).isoformat(),
			"drift": str(self.drift),
			"uuid": str(uuid.uuid1()),
		}

class UsernameIdentSource(object):
	"""simple Data Source which adds username-based identity to the
	request, provides the following key:
		username: your username, duh
	"""
	def __init__(self, username):
		self.username = username

	def __call__(self):
		return {"username": self.username}

def CheckBasic(d, uuidstore):
	"""Most Basic Auth Data Checker, checks that the following keys are
	within their restraints (outlined), corresponds to the keys populated
	by BasicDataSource:
		time: checks that the current time is within "drift" of the given
			time
		uuid: checks that the given uuid hasn't been used recently on this
			server (within "drift" to be technical)
			Note: this uses uuidstore for checking, but doesn't add the
			uuid to it, you have to do that afterwards
	"""
	if isinstance(d["drift"], int) or isinstance(d["drift"], long):
		drift = datetime.timedelta(seconds=d["drift"])
	else:
		drift = timeutils.parse_timedelta(d["drift"])
	time = dateutil.parser.parse(d["time"])
	now = datetime.datetime.now(pytz.utc)
	return all((timeutils.in_range(time, now-drift, now+drift), uuid not in uuidstore))


