Commit 6c99e7cdab0198c100761939e427f40bd18afe41
1 parent
9d0b5695
Exists in
master
and in
1 other branch
Moved method get_errors to base class and implemented RMSE metric.
Showing
1 changed file
with
33 additions
and
18 deletions
Show diff stats
src/evaluation.py
... | ... | @@ -18,6 +18,7 @@ |
18 | 18 | # You should have received a copy of the GNU General Public License |
19 | 19 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
20 | 20 | |
21 | +import math | |
21 | 22 | import random |
22 | 23 | from collections import defaultdict |
23 | 24 | import logging |
... | ... | @@ -30,7 +31,21 @@ class Metric(Singleton): |
30 | 31 | """ |
31 | 32 | Base class for metrics. Strategy design pattern. |
32 | 33 | """ |
33 | - pass | |
34 | + def get_errors(self,evaluation): | |
35 | + """ | |
36 | + Compute prediction errors. | |
37 | + """ | |
38 | + keys = evaluation.predicted_item_scores.keys() | |
39 | + keys.extend(evaluation.real_item_scores.keys()) | |
40 | + errors = [] | |
41 | + for k in keys: | |
42 | + if k not in evaluation.real_item_scores: | |
43 | + evaluation.real_item_scores[k] = 0.0 | |
44 | + if k not in evaluation.predicted_item_scores: | |
45 | + evaluation.predicted_item_scores[k] = 0.0 | |
46 | + errors.append(float(evaluation.predicted_item_scores[k]- | |
47 | + evaluation.real_item_scores[k])) | |
48 | + return errors | |
34 | 49 | |
35 | 50 | class Precision(Metric): |
36 | 51 | """ |
... | ... | @@ -95,22 +110,6 @@ class MAE(Metric): |
95 | 110 | """ |
96 | 111 | self.desc = " MAE " |
97 | 112 | |
98 | - def get_errors(self,evaluation): | |
99 | - """ | |
100 | - Compute prediction errors. | |
101 | - """ | |
102 | - keys = evaluation.predicted_item_scores.keys() | |
103 | - keys.extend(evaluation.real_item_scores.keys()) | |
104 | - errors = [] | |
105 | - for k in keys: | |
106 | - if k not in evaluation.real_item_scores: | |
107 | - evaluation.real_item_scores[k] = 0.0 | |
108 | - if k not in evaluation.predicted_item_scores: | |
109 | - evaluation.predicted_item_scores[k] = 0.0 | |
110 | - errors.append(float(evaluation.predicted_item_scores[k]- | |
111 | - evaluation.real_item_scores[k])) | |
112 | - return errors | |
113 | - | |
114 | 113 | def run(self,evaluation): |
115 | 114 | """ |
116 | 115 | Compute metric. |
... | ... | @@ -118,7 +117,7 @@ class MAE(Metric): |
118 | 117 | errors = self.get_errors(evaluation) |
119 | 118 | return sum(errors)/len(errors) |
120 | 119 | |
121 | -class MSE(MAE): | |
120 | +class MSE(Metric): | |
122 | 121 | """ |
123 | 122 | Prediction accuracy metric defined as the mean square error. |
124 | 123 | """ |
... | ... | @@ -136,6 +135,22 @@ class MSE(MAE): |
136 | 135 | square_errors = [pow(x,2) for x in errors] |
137 | 136 | return sum(square_errors)/len(square_errors) |
138 | 137 | |
138 | +class RMSE(MSE): | |
139 | + """ | |
140 | + Prediction accuracy metric defined as the root mean square error. | |
141 | + """ | |
142 | + def __init__(self): | |
143 | + """ | |
144 | + Set metric description. | |
145 | + """ | |
146 | + self.desc = " RMSE " | |
147 | + | |
148 | + def run(self,evaluation): | |
149 | + """ | |
150 | + Compute metric. | |
151 | + """ | |
152 | + return math.sqrt(MSE.run(evaluation)) | |
153 | + | |
139 | 154 | class Coverage(Metric): |
140 | 155 | """ |
141 | 156 | Evaluation metric defined as the percentage of itens covered by the | ... | ... |