import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import scipy
import random
import sys

# Written by Boaz Nadler, April/2023
# Demo Code for algorithm developed in 
# "A statistical approach to estimate seismic monitoring
# station biases and error levels
# By Y. Radzyner, M. Galun and B. Nadler
# Variable notations below are same as in paper


N = 10; # N is total number of monitoring stations
M = 50000; # M is total number of seismic events

y_min = 4;  # minimal and maximal generated seismic magnitude strength 
y_max = 8; 

n_min = 2; # minimal number of reporting stations for an event
n_max = max(7,N-4)  #maximal number of reporting stations for an event

bias_min = -0.4
bias_max = 0.4

error_min = 0.2
error_max = 0.5

y = np.zeros(M);    # this is the array of the true event strengths
station_bias = np.zeros(N); # array of station biases
station_std  = np.zeros(N); # station magnitude error levels
X = np.zeros((N,M))     # this will store the observed station magnitudes

# generate random data
y = np.random.uniform(y_min, y_max, size=M)   # actual (true) magnitude events
station_bias = np.random.uniform(bias_min,bias_max,size=N)
station_std  = np.random.uniform(error_min,error_max,size=N)

# normalize so that sum of station biases to be zero
bias_sum = np.sum(station_bias)
station_bias = station_bias - bias_sum/N

# generate station data in matrix X 
for i in np.arange(M):
    num_reporting_stations = np.random.randint(n_min,n_max)
    b = np.random.permutation(N)
    rep_station = b[0:num_reporting_stations]
    #print(i,' ',b, ' ',num_reporting_stations,' ',reporting_stations)
    mag_error = station_bias[rep_station] + station_std[rep_station]*np.random.normal(0,1,num_reporting_stations) 
    X[rep_station,i] = y[i]  + mag_error
    #print(i,' ',y[i],' ',X[i,:])

# compute matrices M, D, V as defined in our paper
# M[i,j] = number of events reported by both stations i and j
# D[i,j] = mean of difference in magnitude of events reported by stations i and j
# V[i,j] = variance of magnitude difference in magnitude between stations i and j
V = np.zeros((N,N))
D = np.zeros((N,N))
M = np.zeros((N,N))
for i in np.arange(N):
    idx_i = np.nonzero(X[i,:])  # list of events that station i reported
    for j in np.arange(i+1,N):
        idx_j = np.nonzero(X[j,:])
        idx_ij = np.intersect1d(idx_i,idx_j,assume_unique=True)  #events jointly detected by stations i and j
        #print(i,' ',j,' ',idx_ij)
        V[i,j] = np.var(X[i,idx_ij] - X[j,idx_ij])
        D[i,j] = np.mean(X[i,idx_ij] - X[j,idx_ij])
        D[j,i] = -D[i,j]     # D[i,j] = difference in magnitudes between stations i and j
        V[j,i] = V[i,j]
        M[i,j] = idx_ij.size #number of events reported by both stations i and j
        M[j,i] = M[i,j]     


# function that estimates the stations squared error levels 
def estimate_station_error_level(M,V):
    
    N = len(M)
    # print("Inside estimate_station_error_level N=",N)
    
    # construct matrix A for error level estimation        
    A = np.zeros((N,N))
    rhs= np.zeros(N)

    # least squares for station magnitude error levels
    A = M.copy(); 
    for i in np.arange(N):        
        rhs[i] = np.sum( np.multiply(V[i,:],M[i,:])  )
        A[i,i] = np.sum(M[i,:])

    st_var_estimate = np.linalg.inv(A).dot(rhs)    #This is the estimate of the magnitude error variance
            
    return st_var_estimate

# function that estimates the station biases, under normalization of sum equal zero
def estimate_station_bias(M,D,st_var_estimate):
    
    N = len(M)
    # least squares for station biases
    B = np.zeros((N,N))
    rhs = np.zeros(N)
    for k in np.arange(N):
        sigma2_kj = st_var_estimate[k]+st_var_estimate     # this is sigma_k^2+sigma_j^2
        M_times_D = np.multiply( M[k,:],D[k,:]  )          # vector with entries M[k,j]* D[k,j] 
        rhs[k] = - np.sum( np.divide( M_times_D,sigma2_kj)  )
        B[k,k] = - np.sum( np.divide(M[k,:] , sigma2_kj) )
        for j in np.arange(k+1,N):
            B[k,j] = M[k,j] / sigma2_kj[j]
            B[j,k] = B[k,j]
        
    # linear system is rank deficient, compute least norm solution    
    bias_estimate, res, rnk, s = scipy.linalg.lstsq(B,rhs)
    sum_bias = np.sum(bias_estimate)

    bias_estimate = bias_estimate - sum_bias/N
    return bias_estimate


st_var_estimate = estimate_station_error_level(M,V)
st_std_estimate = np.sqrt(st_var_estimate)


bias_estimate = estimate_station_bias(M,D,st_var_estimate)
    
print('Estimated Station Biases:')
print(bias_estimate)
print('True Station Biases:')
print(station_bias)

print('Difference:')
print(bias_estimate-station_bias)




plt.figure(1)
plt.clf()
plt.plot(station_std,st_std_estimate,'ro')
plt.plot([error_min, error_max],[error_min, error_max],'k-' )
plt.grid(True)
plt.xlabel('magnitude error level $\sigma$')
plt.ylabel('estimated error level')
plt.show()    
    

    
plt.figure(2)
plt.clf()
plt.plot(station_bias,bias_estimate,'ro')
plt.plot([bias_min, bias_max],[bias_min, bias_max],'k-' )
plt.grid(True)
plt.xlabel('station bias $b$')
plt.ylabel('estimated station bias')
plt.show()