Commit 7c99a2c6530cf0c1c2fde7de9c6559be56e2fba0

Authored by Tássia Camões Araújo
1 parent c93445bc
Exists in master and in 1 other branch add_vagrant

Updating cross-validation to new classes structure and user profile.

Showing 2 changed files with 33 additions and 36 deletions   Show diff stats
src/cross_validation.py
... ... @@ -20,6 +20,8 @@
20 20 import os
21 21 import sys
22 22 import logging
  23 +import datetime
  24 +from datetime import timedelta
23 25  
24 26 from config import *
25 27 from data import *
... ... @@ -28,35 +30,28 @@ from similarity_measure import *
28 30 from recommender import *
29 31 from strategy import *
30 32 from user import *
31   -
32   -def set_up_recommender(cfg):
33   - if cfg.strategy == "cta":
34   - axi_db = xapian.Database(cfg.axi)
35   - app_rec = Recommender(axi_db)
36   - app_rec.set_strategy(AxiContentBasedStrategy())
37   -
38   - elif cfg.strategy == "ct":
39   - debtags_db = DebtagsDB(cfg.tags_db)
40   - if not debtags_db.load():
41   - logging.error("Could not load DebtagsDB from %s." % cfg.tags_db)
42   - raise Error
43   - debtags_index = DebtagsIndex(os.path.expanduser(cfg.tags_index))
44   - debtags_index.load(debtags_db,cfg.reindex)
45   - app_rec = Recommender(debtags_index)
46   - app_rec.set_strategy(ContentBasedStrategy())
47   -
48   - return app_rec
49   -
50   -def cross_validation(recommender):
51   - metrics = []
52   - metrics.append(Precision())
53   - metrics.append(Recall())
54   - validation = CrossValidation(0.1,10,recommender,metrics)
55   - validation.run(user)
  33 +from error import Error
56 34  
57 35 if __name__ == '__main__':
58   - cfg = Config()
59   - rec = set_up_recommender(cfg)
60   - user = LocalSystem()
61   - #result.print_result()
62   - cross_validation(rec)
  36 + try:
  37 + cfg = Config()
  38 + rec = Recommender(cfg)
  39 + user = LocalSystem()
  40 + user.maximal_pkg_profile()
  41 +
  42 + begin_time = datetime.datetime.now()
  43 + logging.debug("Cross-validation started at %s" % begin_time)
  44 +
  45 + metrics = []
  46 + metrics.append(Precision())
  47 + metrics.append(Recall())
  48 + validation = CrossValidation(0.3,10,rec,metrics)
  49 + validation.run(user)
  50 +
  51 + end_time = datetime.datetime.now()
  52 + logging.debug("Cross-validation completed at %s" % end_time)
  53 + delta = end_time - begin_time
  54 + logging.info("Time elapsed: %d seconds." % delta.seconds)
  55 +
  56 + except Error:
  57 + logging.critical("Aborting proccess. Use '--debug' for more details.")
... ...
src/evaluation.py
... ... @@ -33,8 +33,7 @@ class Precision(Metric):
33 33 self.desc = " Precision "
34 34  
35 35 def run(self,evaluation):
36   - return float(len(evaluation.predicted_real) /
37   - len(evaluation.predicted_relevant))
  36 + return float(len(evaluation.predicted_real))/len(evaluation.predicted_relevant)
38 37  
39 38 class Recall(Metric):
40 39 """ """
... ... @@ -42,8 +41,7 @@ class Recall(Metric):
42 41 self.desc = " Recall "
43 42  
44 43 def run(self,evaluation):
45   - return float(len(evaluation.predicted_real) /
46   - len(evaluation.real_relevant))
  44 + return float(len(evaluation.predicted_real))/len(evaluation.real_relevant)
47 45  
48 46 class F1(Metric):
49 47 """ """
... ... @@ -89,6 +87,9 @@ class Evaluation:
89 87 self.real_relevant = real_result.get_prediction()
90 88 self.predicted_real = [v for v in self.predicted_relevant if v in
91 89 self.real_relevant]
  90 + print len(self.predicted_relevant)
  91 + print len(self.real_relevant)
  92 + print len(self.predicted_real)
92 93  
93 94 def run(self,metric):
94 95 return metric.run(self)
... ... @@ -134,8 +135,9 @@ class CrossValidation:
134 135 """
135 136 Perform cross-validation.
136 137 """
137   - partition_size = int(len(user.item_score)*self.partition_proportion)
138   - cross_item_score = user.item_score.copy()
  138 + cross_item_score = dict.fromkeys(user.pkg_profile,1)
  139 + partition_size = int(len(cross_item_score)*self.partition_proportion)
  140 + #cross_item_score = user.item_score.copy()
139 141 for r in range(self.rounds):
140 142 round_partition = {}
141 143 for j in range(partition_size):
... ...