Commit 6c99e7cdab0198c100761939e427f40bd18afe41
1 parent
9d0b5695
Exists in
master
and in
1 other branch
Moved method get_errors to base class and implemented RMSE metric.
Showing
1 changed file
with
33 additions
and
18 deletions
Show diff stats
src/evaluation.py
| ... | ... | @@ -18,6 +18,7 @@ |
| 18 | 18 | # You should have received a copy of the GNU General Public License |
| 19 | 19 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
| 20 | 20 | |
| 21 | +import math | |
| 21 | 22 | import random |
| 22 | 23 | from collections import defaultdict |
| 23 | 24 | import logging |
| ... | ... | @@ -30,7 +31,21 @@ class Metric(Singleton): |
| 30 | 31 | """ |
| 31 | 32 | Base class for metrics. Strategy design pattern. |
| 32 | 33 | """ |
| 33 | - pass | |
| 34 | + def get_errors(self,evaluation): | |
| 35 | + """ | |
| 36 | + Compute prediction errors. | |
| 37 | + """ | |
| 38 | + keys = evaluation.predicted_item_scores.keys() | |
| 39 | + keys.extend(evaluation.real_item_scores.keys()) | |
| 40 | + errors = [] | |
| 41 | + for k in keys: | |
| 42 | + if k not in evaluation.real_item_scores: | |
| 43 | + evaluation.real_item_scores[k] = 0.0 | |
| 44 | + if k not in evaluation.predicted_item_scores: | |
| 45 | + evaluation.predicted_item_scores[k] = 0.0 | |
| 46 | + errors.append(float(evaluation.predicted_item_scores[k]- | |
| 47 | + evaluation.real_item_scores[k])) | |
| 48 | + return errors | |
| 34 | 49 | |
| 35 | 50 | class Precision(Metric): |
| 36 | 51 | """ |
| ... | ... | @@ -95,22 +110,6 @@ class MAE(Metric): |
| 95 | 110 | """ |
| 96 | 111 | self.desc = " MAE " |
| 97 | 112 | |
| 98 | - def get_errors(self,evaluation): | |
| 99 | - """ | |
| 100 | - Compute prediction errors. | |
| 101 | - """ | |
| 102 | - keys = evaluation.predicted_item_scores.keys() | |
| 103 | - keys.extend(evaluation.real_item_scores.keys()) | |
| 104 | - errors = [] | |
| 105 | - for k in keys: | |
| 106 | - if k not in evaluation.real_item_scores: | |
| 107 | - evaluation.real_item_scores[k] = 0.0 | |
| 108 | - if k not in evaluation.predicted_item_scores: | |
| 109 | - evaluation.predicted_item_scores[k] = 0.0 | |
| 110 | - errors.append(float(evaluation.predicted_item_scores[k]- | |
| 111 | - evaluation.real_item_scores[k])) | |
| 112 | - return errors | |
| 113 | - | |
| 114 | 113 | def run(self,evaluation): |
| 115 | 114 | """ |
| 116 | 115 | Compute metric. |
| ... | ... | @@ -118,7 +117,7 @@ class MAE(Metric): |
| 118 | 117 | errors = self.get_errors(evaluation) |
| 119 | 118 | return sum(errors)/len(errors) |
| 120 | 119 | |
| 121 | -class MSE(MAE): | |
| 120 | +class MSE(Metric): | |
| 122 | 121 | """ |
| 123 | 122 | Prediction accuracy metric defined as the mean square error. |
| 124 | 123 | """ |
| ... | ... | @@ -136,6 +135,22 @@ class MSE(MAE): |
| 136 | 135 | square_errors = [pow(x,2) for x in errors] |
| 137 | 136 | return sum(square_errors)/len(square_errors) |
| 138 | 137 | |
| 138 | +class RMSE(MSE): | |
| 139 | + """ | |
| 140 | + Prediction accuracy metric defined as the root mean square error. | |
| 141 | + """ | |
| 142 | + def __init__(self): | |
| 143 | + """ | |
| 144 | + Set metric description. | |
| 145 | + """ | |
| 146 | + self.desc = " RMSE " | |
| 147 | + | |
| 148 | + def run(self,evaluation): | |
| 149 | + """ | |
| 150 | + Compute metric. | |
| 151 | + """ | |
| 152 | + return math.sqrt(MSE.run(evaluation)) | |
| 153 | + | |
| 139 | 154 | class Coverage(Metric): |
| 140 | 155 | """ |
| 141 | 156 | Evaluation metric defined as the percentage of itens covered by the | ... | ... |