Commit 6c99e7cdab0198c100761939e427f40bd18afe41

Authored by Tássia Camões Araújo
1 parent 9d0b5695
Exists in master and in 1 other branch add_vagrant

Moved method get_errors to base class and implemented RMSE metric.

Showing 1 changed file with 33 additions and 18 deletions   Show diff stats
src/evaluation.py
@@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
18 # You should have received a copy of the GNU General Public License 18 # You should have received a copy of the GNU General Public License
19 # along with this program. If not, see <http://www.gnu.org/licenses/>. 19 # along with this program. If not, see <http://www.gnu.org/licenses/>.
20 20
  21 +import math
21 import random 22 import random
22 from collections import defaultdict 23 from collections import defaultdict
23 import logging 24 import logging
@@ -30,7 +31,21 @@ class Metric(Singleton): @@ -30,7 +31,21 @@ class Metric(Singleton):
30 """ 31 """
31 Base class for metrics. Strategy design pattern. 32 Base class for metrics. Strategy design pattern.
32 """ 33 """
33 - pass 34 + def get_errors(self,evaluation):
  35 + """
  36 + Compute prediction errors.
  37 + """
  38 + keys = evaluation.predicted_item_scores.keys()
  39 + keys.extend(evaluation.real_item_scores.keys())
  40 + errors = []
  41 + for k in keys:
  42 + if k not in evaluation.real_item_scores:
  43 + evaluation.real_item_scores[k] = 0.0
  44 + if k not in evaluation.predicted_item_scores:
  45 + evaluation.predicted_item_scores[k] = 0.0
  46 + errors.append(float(evaluation.predicted_item_scores[k]-
  47 + evaluation.real_item_scores[k]))
  48 + return errors
34 49
35 class Precision(Metric): 50 class Precision(Metric):
36 """ 51 """
@@ -95,22 +110,6 @@ class MAE(Metric): @@ -95,22 +110,6 @@ class MAE(Metric):
95 """ 110 """
96 self.desc = " MAE " 111 self.desc = " MAE "
97 112
98 - def get_errors(self,evaluation):  
99 - """  
100 - Compute prediction errors.  
101 - """  
102 - keys = evaluation.predicted_item_scores.keys()  
103 - keys.extend(evaluation.real_item_scores.keys())  
104 - errors = []  
105 - for k in keys:  
106 - if k not in evaluation.real_item_scores:  
107 - evaluation.real_item_scores[k] = 0.0  
108 - if k not in evaluation.predicted_item_scores:  
109 - evaluation.predicted_item_scores[k] = 0.0  
110 - errors.append(float(evaluation.predicted_item_scores[k]-  
111 - evaluation.real_item_scores[k]))  
112 - return errors  
113 -  
114 def run(self,evaluation): 113 def run(self,evaluation):
115 """ 114 """
116 Compute metric. 115 Compute metric.
@@ -118,7 +117,7 @@ class MAE(Metric): @@ -118,7 +117,7 @@ class MAE(Metric):
118 errors = self.get_errors(evaluation) 117 errors = self.get_errors(evaluation)
119 return sum(errors)/len(errors) 118 return sum(errors)/len(errors)
120 119
121 -class MSE(MAE): 120 +class MSE(Metric):
122 """ 121 """
123 Prediction accuracy metric defined as the mean square error. 122 Prediction accuracy metric defined as the mean square error.
124 """ 123 """
@@ -136,6 +135,22 @@ class MSE(MAE): @@ -136,6 +135,22 @@ class MSE(MAE):
136 square_errors = [pow(x,2) for x in errors] 135 square_errors = [pow(x,2) for x in errors]
137 return sum(square_errors)/len(square_errors) 136 return sum(square_errors)/len(square_errors)
138 137
  138 +class RMSE(MSE):
  139 + """
  140 + Prediction accuracy metric defined as the root mean square error.
  141 + """
  142 + def __init__(self):
  143 + """
  144 + Set metric description.
  145 + """
  146 + self.desc = " RMSE "
  147 +
  148 + def run(self,evaluation):
  149 + """
  150 + Compute metric.
  151 + """
  152 + return math.sqrt(MSE.run(evaluation))
  153 +
139 class Coverage(Metric): 154 class Coverage(Metric):
140 """ 155 """
141 Evaluation metric defined as the percentage of itens covered by the 156 Evaluation metric defined as the percentage of itens covered by the