Skip to content
Open

Mse #31

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions metrics/MSE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np
import random
import logging

def mse(X, obs):
'''
Retourne la MSE
'''
random_idx = random.sample(range(X.shape[0]), 1)[0]
_mse = np.nanmean(((X[random_idx] - obs.squeeze())**2),axis=(-2,-1))
return _mse
20 changes: 20 additions & 0 deletions metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metrics import CRPS_calc
from metrics import area_proportion as ap
from metrics import object_detection as obj
from metrics import MSE

from metrics.metrics import Metric, PreprocessCondObs, PreprocessDist, PreprocessStandalone

Expand Down Expand Up @@ -474,6 +475,25 @@ def _calculateCore(self, processed_data):

return GM.relative_std_diff(real_data,fake_data)


#####################################################################
############################ Determinstic metrics ###################
#####################################################################

class mse(PreprocessCondObs):
def __init__(self, *args, **kwargs):
super().__init__(isBatched=True)

def _calculateCore(self, processed_data):
if not self.isOnReal:
exp_data = processed_data['fake_data']
else:
exp_data = processed_data['real_data']
obs_data = processed_data['obs_data']
return MSE.mse(exp_data, obs_data)



#####################################################################
######################################################
#####################################################################
3 changes: 2 additions & 1 deletion metrics/wind_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def computeWindDir(U, V, xRef=None, yRef=None, proj=None):
@rtype: tuple
@return: vitesse (m/s) et direction du vent exprimée en degrés météo (0 = vent du nord).
"""
ff = np.sqrt(U * U + V * V)

ff = np.sqrt(np.maximum(U * U + V * V, np.zeros_like(U)))

dd3 = (180 + 180 / np.pi * np.arctan2(U, V)) % 360
# logging.debug(dd3)
Expand Down
12 changes: 7 additions & 5 deletions preprocess/rrPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def init_normalization(self):
normalization_type = self.normalization["type"]
if normalization_type == "mean":
means, stds = self.load_stat_files(normalization_type, "mean", "std")
logging.debug(f"stat constants {means, stds}")
return None, None, means, stds
elif normalization_type == "minmax":
maxs, mins = self.load_stat_files(normalization_type, "max", "min")
Expand All @@ -64,16 +65,16 @@ def load_stat_files(self, normalization_type, str1, str2):
std_or_min_filename += "_ppx"
mean_or_max_filename += ".npy"
std_or_min_filename += ".npy"
logging.debug(f"{mean_or_max_filename}", f"{std_or_min_filename}")
logging.debug(f"Normalization set to {normalization_type}")
# logging.debug(f"{mean_or_max_filename}", f"{std_or_min_filename}")
# logging.debug(f"Normalization set to {normalization_type}")
stat_folder = self.config_data["stat_folder"]
file_path = os.path.join(self.config_data["real_data_dir"], stat_folder, mean_or_max_filename)
means_or_maxs = np.load(file_path).astype('float32')
logging.debug(f"{str1} file found, {means_or_maxs.shape}")
# logging.debug(f"{str1} file found, {means_or_maxs.shape}")

file_path = os.path.join(self.config_data["real_data_dir"], stat_folder, std_or_min_filename)
stds_or_mins = np.load(file_path).astype('float32')
logging.debug(f"{str2} file found, {stds_or_mins.shape}")
# logging.debug(f"{str2} file found, {stds_or_mins.shape}")
return means_or_maxs, stds_or_mins

def detransform(self, data):
Expand Down Expand Up @@ -178,4 +179,5 @@ def __init__(self, config_data, sizeH, sizeW, variables, **kwargs):
super().__init__(config_data, sizeH, sizeW, variables, **kwargs)

def process_batch(self, batch):
return self.detransform(batch)
res = self.detransform(batch)
return res
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PyYaml==6.0.1
torch==2.2.0+cu121
torch==2.2.0
numpy==1.26.4
tqdm==4.66.1
pandas==2.2.0
Expand All @@ -9,4 +9,5 @@ pyproj==3.6.1
properscoring==0.1
geopy==2.4.1
astropy==6.0.0
scikit-image
CRPS