From bb8d206b3a8d722c6c54f32f50f50059417f8dc3 Mon Sep 17 00:00:00 2001 From: Tássia Camões Araújo Date: Fri, 4 Mar 2011 17:43:42 -0300 Subject: [PATCH] Implementation of cross validation completed. The result is printed in a matrix format. (closes #3) --- src/app_recommender.py | 6 ++++++ src/data.py | 4 ++-- src/evaluation.py | 96 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------- 3 files changed, 75 insertions(+), 31 deletions(-) diff --git a/src/app_recommender.py b/src/app_recommender.py index 2cb7f1f..5414700 100755 --- a/src/app_recommender.py +++ b/src/app_recommender.py @@ -50,3 +50,9 @@ if __name__ == '__main__': result = recommender.generate_recommendation(user) result.print_result() + + metrics = [] + metrics.append(Precision()) + metrics.append(Recall()) + validation = CrossValidation(0.1,10,recommender,metrics) + validation.run(user) diff --git a/src/data.py b/src/data.py index 3a7395e..851272e 100644 --- a/src/data.py +++ b/src/data.py @@ -76,8 +76,8 @@ class DebtagsIndex: """ Load an existing debtags index. """ if not reindex: try: - print ("Opening existing debtags xapian index at \'%s\'" % - self.path) + #print ("Opening existing debtags xapian index at \'%s\'" % + # self.path) self.index = xapian.Database(self.path) except DatabaseError: print "Could not open debtags xapian index" diff --git a/src/evaluation.py b/src/evaluation.py index ec9615c..01dd19e 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -17,13 +17,18 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import random +from collections import defaultdict +from user import * +from recommender import * + class Metric: """ """ class Precision(Metric): """ """ - def __init_(self): - self.desc = "Precision" + def __init__(self): + self.desc = " Precision " def run(self,evaluation): return float(len(evaluation.predicted_real) / @@ -31,8 +36,8 @@ class Precision(Metric): class Recall(Metric): """ """ - def __init_(self): - self.desc = "Recall" + def __init__(self): + self.desc = " Recall " def run(self,evaluation): return float(len(evaluation.predicted_real) / @@ -40,8 +45,8 @@ class Recall(Metric): class F1(Metric): """ """ - def __init_(self): - self.desc = "F1" + def __init__(self): + self.desc = " F1 " def run(self,evaluation): p = Precision().run(evaluation) @@ -50,24 +55,24 @@ class F1(Metric): class MAE(Metric): """ """ - def __init_(self): - self.desc = "MAE" + def __init__(self): + self.desc = " MAE " def run(self,evaluation): print "run" class MSE(Metric): """ """ - def __init_(self): - self.desc = "MSE" + def __init__(self): + self.desc = " MSE " def run(self,evaluation): print "run" class Coverage(Metric): """ """ - def __init_(self): - self.desc = "Coverage" + def __init__(self): + self.desc = " Coverage " def run(self,evaluation): print "run" @@ -77,9 +82,9 @@ class Evaluation: def __init__(self,predicted_result,real_result): """ """ self.predicted_item_scores = predicted_result.item_score - self.predicted_relevant = predicted_result.get_prediction.keys() + self.predicted_relevant = predicted_result.get_prediction() self.real_item_scores = real_result.item_score - self.real_relevant = real_result.get_prediction.keys() + self.real_relevant = real_result.get_prediction() self.predicted_real = [v for v in self.predicted_relevant if v in self.real_relevant] @@ -88,27 +93,60 @@ class Evaluation: class CrossValidation: """ Cross-validation method """ - def __init__(self,partition_size,rounds,rec,metrics_list): + def __init__(self,partition_proportion,rounds,rec,metrics_list): """ Set parameters: partition_size, rounds, recommender and metrics_list """ - self.partition_size = partition_size + if partition_proportion<1 and partition_proportion>0: + self.partition_proportion = partition_proportion + else: + print "A proporcao de particao deve ser um avalor ente 0 e 1." + exit(1) self.rounds = rounds self.recommender = rec - self.metrics_list = self.metrics_list + self.metrics_list = metrics_list + self.cross_results = defaultdict(list) + + def print_result(self): + print "" + metrics_desc = "" + for metric in self.metrics_list: + metrics_desc += "%s|" % (metric.desc) + print "| Round |%s" % metrics_desc + for r in range(self.rounds): + metrics_result = "" + for metric in self.metrics_list: + metrics_result += (" %.2f |" % + (self.cross_results[metric.desc][r])) + print "| %d |%s" % (r,metrics_result) + metrics_mean = "" + for metric in self.metrics_list: + mean = float(sum(self.cross_results[metric.desc]) / + len(self.cross_results[metric.desc])) + metrics_mean += " %.2f |" % (mean) + print "| Mean |%s" % (metrics_mean) def run(self,user): """ Perform cross-validation. """ - for i in rounds: - cross_result = {} - for metric in self.metrics_list: - cross_results[metric.desc] = [] - cross_user = User(user.item_score) # FIXME: choose subset - predicted_result = self.recommender.gererateRecommendation() - evaluation = Evaluation(predicted_result,user.item_score) + partition_size = int(len(user.item_score)*self.partition_proportion) + cross_item_score = user.item_score.copy() + for r in range(self.rounds): + round_partition = {} + for j in range(partition_size): + if len(cross_item_score)>0: + random_key = random.choice(cross_item_score.keys()) + else: + print "cross_item_score vazio" + exit(1) + round_partition[random_key] = cross_item_score.pop(random_key) + round_user = User(cross_item_score) + predicted_result = self.recommender.generate_recommendation(round_user) + real_result = RecommendationResult(round_partition,len(round_partition)) + evaluation = Evaluation(predicted_result,real_result) for metric in self.metrics_list: - cross_results[metric.desc].append(evaluation.run(metric)) - for metric in self.metrics_list: - mean = (sum(cross_result[metric.desc]) / - len(cross_result[metric.desc])) - print "Mean %d: %2f" % (metric.desc,mean) + result = evaluation.run(metric) + self.cross_results[metric.desc].append(result) + while len(round_partition)>0: + item,score = round_partition.popitem() + cross_item_score[item] = score + self.print_result() -- libgit2 0.21.2