Commit 6c99e7cdab0198c100761939e427f40bd18afe41

Authored by Tássia Camões Araújo
1 parent 9d0b5695
Exists in master and in 1 other branch add_vagrant

Moved method get_errors to base class and implemented RMSE metric.

Showing 1 changed file with 33 additions and 18 deletions   Show diff stats
src/evaluation.py
... ... @@ -18,6 +18,7 @@
18 18 # You should have received a copy of the GNU General Public License
19 19 # along with this program. If not, see <http://www.gnu.org/licenses/>.
20 20  
  21 +import math
21 22 import random
22 23 from collections import defaultdict
23 24 import logging
... ... @@ -30,7 +31,21 @@ class Metric(Singleton):
30 31 """
31 32 Base class for metrics. Strategy design pattern.
32 33 """
33   - pass
  34 + def get_errors(self,evaluation):
  35 + """
  36 + Compute prediction errors.
  37 + """
  38 + keys = evaluation.predicted_item_scores.keys()
  39 + keys.extend(evaluation.real_item_scores.keys())
  40 + errors = []
  41 + for k in keys:
  42 + if k not in evaluation.real_item_scores:
  43 + evaluation.real_item_scores[k] = 0.0
  44 + if k not in evaluation.predicted_item_scores:
  45 + evaluation.predicted_item_scores[k] = 0.0
  46 + errors.append(float(evaluation.predicted_item_scores[k]-
  47 + evaluation.real_item_scores[k]))
  48 + return errors
34 49  
35 50 class Precision(Metric):
36 51 """
... ... @@ -95,22 +110,6 @@ class MAE(Metric):
95 110 """
96 111 self.desc = " MAE "
97 112  
98   - def get_errors(self,evaluation):
99   - """
100   - Compute prediction errors.
101   - """
102   - keys = evaluation.predicted_item_scores.keys()
103   - keys.extend(evaluation.real_item_scores.keys())
104   - errors = []
105   - for k in keys:
106   - if k not in evaluation.real_item_scores:
107   - evaluation.real_item_scores[k] = 0.0
108   - if k not in evaluation.predicted_item_scores:
109   - evaluation.predicted_item_scores[k] = 0.0
110   - errors.append(float(evaluation.predicted_item_scores[k]-
111   - evaluation.real_item_scores[k]))
112   - return errors
113   -
114 113 def run(self,evaluation):
115 114 """
116 115 Compute metric.
... ... @@ -118,7 +117,7 @@ class MAE(Metric):
118 117 errors = self.get_errors(evaluation)
119 118 return sum(errors)/len(errors)
120 119  
121   -class MSE(MAE):
  120 +class MSE(Metric):
122 121 """
123 122 Prediction accuracy metric defined as the mean square error.
124 123 """
... ... @@ -136,6 +135,22 @@ class MSE(MAE):
136 135 square_errors = [pow(x,2) for x in errors]
137 136 return sum(square_errors)/len(square_errors)
138 137  
  138 +class RMSE(MSE):
  139 + """
  140 + Prediction accuracy metric defined as the root mean square error.
  141 + """
  142 + def __init__(self):
  143 + """
  144 + Set metric description.
  145 + """
  146 + self.desc = " RMSE "
  147 +
  148 + def run(self,evaluation):
  149 + """
  150 + Compute metric.
  151 + """
  152 + return math.sqrt(MSE.run(evaluation))
  153 +
139 154 class Coverage(Metric):
140 155 """
141 156 Evaluation metric defined as the percentage of itens covered by the
... ...