Package fieldpy :: Package core :: Module extra_classes
[hide private]
[frames] | no frames]

Source Code for Module fieldpy.core.extra_classes

  1  """ 
  2  Just a few helper classes. 
  3  """ 
  4  from __future__ import division 
  5   
  6  import cPickle as pickle 
  7  import pprint 
  8   
  9  import pylab as plt 
 10  import numpy as np 
 11  import scipy.signal as signal 
 12  import scipy.integrate 
 13   
 14   
 15  import fieldpy.core.helper_fns as helper_fns 
 16   
 17  # try cython imports 
 18  try: 
 19      import pyximport; pyximport.install(setup_args={'include_dirs':[np.get_include()]}) 
 20  except: 
 21      pass 
 22   
 23  try: 
 24      from fieldpy.core.helper_fns_cython import filter_and_resample 
 25      from fieldpy.core.helper_fns_cython import get_ind_closest 
 26      cython_module = True 
 27  except: 
 28      from fieldpy.core.helper_fns import filter_and_resample_slow as filter_and_resample 
 29      cython_module = False 
 30   
31 -class Data(object):
32 """Parent class for all kinds of data. The data itself should be held in 33 self.data as a masked record array (masked np.ma.core.MaskedArray) 34 35 @todo: at the moment only working for 1D data, maybe extend it to N-D 36 """
37 - def __init__(self):
38 """ 39 Sets up the basic structure 40 41 @todo: introduce a way to handle units 42 """ 43 #: where the data is held 44 self.data = np.ma.array([]) 45 #: where the metadata is held 46 self.md = Metadata() 47 #: get revision of fieldpy package with which this instance was created 48 self.md.fieldpy_git_revision = helper_fns.get_current_fieldpy_git_hash()
49
50 - def check(self):
51 """ 52 Do integrity checks; should be invoked at the end of __init__ 53 [NOT IMPLEMENTED] 54 """ 55 pass
56
57 - def pickle(self, filename):
58 """ 59 Pickel self. 60 61 @type filename: string 62 @param filename: the name of the picklefile (reccomended to use .pkl extension) 63 """ 64 with open(filename, 'wb') as fil: 65 pickle.dump(self, fil, protocol=pickle.HIGHEST_PROTOCOL)
66
67 - def mask_value(self, field, value):
68 """ 69 Filters out the given value of from the given field by masking them. 70 71 @type field: string 72 @param field: the name of the self.data field 73 74 @type value: ? 75 @param value: the value which should be masked 76 """ 77 # find value and mask them 78 self.data[field][self.data[field]==value] = np.ma.masked
79
80 - def mask_jumps(self, field, jump_size):
81 """ 82 Maskes datapoint for which the absolute difference between 83 them and the point before is more then jump_size. 84 85 @type field: string 86 @param field: the name of the self.data field 87 88 @type jump_size: ? 89 @param jump_size: The size of jump which is masked 90 """ 91 absdiff = np.abs(np.diff(self.data[field])) 92 self.data[field][absdiff>jump_size] = np.ma.masked
93
94 - def mask_if_true(self, field, bool_fn):
95 """ 96 Masks all the values of field where bool_fn is True. 97 98 @type field: string 99 @param field: the name of the self.data field 100 101 @type bool_fn: function 102 @param bool_fn: function which takes self.data['field'] as argument 103 and returns a boolean 104 """ 105 self.data[field][bool_fn(self.data[field])] = np.ma.masked
106
107 - def mask_jumps_one_dir(self, field, jump_size):
108 """ 109 Maskes datapoint for which the difference between them and the 110 point before is more/less then jump_size, for jump_size 111 positive/negative. 112 113 @type field: string 114 @param field: the name of the self.data field 115 116 @type jump_size: ? 117 @param jump_size: The size of jump which is filtered 118 """ 119 diff = np.diff(self.data[field]) 120 jump_size = -jump_size 121 if jump_size>0: 122 self.data[field][diff>jump_size] = np.ma.masked 123 else: 124 self.data[field][diff<jump_size] = np.ma.masked
125
126 - def filter_by_freq(self, field, output_field, sample_rate, cutoff, bandwidth=None, mode='lowpass'):
127 """ Perform filtering with a windowed sinc frequency-domain 128 filter. Can be used for either lowpass or highpass filtering, 129 useful for removing rapid signal variations or long-wavelength 130 trends, respectively. 131 132 The call signature is complicated and should be streamlined. 133 134 @type field: string 135 @param field: the name of the self.data field to filter 136 137 @type output_field: string 138 @param output_field: the name of the output field in self.data 139 140 @type sample_rate: float 141 @param sample_rate: the time or distance per sample, i.e. 142 sampling interval 143 144 @type cutoff: float 145 @param cutoff: the e-folding cutoff frequency or wavenumber 146 147 @type bandwidth: float 148 @param bandwidth: transition bandwidth (narrow means sharper 149 response at the expense of a longer filtering kernel and 150 slower convolution) 151 152 @type mode: str 153 @param mode: either 'lowpass' or 'highpass' 154 """ 155 156 if bandwidth is None: 157 # Guess a reasonable filter bandwidth 158 bandwidth = 0.1 * cutoff 159 160 # Compute filter parameters 161 bw = bandwidth * sample_rate 162 fc = cutoff * sample_rate 163 N = 4/bw 164 if N % 2 == 0: N += 1 165 166 # Generate the kernel 167 kernel = np.sinc(np.arange(-(N-1)/2, (N+1)/2) * fc) \ 168 * signal.firwin(N, fc, window='blackman') 169 kernel = kernel / sum(kernel) 170 171 if mode == 'highpass': 172 # Spectral inversion 173 kernel = -kernel 174 kernel[(N-1)/2] += 1.0 175 176 # Check kernel length 177 if len(kernel)>len(self.data[field]): 178 raise ValueError( 179 "Kernel {0} is longer than input array {1}".format( 180 len(kernel), len(self.data[field]))) 181 182 # Convolve data field 183 filt_data = np.convolve(kernel, self.data[field], mode='same') 184 185 # Add it to data recordarray 186 try: 187 self.data[output_field] = filt 188 except: 189 self.data = helper_fns.append_a2mrec(self.data, filt, output_field) 190 return
191
192 - def filter_low_pass(self, field, output_field, 193 bandwidth, sample_rate, cutoff_freq):
194 """ 195 This implements a sinc filter which is a low pass filter 196 (i.e. a step filter in the frequency domain). 197 198 @type field: string 199 @param field: the name of the self.data field 200 201 @type output_field: string 202 @param output_field: the name of the output field of self.data 203 204 @type bandwidth: float 205 @param bandwidth: is the precision of the filter narrower 206 bandwidth would seem better, but increases 207 numerical artefacts. Set to a few times 208 more than what you wanna filter 209 210 @type sample_rate: float 211 @param sample_rate: is the sampling interval 212 213 @param cutoff_freq: float 214 @param cutoff_freq: is the centre frequency to cut off at 215 example: you have day-long events and 216 hour-long noise so use a cut-off period of 217 a few hours 218 219 """ 220 221 # bandwidth is the precision of the filter 222 # narrower bandwidth would seem better, but increases numerical artefacts 223 # sample_rate is the sampling interval 224 bw = bandwidth * sample_rate 225 226 # cutoff_frequency is the centre frequency to cut off at 227 # example: you have day-long events and hour-long noise 228 # so use a cut-off period of a few hours 229 fc = cutoff_freq * sample_rate 230 231 N = 4//bw 232 # Make sure the kernel length is odd 233 if N % 2 == 0: N += 1 234 235 #print sample_rate, fc, bw, N 236 237 # Generate the kernel by multiplying a truncated sinc function by some window (blackman is typically used) 238 kernel = np.sinc(np.arange(-(N-1)/2, (N+1)/2) * fc) \ 239 * signal.firwin(N, fc, window='blackman') 240 # normalize to preserve signal amplitudes 241 kernel /= kernel.sum() 242 243 if len(kernel)>len(self.data[field]): 244 raise ValueError('Kernel is longer than input array: len(kernel)=%i' % len(kernel)) 245 246 filt = np.convolve(self.data[field], kernel, mode='same') 247 try: 248 self.data[output_field] = filt 249 except: 250 self.data = helper_fns.append_a2mrec(self.data, filt, output_field)
251 252 253 # def filter_moving_average(self, field, output_field, numtaps, cutoff): 254 # # boxcar filter, should be ok up to 6 sample window 255 # window = 4. 256 # kernel = np.ones(4)/4 257 # out = np.convolve(kernel, self.data[field], mode='same') 258
259 -class TimeSeries(Data):
260 """ 261 A class to hold time series data. In self.data['time'] has to be 262 the monotonically increasing times. 263 """
264 - def __init__(self):
265 super(TimeSeries, self).__init__()
266
267 - def check(self):
268 super(TimeSeries, self).check() 269 self.check_monotonically_inc_time()
270
272 """ 273 Checks whether the times in self.data['time'] are 274 monotonically increasing. Raises an error if they are not. 275 """ 276 if np.any(np.diff(self.data['time'])<=0): 277 raise ValueError("self.data['time'] is not stricktly increasing!")
278
280 """ 281 Checks whether sampling interval in self.data['time'] is 282 constant. Raises an error if they are not. 283 """ 284 self.check_monotonically_inc_time() 285 samp_int = self.data['time'][1]-self.data['time'][0] 286 if np.any(1e-10*samp_int < np.abs(np.diff(self.data['time'])-samp_int)): 287 raise ValueError("The sampling interval in self.data['time'] is not constant.")
288
289 - def plot_ts(self, var_name=None, ax=None, fmt='-', **plot_kwargs):
290 """ 291 Plots the time series data for var_name. 292 293 @type var_name: string 294 @param var_name: string of variable to be plotted 295 296 @type ax: axes object 297 @param ax: axes to be plotted into (default: make a new figure) 298 299 @type fmt: string 300 @param fmt: format string to be passed to the plotter (default '-') 301 302 @param plot_kwargs: keywords arguments for plt.plot function 303 """ 304 if ax is None: 305 fig = plt.figure() 306 ax = fig.add_subplot(111) 307 if var_name is None: 308 var_name = self.data.dtype.names[1] 309 310 ax.plot(self.data['time'], self.data[var_name], fmt, **plot_kwargs) 311 ax.set_xlabel('Time') 312 ax.set_ylabel(var_name) 313 314 return ax
315
316 - def plot_date(self, var_name=None, ax=None, fmt='-', **plot_kwargs):
317 """ 318 Plots the time series data for var_name with plt.plot_date. 319 320 @type var_name: string 321 @param var_name: string of variable to be plotted 322 323 @type ax: axes object 324 @param ax: axes to be plotted into (default: make a new figure) 325 326 @type fmt: string 327 @param fmt: format string to be passed to the plotter (default '-') 328 329 @param plot_kwargs: keywords arguments for plt.plot_date function 330 """ 331 if ax is None: 332 fig = plt.figure() 333 ax = fig.add_subplot(111) 334 if var_name is None: 335 var_name = self.data.dtype.names[1] 336 337 # rotate date labels 338 labels = ax.get_xticklabels() 339 plt.setp(labels, rotation=45) 340 ax.plot_date(self.data['time'], self.data[var_name], fmt, **plot_kwargs) 341 ax.set_xlabel('Time') 342 ax.set_ylabel(var_name) 343 344 return ax
345
346 - def time_filter_gauss(self, field, output_field, time_window=0.021):
347 """ 348 Filters the self.data[field] with a weighted moving average filter. 349 350 Appends or over-writes the field output_fild to self.data 351 352 @type field: string 353 @param field: the name of the self.data field 354 355 @type output_field: string 356 @param output_field: the name of the output field of self.data 357 358 @type time_window: float 359 @param time_window: the filter window in units of self.data['time'] 360 361 @note: I'm not sure how exactly the missing data are handled... 362 """ 363 filt = helper_fns.filter_gauss(self.data['time'], self.data[field], time_window) 364 try: 365 self.data[output_field] = filt 366 except: 367 self.data = helper_fns.append_ma2mrec(self.data, filt, output_field)
368 # scikits_imported = True 369 # if scikits_imported: 370 # mov_av = scikits.timeseries.lib.cmov_average(self.data[field], window) 371 # self.data = helper_fns.append_ma2mrec(self.data, mov_av, output_field) 372 # else: 373 # helper_fns.filter_(self.data['time'], self.data[field], 10./60/24) 374
375 - def filter_and_resample(self, field, filter_window, 376 new_sample_int, first_sample=None, 377 modify_instance=False):
378 """Filters and resamples a time signal with a running average in a 379 given time intervall +/-filter_window/2. The filter has a PHASE 380 SHIFT for a UN-equally sampled signal. 381 382 @param field: which field of self.data to use 383 @param filter_window: filter time interval 384 @param new_sample_int: the new sampling interval (in units of t) 385 @param first_sample: the time of the first sample in the resampled 386 signal. If None (default) it is 387 (approximatly) set to time[0]+filter_window 388 @param modify_instance: If true, self.data will only contain 389 the new times and data 390 391 @rtype: list of np.array 392 @return: Returns [new_time, new_data] 393 """ 394 [new_t, new_d] = filter_and_resample(self.data['time'], self.data[field], filter_window, 395 new_sample_int, first_sample) 396 if modify_instance: 397 self.data = np.empty(new_t.shape, dtype=[('time', float),(field, float)]) 398 self.data['time'] = new_t 399 self.data[field] = new_d 400 return 401 else: 402 return new_t, new_d
403
404 - def get_ind(self,times):
405 """This is the index function you want to use. 406 407 @param times: a list of times 408 409 @returns: a list of indices nearest to the times 410 """ 411 ind = [] 412 try: 413 for ti in times: 414 ind.append(self.get_index_nearest(ti)[0]) 415 except: 416 ind.append(self.get_index_nearest(times)[0]) 417 return ind
418 - def get_ind_as_slice(self, times):
419 """This function takes two times and returns a slice object 420 such that the slice includes start and end time. 421 """ 422 ind = self.get_ind(times) 423 return slice(ind[0],ind[1]+1)
424
425 - def get_index_after(self, time):
426 """get_index_after(t) returns the index at t or immediatly 427 after t and the exact time corresponding to that index. 428 429 If there is no index afterwards, return (None, None). 430 431 @rtype: tuple 432 @return: tuple containing index and the time corresponing to that index 433 """ 434 ind = (np.where(time <= self.data['time']))[0] 435 if ind.shape[0] == 0: 436 ind = None 437 exact_time = None 438 else: 439 ind = np.min(ind) 440 exact_time = self.data['time'][ind] 441 return ind, exact_time
442 - def get_index_before(self, time):
443 """get_index_before(t) returns the index at t or immediately 444 before t and the exact time corresponding to that index. 445 446 If there is no index before, return (None, None). 447 448 @rtype: tuple 449 @return: tuple containing index and the time corresponing to that index 450 """ 451 ind = (np.where(time >= self.data['time']))[0] 452 if ind.shape[0] == 0: 453 ind = None 454 exact_time = None 455 else: 456 ind = np.max(ind) 457 exact_time = self.data['time'][ind] 458 return ind, exact_time
459
460 - def get_index_nearest(self, time):
461 """Get the index and time which is nearest to the time. 462 463 @rtype: tuple 464 @return: tuple containing index and the time corresponing to that index 465 """ 466 if cython_module: 467 ind = get_ind_closest(self.data['time'], time) 468 return ind, self.data['time'][ind] 469 else: 470 ind_b, t_b = self.get_index_before(time) 471 ind_a, t_a = self.get_index_after(time) 472 if ind_b is None: 473 return ind_a, t_a 474 if ind_a is None: 475 return ind_b, t_b 476 if time-t_b > t_a-time: 477 return ind_a, t_a 478 else: 479 return ind_b, t_b
480
481 - def cut_time_series(self, time_span):
482 """This cut a piece of the timeseries out and modifies the 483 data B{in place}. 484 485 @type time_span: 486 """ 487 sli = self.get_ind_as_slice(time_span) 488 self.data = self.data[sli]
489
490 - def integrate(self, field, start=None, end=None):
491 """ 492 Integrates the field from start time to end time by a 493 trapezoidal method. 494 """ 495 if start is None: 496 ind0 = 0 497 else: 498 ind0 = self.get_index_nearest(start) 499 if end is None: 500 ind1 = len(self.data) 501 else: 502 ind1 = self.get_index_nearest(end) 503 504 return scipy.integrate.trapz(self.data[field][ind0:ind1], self.data['time'][ind0:ind1])
505 506
507 -class Metadata(object):
508 """Just a structure to hold the metadata. 509 """
510 - def __init__(self):
511 #: the units of the data 512 self.units = []
513 - def __str__(self):
514 return pprint.pformat(self.__dict__)
515 - def __repr__(self):
516 return pprint.pformat(self.__dict__)
517 - def __eq__(self, other):
518 """Compare the two underlying __dict__ 519 """ 520 return self.__dict__ == other.__dict__
521