Commit 6c99e7cdab0198c100761939e427f40bd18afe41
1 parent
9d0b5695
Exists in
master
and in
1 other branch
Moved method get_errors to base class and implemented RMSE metric.
Showing
1 changed file
with
33 additions
and
18 deletions
Show diff stats
src/evaluation.py
| @@ -18,6 +18,7 @@ | @@ -18,6 +18,7 @@ | ||
| 18 | # You should have received a copy of the GNU General Public License | 18 | # You should have received a copy of the GNU General Public License |
| 19 | # along with this program. If not, see <http://www.gnu.org/licenses/>. | 19 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
| 20 | 20 | ||
| 21 | +import math | ||
| 21 | import random | 22 | import random |
| 22 | from collections import defaultdict | 23 | from collections import defaultdict |
| 23 | import logging | 24 | import logging |
| @@ -30,7 +31,21 @@ class Metric(Singleton): | @@ -30,7 +31,21 @@ class Metric(Singleton): | ||
| 30 | """ | 31 | """ |
| 31 | Base class for metrics. Strategy design pattern. | 32 | Base class for metrics. Strategy design pattern. |
| 32 | """ | 33 | """ |
| 33 | - pass | 34 | + def get_errors(self,evaluation): |
| 35 | + """ | ||
| 36 | + Compute prediction errors. | ||
| 37 | + """ | ||
| 38 | + keys = evaluation.predicted_item_scores.keys() | ||
| 39 | + keys.extend(evaluation.real_item_scores.keys()) | ||
| 40 | + errors = [] | ||
| 41 | + for k in keys: | ||
| 42 | + if k not in evaluation.real_item_scores: | ||
| 43 | + evaluation.real_item_scores[k] = 0.0 | ||
| 44 | + if k not in evaluation.predicted_item_scores: | ||
| 45 | + evaluation.predicted_item_scores[k] = 0.0 | ||
| 46 | + errors.append(float(evaluation.predicted_item_scores[k]- | ||
| 47 | + evaluation.real_item_scores[k])) | ||
| 48 | + return errors | ||
| 34 | 49 | ||
| 35 | class Precision(Metric): | 50 | class Precision(Metric): |
| 36 | """ | 51 | """ |
| @@ -95,22 +110,6 @@ class MAE(Metric): | @@ -95,22 +110,6 @@ class MAE(Metric): | ||
| 95 | """ | 110 | """ |
| 96 | self.desc = " MAE " | 111 | self.desc = " MAE " |
| 97 | 112 | ||
| 98 | - def get_errors(self,evaluation): | ||
| 99 | - """ | ||
| 100 | - Compute prediction errors. | ||
| 101 | - """ | ||
| 102 | - keys = evaluation.predicted_item_scores.keys() | ||
| 103 | - keys.extend(evaluation.real_item_scores.keys()) | ||
| 104 | - errors = [] | ||
| 105 | - for k in keys: | ||
| 106 | - if k not in evaluation.real_item_scores: | ||
| 107 | - evaluation.real_item_scores[k] = 0.0 | ||
| 108 | - if k not in evaluation.predicted_item_scores: | ||
| 109 | - evaluation.predicted_item_scores[k] = 0.0 | ||
| 110 | - errors.append(float(evaluation.predicted_item_scores[k]- | ||
| 111 | - evaluation.real_item_scores[k])) | ||
| 112 | - return errors | ||
| 113 | - | ||
| 114 | def run(self,evaluation): | 113 | def run(self,evaluation): |
| 115 | """ | 114 | """ |
| 116 | Compute metric. | 115 | Compute metric. |
| @@ -118,7 +117,7 @@ class MAE(Metric): | @@ -118,7 +117,7 @@ class MAE(Metric): | ||
| 118 | errors = self.get_errors(evaluation) | 117 | errors = self.get_errors(evaluation) |
| 119 | return sum(errors)/len(errors) | 118 | return sum(errors)/len(errors) |
| 120 | 119 | ||
| 121 | -class MSE(MAE): | 120 | +class MSE(Metric): |
| 122 | """ | 121 | """ |
| 123 | Prediction accuracy metric defined as the mean square error. | 122 | Prediction accuracy metric defined as the mean square error. |
| 124 | """ | 123 | """ |
| @@ -136,6 +135,22 @@ class MSE(MAE): | @@ -136,6 +135,22 @@ class MSE(MAE): | ||
| 136 | square_errors = [pow(x,2) for x in errors] | 135 | square_errors = [pow(x,2) for x in errors] |
| 137 | return sum(square_errors)/len(square_errors) | 136 | return sum(square_errors)/len(square_errors) |
| 138 | 137 | ||
| 138 | +class RMSE(MSE): | ||
| 139 | + """ | ||
| 140 | + Prediction accuracy metric defined as the root mean square error. | ||
| 141 | + """ | ||
| 142 | + def __init__(self): | ||
| 143 | + """ | ||
| 144 | + Set metric description. | ||
| 145 | + """ | ||
| 146 | + self.desc = " RMSE " | ||
| 147 | + | ||
| 148 | + def run(self,evaluation): | ||
| 149 | + """ | ||
| 150 | + Compute metric. | ||
| 151 | + """ | ||
| 152 | + return math.sqrt(MSE.run(evaluation)) | ||
| 153 | + | ||
| 139 | class Coverage(Metric): | 154 | class Coverage(Metric): |
| 140 | """ | 155 | """ |
| 141 | Evaluation metric defined as the percentage of itens covered by the | 156 | Evaluation metric defined as the percentage of itens covered by the |