Update partialable.

This commit is contained in:
Ivan Malison 2014-09-13 03:15:25 -07:00
parent d6d5adef71
commit 40f535f224

View File

@ -4,13 +4,13 @@ import inspect
class n_partialable(object): class n_partialable(object):
@staticmethod @staticmethod
def arity_evaluation_checker(function, is_method=False): def arity_evaluation_checker(function):
is_class = inspect.isclass(function) is_class = inspect.isclass(function)
if is_class: if is_class:
function = function.__init__ function = function.__init__
function_info = inspect.getargspec(function) function_info = inspect.getargspec(function)
function_args = function_info.args function_args = function_info.args
if is_class or is_method: if is_class:
# This is to handle the fact that self will get passed in automatically. # This is to handle the fact that self will get passed in automatically.
function_args = function_args[1:] function_args = function_args[1:]
def evaluation_checker(*args, **kwargs): def evaluation_checker(*args, **kwargs):
@ -24,17 +24,10 @@ class n_partialable(object):
return not needed_args or kwarg_keys.issuperset(needed_args) return not needed_args or kwarg_keys.issuperset(needed_args)
return evaluation_checker return evaluation_checker
@staticmethod def __init__(self, function, evaluation_checker=None, args=None, kwargs=None):
def count_evaluation_checker(count):
def function(*args, **kwargs):
return len(args) >= count
return function
def __init__(self, function, evaluation_checker=None, args=None, kwargs=None,
is_method=False):
self.function = function self.function = function
self.evaluation_checker = (evaluation_checker or self.evaluation_checker = (evaluation_checker or
self.arity_evaluation_checker(function, is_method)) self.arity_evaluation_checker(function))
self.args = args or () self.args = args or ()
self.kwargs = kwargs or {} self.kwargs = kwargs or {}
@ -48,5 +41,11 @@ class n_partialable(object):
return type(self)(self.function, self.evaluation_checker, return type(self)(self.function, self.evaluation_checker,
new_args, new_kwargs) new_args, new_kwargs)
def __get__(self, obj, obj_type):
bound = type(self)(self.function, self.evaluation_checker,
args=self.args + (obj,), kwargs=self.kwargs)
setattr(obj, self.function.__name__, bound)
return bound
n_partialable = n_partialable(n_partialable) n_partialable = n_partialable(n_partialable)