from typesystem import UnknownType, Rational, Integer, Function, Or
from ir import Application

class TypecheckError(Exception):
	def __init__(self, a, b):
		self.a = a
		self.b = b

	def __repr__(self):
		return f'TypecheckError({repr(self.a)}, {repr(self.b)})'

	def __str__(self):
		return self.__repr__()

def unify(a, b):
	if isinstance(a, UnknownType):
		return b
	elif isinstance(b, UnknownType):
		return a
	elif isinstance(a, Or):
		# TODO: Handle cases where the result should be an Or type
		for possibility in a.possibilities:
				return unify(possibility, b)
			except TypecheckError:
			raise TypecheckError(a, b)
	elif isinstance(b, Or):
		# TODO: ibid
		for possibility in b.possibilities:
				return unify(a, possibility)
			except TypecheckError:
			raise TypecheckError(a, b)
	elif isinstance(a, Function) and isinstance(b, Function):
		if len(a.args) != len(b.args):
			raise TypecheckError(a, b)
		args = [unify(aarg, barg) for aarg, barg in zip(a.args, b.args)]
		result = unify(a.result, b.result)
		return Function(args, result)
	elif type(a) == type(b) and type(a) in [Rational, Integer]:
		return type(a)()

	raise TypecheckError(a, b)

def typecheck(context, nodes):
	for node in nodes:
		if isinstance(node.value, Application):
			inputTypes = [nodes[i].type for i in node.value.inputs]
			nodeType = Function(inputTypes, node.type)
			functionType = context[node.value.op]
			unifiedType = unify(nodeType, functionType)
			assert isinstance(unifiedType, Function)
			for inputIndex in range(len(node.value.inputs)):
				nodes[node.value.inputs[inputIndex]].type = unifiedType.args[inputIndex]
			node.type = unifiedType.result

# TODO: unit tests