Commit 7c99a2c6530cf0c1c2fde7de9c6559be56e2fba0
1 parent
c93445bc
Exists in
master
and in
1 other branch
Updating cross-validation to new classes structure and user profile.
Showing
2 changed files
with
33 additions
and
36 deletions
Show diff stats
src/cross_validation.py
| ... | ... | @@ -20,6 +20,8 @@ |
| 20 | 20 | import os |
| 21 | 21 | import sys |
| 22 | 22 | import logging |
| 23 | +import datetime | |
| 24 | +from datetime import timedelta | |
| 23 | 25 | |
| 24 | 26 | from config import * |
| 25 | 27 | from data import * |
| ... | ... | @@ -28,35 +30,28 @@ from similarity_measure import * |
| 28 | 30 | from recommender import * |
| 29 | 31 | from strategy import * |
| 30 | 32 | from user import * |
| 31 | - | |
| 32 | -def set_up_recommender(cfg): | |
| 33 | - if cfg.strategy == "cta": | |
| 34 | - axi_db = xapian.Database(cfg.axi) | |
| 35 | - app_rec = Recommender(axi_db) | |
| 36 | - app_rec.set_strategy(AxiContentBasedStrategy()) | |
| 37 | - | |
| 38 | - elif cfg.strategy == "ct": | |
| 39 | - debtags_db = DebtagsDB(cfg.tags_db) | |
| 40 | - if not debtags_db.load(): | |
| 41 | - logging.error("Could not load DebtagsDB from %s." % cfg.tags_db) | |
| 42 | - raise Error | |
| 43 | - debtags_index = DebtagsIndex(os.path.expanduser(cfg.tags_index)) | |
| 44 | - debtags_index.load(debtags_db,cfg.reindex) | |
| 45 | - app_rec = Recommender(debtags_index) | |
| 46 | - app_rec.set_strategy(ContentBasedStrategy()) | |
| 47 | - | |
| 48 | - return app_rec | |
| 49 | - | |
| 50 | -def cross_validation(recommender): | |
| 51 | - metrics = [] | |
| 52 | - metrics.append(Precision()) | |
| 53 | - metrics.append(Recall()) | |
| 54 | - validation = CrossValidation(0.1,10,recommender,metrics) | |
| 55 | - validation.run(user) | |
| 33 | +from error import Error | |
| 56 | 34 | |
| 57 | 35 | if __name__ == '__main__': |
| 58 | - cfg = Config() | |
| 59 | - rec = set_up_recommender(cfg) | |
| 60 | - user = LocalSystem() | |
| 61 | - #result.print_result() | |
| 62 | - cross_validation(rec) | |
| 36 | + try: | |
| 37 | + cfg = Config() | |
| 38 | + rec = Recommender(cfg) | |
| 39 | + user = LocalSystem() | |
| 40 | + user.maximal_pkg_profile() | |
| 41 | + | |
| 42 | + begin_time = datetime.datetime.now() | |
| 43 | + logging.debug("Cross-validation started at %s" % begin_time) | |
| 44 | + | |
| 45 | + metrics = [] | |
| 46 | + metrics.append(Precision()) | |
| 47 | + metrics.append(Recall()) | |
| 48 | + validation = CrossValidation(0.3,10,rec,metrics) | |
| 49 | + validation.run(user) | |
| 50 | + | |
| 51 | + end_time = datetime.datetime.now() | |
| 52 | + logging.debug("Cross-validation completed at %s" % end_time) | |
| 53 | + delta = end_time - begin_time | |
| 54 | + logging.info("Time elapsed: %d seconds." % delta.seconds) | |
| 55 | + | |
| 56 | + except Error: | |
| 57 | + logging.critical("Aborting proccess. Use '--debug' for more details.") | ... | ... |
src/evaluation.py
| ... | ... | @@ -33,8 +33,7 @@ class Precision(Metric): |
| 33 | 33 | self.desc = " Precision " |
| 34 | 34 | |
| 35 | 35 | def run(self,evaluation): |
| 36 | - return float(len(evaluation.predicted_real) / | |
| 37 | - len(evaluation.predicted_relevant)) | |
| 36 | + return float(len(evaluation.predicted_real))/len(evaluation.predicted_relevant) | |
| 38 | 37 | |
| 39 | 38 | class Recall(Metric): |
| 40 | 39 | """ """ |
| ... | ... | @@ -42,8 +41,7 @@ class Recall(Metric): |
| 42 | 41 | self.desc = " Recall " |
| 43 | 42 | |
| 44 | 43 | def run(self,evaluation): |
| 45 | - return float(len(evaluation.predicted_real) / | |
| 46 | - len(evaluation.real_relevant)) | |
| 44 | + return float(len(evaluation.predicted_real))/len(evaluation.real_relevant) | |
| 47 | 45 | |
| 48 | 46 | class F1(Metric): |
| 49 | 47 | """ """ |
| ... | ... | @@ -89,6 +87,9 @@ class Evaluation: |
| 89 | 87 | self.real_relevant = real_result.get_prediction() |
| 90 | 88 | self.predicted_real = [v for v in self.predicted_relevant if v in |
| 91 | 89 | self.real_relevant] |
| 90 | + print len(self.predicted_relevant) | |
| 91 | + print len(self.real_relevant) | |
| 92 | + print len(self.predicted_real) | |
| 92 | 93 | |
| 93 | 94 | def run(self,metric): |
| 94 | 95 | return metric.run(self) |
| ... | ... | @@ -134,8 +135,9 @@ class CrossValidation: |
| 134 | 135 | """ |
| 135 | 136 | Perform cross-validation. |
| 136 | 137 | """ |
| 137 | - partition_size = int(len(user.item_score)*self.partition_proportion) | |
| 138 | - cross_item_score = user.item_score.copy() | |
| 138 | + cross_item_score = dict.fromkeys(user.pkg_profile,1) | |
| 139 | + partition_size = int(len(cross_item_score)*self.partition_proportion) | |
| 140 | + #cross_item_score = user.item_score.copy() | |
| 139 | 141 | for r in range(self.rounds): |
| 140 | 142 | round_partition = {} |
| 141 | 143 | for j in range(partition_size): | ... | ... |