Commit 7c99a2c6530cf0c1c2fde7de9c6559be56e2fba0
1 parent
c93445bc
Exists in
master
and in
1 other branch
Updating cross-validation to new classes structure and user profile.
Showing
2 changed files
with
33 additions
and
36 deletions
Show diff stats
src/cross_validation.py
@@ -20,6 +20,8 @@ | @@ -20,6 +20,8 @@ | ||
20 | import os | 20 | import os |
21 | import sys | 21 | import sys |
22 | import logging | 22 | import logging |
23 | +import datetime | ||
24 | +from datetime import timedelta | ||
23 | 25 | ||
24 | from config import * | 26 | from config import * |
25 | from data import * | 27 | from data import * |
@@ -28,35 +30,28 @@ from similarity_measure import * | @@ -28,35 +30,28 @@ from similarity_measure import * | ||
28 | from recommender import * | 30 | from recommender import * |
29 | from strategy import * | 31 | from strategy import * |
30 | from user import * | 32 | from user import * |
31 | - | ||
32 | -def set_up_recommender(cfg): | ||
33 | - if cfg.strategy == "cta": | ||
34 | - axi_db = xapian.Database(cfg.axi) | ||
35 | - app_rec = Recommender(axi_db) | ||
36 | - app_rec.set_strategy(AxiContentBasedStrategy()) | ||
37 | - | ||
38 | - elif cfg.strategy == "ct": | ||
39 | - debtags_db = DebtagsDB(cfg.tags_db) | ||
40 | - if not debtags_db.load(): | ||
41 | - logging.error("Could not load DebtagsDB from %s." % cfg.tags_db) | ||
42 | - raise Error | ||
43 | - debtags_index = DebtagsIndex(os.path.expanduser(cfg.tags_index)) | ||
44 | - debtags_index.load(debtags_db,cfg.reindex) | ||
45 | - app_rec = Recommender(debtags_index) | ||
46 | - app_rec.set_strategy(ContentBasedStrategy()) | ||
47 | - | ||
48 | - return app_rec | ||
49 | - | ||
50 | -def cross_validation(recommender): | ||
51 | - metrics = [] | ||
52 | - metrics.append(Precision()) | ||
53 | - metrics.append(Recall()) | ||
54 | - validation = CrossValidation(0.1,10,recommender,metrics) | ||
55 | - validation.run(user) | 33 | +from error import Error |
56 | 34 | ||
57 | if __name__ == '__main__': | 35 | if __name__ == '__main__': |
58 | - cfg = Config() | ||
59 | - rec = set_up_recommender(cfg) | ||
60 | - user = LocalSystem() | ||
61 | - #result.print_result() | ||
62 | - cross_validation(rec) | 36 | + try: |
37 | + cfg = Config() | ||
38 | + rec = Recommender(cfg) | ||
39 | + user = LocalSystem() | ||
40 | + user.maximal_pkg_profile() | ||
41 | + | ||
42 | + begin_time = datetime.datetime.now() | ||
43 | + logging.debug("Cross-validation started at %s" % begin_time) | ||
44 | + | ||
45 | + metrics = [] | ||
46 | + metrics.append(Precision()) | ||
47 | + metrics.append(Recall()) | ||
48 | + validation = CrossValidation(0.3,10,rec,metrics) | ||
49 | + validation.run(user) | ||
50 | + | ||
51 | + end_time = datetime.datetime.now() | ||
52 | + logging.debug("Cross-validation completed at %s" % end_time) | ||
53 | + delta = end_time - begin_time | ||
54 | + logging.info("Time elapsed: %d seconds." % delta.seconds) | ||
55 | + | ||
56 | + except Error: | ||
57 | + logging.critical("Aborting proccess. Use '--debug' for more details.") |
src/evaluation.py
@@ -33,8 +33,7 @@ class Precision(Metric): | @@ -33,8 +33,7 @@ class Precision(Metric): | ||
33 | self.desc = " Precision " | 33 | self.desc = " Precision " |
34 | 34 | ||
35 | def run(self,evaluation): | 35 | def run(self,evaluation): |
36 | - return float(len(evaluation.predicted_real) / | ||
37 | - len(evaluation.predicted_relevant)) | 36 | + return float(len(evaluation.predicted_real))/len(evaluation.predicted_relevant) |
38 | 37 | ||
39 | class Recall(Metric): | 38 | class Recall(Metric): |
40 | """ """ | 39 | """ """ |
@@ -42,8 +41,7 @@ class Recall(Metric): | @@ -42,8 +41,7 @@ class Recall(Metric): | ||
42 | self.desc = " Recall " | 41 | self.desc = " Recall " |
43 | 42 | ||
44 | def run(self,evaluation): | 43 | def run(self,evaluation): |
45 | - return float(len(evaluation.predicted_real) / | ||
46 | - len(evaluation.real_relevant)) | 44 | + return float(len(evaluation.predicted_real))/len(evaluation.real_relevant) |
47 | 45 | ||
48 | class F1(Metric): | 46 | class F1(Metric): |
49 | """ """ | 47 | """ """ |
@@ -89,6 +87,9 @@ class Evaluation: | @@ -89,6 +87,9 @@ class Evaluation: | ||
89 | self.real_relevant = real_result.get_prediction() | 87 | self.real_relevant = real_result.get_prediction() |
90 | self.predicted_real = [v for v in self.predicted_relevant if v in | 88 | self.predicted_real = [v for v in self.predicted_relevant if v in |
91 | self.real_relevant] | 89 | self.real_relevant] |
90 | + print len(self.predicted_relevant) | ||
91 | + print len(self.real_relevant) | ||
92 | + print len(self.predicted_real) | ||
92 | 93 | ||
93 | def run(self,metric): | 94 | def run(self,metric): |
94 | return metric.run(self) | 95 | return metric.run(self) |
@@ -134,8 +135,9 @@ class CrossValidation: | @@ -134,8 +135,9 @@ class CrossValidation: | ||
134 | """ | 135 | """ |
135 | Perform cross-validation. | 136 | Perform cross-validation. |
136 | """ | 137 | """ |
137 | - partition_size = int(len(user.item_score)*self.partition_proportion) | ||
138 | - cross_item_score = user.item_score.copy() | 138 | + cross_item_score = dict.fromkeys(user.pkg_profile,1) |
139 | + partition_size = int(len(cross_item_score)*self.partition_proportion) | ||
140 | + #cross_item_score = user.item_score.copy() | ||
139 | for r in range(self.rounds): | 141 | for r in range(self.rounds): |
140 | round_partition = {} | 142 | round_partition = {} |
141 | for j in range(partition_size): | 143 | for j in range(partition_size): |