From 7c99a2c6530cf0c1c2fde7de9c6559be56e2fba0 Mon Sep 17 00:00:00 2001 From: Tássia Camões Araújo Date: Thu, 17 Mar 2011 20:24:01 -0300 Subject: [PATCH] Updating cross-validation to new classes structure and user profile. --- src/cross_validation.py | 55 +++++++++++++++++++++++++------------------------------ src/evaluation.py | 14 ++++++++------ 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/src/cross_validation.py b/src/cross_validation.py index 0aa8708..97dbd93 100755 --- a/src/cross_validation.py +++ b/src/cross_validation.py @@ -20,6 +20,8 @@ import os import sys import logging +import datetime +from datetime import timedelta from config import * from data import * @@ -28,35 +30,28 @@ from similarity_measure import * from recommender import * from strategy import * from user import * - -def set_up_recommender(cfg): - if cfg.strategy == "cta": - axi_db = xapian.Database(cfg.axi) - app_rec = Recommender(axi_db) - app_rec.set_strategy(AxiContentBasedStrategy()) - - elif cfg.strategy == "ct": - debtags_db = DebtagsDB(cfg.tags_db) - if not debtags_db.load(): - logging.error("Could not load DebtagsDB from %s." % cfg.tags_db) - raise Error - debtags_index = DebtagsIndex(os.path.expanduser(cfg.tags_index)) - debtags_index.load(debtags_db,cfg.reindex) - app_rec = Recommender(debtags_index) - app_rec.set_strategy(ContentBasedStrategy()) - - return app_rec - -def cross_validation(recommender): - metrics = [] - metrics.append(Precision()) - metrics.append(Recall()) - validation = CrossValidation(0.1,10,recommender,metrics) - validation.run(user) +from error import Error if __name__ == '__main__': - cfg = Config() - rec = set_up_recommender(cfg) - user = LocalSystem() - #result.print_result() - cross_validation(rec) + try: + cfg = Config() + rec = Recommender(cfg) + user = LocalSystem() + user.maximal_pkg_profile() + + begin_time = datetime.datetime.now() + logging.debug("Cross-validation started at %s" % begin_time) + + metrics = [] + metrics.append(Precision()) + metrics.append(Recall()) + validation = CrossValidation(0.3,10,rec,metrics) + validation.run(user) + + end_time = datetime.datetime.now() + logging.debug("Cross-validation completed at %s" % end_time) + delta = end_time - begin_time + logging.info("Time elapsed: %d seconds." % delta.seconds) + + except Error: + logging.critical("Aborting proccess. Use '--debug' for more details.") diff --git a/src/evaluation.py b/src/evaluation.py index 6e39d61..2c55f17 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -33,8 +33,7 @@ class Precision(Metric): self.desc = " Precision " def run(self,evaluation): - return float(len(evaluation.predicted_real) / - len(evaluation.predicted_relevant)) + return float(len(evaluation.predicted_real))/len(evaluation.predicted_relevant) class Recall(Metric): """ """ @@ -42,8 +41,7 @@ class Recall(Metric): self.desc = " Recall " def run(self,evaluation): - return float(len(evaluation.predicted_real) / - len(evaluation.real_relevant)) + return float(len(evaluation.predicted_real))/len(evaluation.real_relevant) class F1(Metric): """ """ @@ -89,6 +87,9 @@ class Evaluation: self.real_relevant = real_result.get_prediction() self.predicted_real = [v for v in self.predicted_relevant if v in self.real_relevant] + print len(self.predicted_relevant) + print len(self.real_relevant) + print len(self.predicted_real) def run(self,metric): return metric.run(self) @@ -134,8 +135,9 @@ class CrossValidation: """ Perform cross-validation. """ - partition_size = int(len(user.item_score)*self.partition_proportion) - cross_item_score = user.item_score.copy() + cross_item_score = dict.fromkeys(user.pkg_profile,1) + partition_size = int(len(cross_item_score)*self.partition_proportion) + #cross_item_score = user.item_score.copy() for r in range(self.rounds): round_partition = {} for j in range(partition_size): -- libgit2 0.21.2