Commit 7c99a2c6530cf0c1c2fde7de9c6559be56e2fba0
1 parent
c93445bc
Exists in
master
and in
1 other branch
Updating cross-validation to new classes structure and user profile.
Showing
2 changed files
with
33 additions
and
36 deletions
Show diff stats
src/cross_validation.py
... | ... | @@ -20,6 +20,8 @@ |
20 | 20 | import os |
21 | 21 | import sys |
22 | 22 | import logging |
23 | +import datetime | |
24 | +from datetime import timedelta | |
23 | 25 | |
24 | 26 | from config import * |
25 | 27 | from data import * |
... | ... | @@ -28,35 +30,28 @@ from similarity_measure import * |
28 | 30 | from recommender import * |
29 | 31 | from strategy import * |
30 | 32 | from user import * |
31 | - | |
32 | -def set_up_recommender(cfg): | |
33 | - if cfg.strategy == "cta": | |
34 | - axi_db = xapian.Database(cfg.axi) | |
35 | - app_rec = Recommender(axi_db) | |
36 | - app_rec.set_strategy(AxiContentBasedStrategy()) | |
37 | - | |
38 | - elif cfg.strategy == "ct": | |
39 | - debtags_db = DebtagsDB(cfg.tags_db) | |
40 | - if not debtags_db.load(): | |
41 | - logging.error("Could not load DebtagsDB from %s." % cfg.tags_db) | |
42 | - raise Error | |
43 | - debtags_index = DebtagsIndex(os.path.expanduser(cfg.tags_index)) | |
44 | - debtags_index.load(debtags_db,cfg.reindex) | |
45 | - app_rec = Recommender(debtags_index) | |
46 | - app_rec.set_strategy(ContentBasedStrategy()) | |
47 | - | |
48 | - return app_rec | |
49 | - | |
50 | -def cross_validation(recommender): | |
51 | - metrics = [] | |
52 | - metrics.append(Precision()) | |
53 | - metrics.append(Recall()) | |
54 | - validation = CrossValidation(0.1,10,recommender,metrics) | |
55 | - validation.run(user) | |
33 | +from error import Error | |
56 | 34 | |
57 | 35 | if __name__ == '__main__': |
58 | - cfg = Config() | |
59 | - rec = set_up_recommender(cfg) | |
60 | - user = LocalSystem() | |
61 | - #result.print_result() | |
62 | - cross_validation(rec) | |
36 | + try: | |
37 | + cfg = Config() | |
38 | + rec = Recommender(cfg) | |
39 | + user = LocalSystem() | |
40 | + user.maximal_pkg_profile() | |
41 | + | |
42 | + begin_time = datetime.datetime.now() | |
43 | + logging.debug("Cross-validation started at %s" % begin_time) | |
44 | + | |
45 | + metrics = [] | |
46 | + metrics.append(Precision()) | |
47 | + metrics.append(Recall()) | |
48 | + validation = CrossValidation(0.3,10,rec,metrics) | |
49 | + validation.run(user) | |
50 | + | |
51 | + end_time = datetime.datetime.now() | |
52 | + logging.debug("Cross-validation completed at %s" % end_time) | |
53 | + delta = end_time - begin_time | |
54 | + logging.info("Time elapsed: %d seconds." % delta.seconds) | |
55 | + | |
56 | + except Error: | |
57 | + logging.critical("Aborting proccess. Use '--debug' for more details.") | ... | ... |
src/evaluation.py
... | ... | @@ -33,8 +33,7 @@ class Precision(Metric): |
33 | 33 | self.desc = " Precision " |
34 | 34 | |
35 | 35 | def run(self,evaluation): |
36 | - return float(len(evaluation.predicted_real) / | |
37 | - len(evaluation.predicted_relevant)) | |
36 | + return float(len(evaluation.predicted_real))/len(evaluation.predicted_relevant) | |
38 | 37 | |
39 | 38 | class Recall(Metric): |
40 | 39 | """ """ |
... | ... | @@ -42,8 +41,7 @@ class Recall(Metric): |
42 | 41 | self.desc = " Recall " |
43 | 42 | |
44 | 43 | def run(self,evaluation): |
45 | - return float(len(evaluation.predicted_real) / | |
46 | - len(evaluation.real_relevant)) | |
44 | + return float(len(evaluation.predicted_real))/len(evaluation.real_relevant) | |
47 | 45 | |
48 | 46 | class F1(Metric): |
49 | 47 | """ """ |
... | ... | @@ -89,6 +87,9 @@ class Evaluation: |
89 | 87 | self.real_relevant = real_result.get_prediction() |
90 | 88 | self.predicted_real = [v for v in self.predicted_relevant if v in |
91 | 89 | self.real_relevant] |
90 | + print len(self.predicted_relevant) | |
91 | + print len(self.real_relevant) | |
92 | + print len(self.predicted_real) | |
92 | 93 | |
93 | 94 | def run(self,metric): |
94 | 95 | return metric.run(self) |
... | ... | @@ -134,8 +135,9 @@ class CrossValidation: |
134 | 135 | """ |
135 | 136 | Perform cross-validation. |
136 | 137 | """ |
137 | - partition_size = int(len(user.item_score)*self.partition_proportion) | |
138 | - cross_item_score = user.item_score.copy() | |
138 | + cross_item_score = dict.fromkeys(user.pkg_profile,1) | |
139 | + partition_size = int(len(cross_item_score)*self.partition_proportion) | |
140 | + #cross_item_score = user.item_score.copy() | |
139 | 141 | for r in range(self.rounds): |
140 | 142 | round_partition = {} |
141 | 143 | for j in range(partition_size): | ... | ... |