Commit 7c99a2c6530cf0c1c2fde7de9c6559be56e2fba0

Authored by Tássia Camões Araújo
1 parent c93445bc
Exists in master and in 1 other branch add_vagrant

Updating cross-validation to new classes structure and user profile.

Showing 2 changed files with 33 additions and 36 deletions   Show diff stats
src/cross_validation.py
@@ -20,6 +20,8 @@ @@ -20,6 +20,8 @@
20 import os 20 import os
21 import sys 21 import sys
22 import logging 22 import logging
  23 +import datetime
  24 +from datetime import timedelta
23 25
24 from config import * 26 from config import *
25 from data import * 27 from data import *
@@ -28,35 +30,28 @@ from similarity_measure import * @@ -28,35 +30,28 @@ from similarity_measure import *
28 from recommender import * 30 from recommender import *
29 from strategy import * 31 from strategy import *
30 from user import * 32 from user import *
31 -  
32 -def set_up_recommender(cfg):  
33 - if cfg.strategy == "cta":  
34 - axi_db = xapian.Database(cfg.axi)  
35 - app_rec = Recommender(axi_db)  
36 - app_rec.set_strategy(AxiContentBasedStrategy())  
37 -  
38 - elif cfg.strategy == "ct":  
39 - debtags_db = DebtagsDB(cfg.tags_db)  
40 - if not debtags_db.load():  
41 - logging.error("Could not load DebtagsDB from %s." % cfg.tags_db)  
42 - raise Error  
43 - debtags_index = DebtagsIndex(os.path.expanduser(cfg.tags_index))  
44 - debtags_index.load(debtags_db,cfg.reindex)  
45 - app_rec = Recommender(debtags_index)  
46 - app_rec.set_strategy(ContentBasedStrategy())  
47 -  
48 - return app_rec  
49 -  
50 -def cross_validation(recommender):  
51 - metrics = []  
52 - metrics.append(Precision())  
53 - metrics.append(Recall())  
54 - validation = CrossValidation(0.1,10,recommender,metrics)  
55 - validation.run(user) 33 +from error import Error
56 34
57 if __name__ == '__main__': 35 if __name__ == '__main__':
58 - cfg = Config()  
59 - rec = set_up_recommender(cfg)  
60 - user = LocalSystem()  
61 - #result.print_result()  
62 - cross_validation(rec) 36 + try:
  37 + cfg = Config()
  38 + rec = Recommender(cfg)
  39 + user = LocalSystem()
  40 + user.maximal_pkg_profile()
  41 +
  42 + begin_time = datetime.datetime.now()
  43 + logging.debug("Cross-validation started at %s" % begin_time)
  44 +
  45 + metrics = []
  46 + metrics.append(Precision())
  47 + metrics.append(Recall())
  48 + validation = CrossValidation(0.3,10,rec,metrics)
  49 + validation.run(user)
  50 +
  51 + end_time = datetime.datetime.now()
  52 + logging.debug("Cross-validation completed at %s" % end_time)
  53 + delta = end_time - begin_time
  54 + logging.info("Time elapsed: %d seconds." % delta.seconds)
  55 +
  56 + except Error:
  57 + logging.critical("Aborting proccess. Use '--debug' for more details.")
src/evaluation.py
@@ -33,8 +33,7 @@ class Precision(Metric): @@ -33,8 +33,7 @@ class Precision(Metric):
33 self.desc = " Precision " 33 self.desc = " Precision "
34 34
35 def run(self,evaluation): 35 def run(self,evaluation):
36 - return float(len(evaluation.predicted_real) /  
37 - len(evaluation.predicted_relevant)) 36 + return float(len(evaluation.predicted_real))/len(evaluation.predicted_relevant)
38 37
39 class Recall(Metric): 38 class Recall(Metric):
40 """ """ 39 """ """
@@ -42,8 +41,7 @@ class Recall(Metric): @@ -42,8 +41,7 @@ class Recall(Metric):
42 self.desc = " Recall " 41 self.desc = " Recall "
43 42
44 def run(self,evaluation): 43 def run(self,evaluation):
45 - return float(len(evaluation.predicted_real) /  
46 - len(evaluation.real_relevant)) 44 + return float(len(evaluation.predicted_real))/len(evaluation.real_relevant)
47 45
48 class F1(Metric): 46 class F1(Metric):
49 """ """ 47 """ """
@@ -89,6 +87,9 @@ class Evaluation: @@ -89,6 +87,9 @@ class Evaluation:
89 self.real_relevant = real_result.get_prediction() 87 self.real_relevant = real_result.get_prediction()
90 self.predicted_real = [v for v in self.predicted_relevant if v in 88 self.predicted_real = [v for v in self.predicted_relevant if v in
91 self.real_relevant] 89 self.real_relevant]
  90 + print len(self.predicted_relevant)
  91 + print len(self.real_relevant)
  92 + print len(self.predicted_real)
92 93
93 def run(self,metric): 94 def run(self,metric):
94 return metric.run(self) 95 return metric.run(self)
@@ -134,8 +135,9 @@ class CrossValidation: @@ -134,8 +135,9 @@ class CrossValidation:
134 """ 135 """
135 Perform cross-validation. 136 Perform cross-validation.
136 """ 137 """
137 - partition_size = int(len(user.item_score)*self.partition_proportion)  
138 - cross_item_score = user.item_score.copy() 138 + cross_item_score = dict.fromkeys(user.pkg_profile,1)
  139 + partition_size = int(len(cross_item_score)*self.partition_proportion)
  140 + #cross_item_score = user.item_score.copy()
139 for r in range(self.rounds): 141 for r in range(self.rounds):
140 round_partition = {} 142 round_partition = {}
141 for j in range(partition_size): 143 for j in range(partition_size):