import gemmi import numpy as np import matplotlib.pyplot as plt # ----------------------------- # Helper: compute SFs # ----------------------------- def compute_sf(pdb_file, dmin): st = gemmi.read_structure(pdb_file) st.setup_entities() st.assign_label_seq_id() st.remove_alternative_conformations() model = st[0] cell = st.cell sg = st.find_spacegroup() if sg is None: if st.spacegroup_hm: sg = gemmi.SpaceGroup(st.spacegroup_hm) else: raise RuntimeError("No space group found") hkl = gemmi.make_miller_array( cell=cell, spacegroup=sg, dmin=dmin, unique=True ) calc = gemmi.StructureFactorCalculatorX(cell) F = np.array([ calc.calculate_sf_from_model(model, [int(h), int(k), int(l)]) for h, k, l in hkl ], dtype=np.complex128) return F, hkl, cell def add_resolution_dependent_noise(SF, hkl, cell, alpha=0.01, B=20.0, rng=None): """ Add resolution-dependent fractional noise to complex structure factors. Noise is applied to amplitudes only, phases preserved. """ F = np.abs(SF) phase = np.angle(SF) d = np.array([cell.calculate_d(list(h)) for h in hkl]) if rng is None: rng = np.random.default_rng() scale = alpha * np.exp(B / (4.0 * d**2)) noise = rng.normal(0.0, scale) # <-- FIX HERE F_noise = F * (1.0 + noise) return F_noise * np.exp(1j * phase), d def scale_sf_against_ref_mtz(ref_mtz_file, F_gen, hkl_gen,n_bins=20 ): """ Scale generated complex structure factors against a reference MTZ using global RMS scaling followed by resolution-binned RMS refinement. Parameters ---------- ref_mtz_file : str Reference MTZ filename F_gen : ndarray (complex) Generated structure factors hkl_gen : ndarray (N,3) HKLs for generated SFs n_bins : int Number of equal-population resolution bins Returns ------- Fgen_scaled : ndarray (complex) Scaled generated SFs (matched reflections only) hkls_common : list of tuple HKLs used d : ndarray Resolution values (Å) """ # ----------------------------- # Read reference MTZ # ----------------------------- mtz = gemmi.read_mtz_file(ref_mtz_file) F_col = "F-obs-filtered" PHI_col = "PHIF-model" F_ref = mtz.column_with_label(F_col).array PHI_ref = mtz.column_with_label(PHI_col).array hkl_ref = mtz.make_miller_array() cell = mtz.cell # Build complex reference SFs F_ref_cplx = F_ref * np.exp(1j * np.deg2rad(PHI_ref)) # ----------------------------- # HKL → SF dictionaries # ----------------------------- ref_dict = { tuple(map(int, h)): F_ref_cplx[i] for i, h in enumerate(hkl_ref) if F_ref[i] > 0.0 } gen_dict = { tuple(map(int, h)): F_gen[i] for i, h in enumerate(hkl_gen) } # ----------------------------- # Find common reflections # ----------------------------- hkls_common = sorted(ref_dict.keys() & gen_dict.keys()) if len(hkls_common) == 0: raise RuntimeError("No common HKLs between reference MTZ and generated SFs") Fref = np.array([ref_dict[h] for h in hkls_common]) Fgen = np.array([gen_dict[h] for h in hkls_common]) # ----------------------------- # Resolution # ----------------------------- d = np.array([cell.calculate_d(list(h)) for h in hkls_common]) # ----------------------------- # Sort by resolution (low → high) # ----------------------------- order = np.argsort(d)[::-1] Fref = Fref[order] Fgen = Fgen[order] d = d[order] hkls_common = [hkls_common[i] for i in order] # ----------------------------- # GLOBAL RMS SCALING # ----------------------------- k0 = np.sqrt( np.mean(np.abs(Fref)**2) / np.mean(np.abs(Fgen)**2) ) Fgen *= k0 # ----------------------------- # BIN-WISE RMS REFINEMENT # ----------------------------- bins = np.array_split(np.arange(len(d)), n_bins) Fgen_scaled = Fgen.copy() for b in bins: if len(b) == 0: continue amp_ref = np.abs(Fref[b]) amp_gen = np.abs(Fgen_scaled[b]) den = np.sum(amp_gen**2) if den <= 0: continue k = np.sqrt(np.sum(amp_ref**2) / den) Fgen_scaled[b] *= k return Fgen_scaled, hkls_common def resultant(SF): R = np.sum(SF) # vector sum magnitude = np.abs(R) # |sum| phase_deg = np.angle(R, deg=True) # phase in degrees return magnitude, phase_deg def riso_x8(F1, F2): A1 = np.abs(F1) A2 = np.abs(F2) denom = 0.5 * (A1 + A2) mask = denom > 0 if np.sum(mask) == 0: return np.nan num = np.sum(np.abs(A1[mask] - A2[mask])) den = np.sum(denom[mask]) return num / den def cciso_x8(F_ref, F_test): """ Xtrapol8 definition of CCiso: Pearson CC of amplitudes F_ref and F_test """ A_ref = np.abs(F_ref) A_test = np.abs(F_test) # mask out invalids if any mask = np.isfinite(A_ref) & np.isfinite(A_test) A_ref = A_ref[mask] A_test = A_test[mask] if len(A_ref) < 2: return np.nan A_ref_mean = np.mean(A_ref) A_test_mean = np.mean(A_test) num = np.sum((A_ref - A_ref_mean)*(A_test - A_test_mean)) den = np.sqrt(np.sum((A_ref - A_ref_mean)**2)*np.sum((A_test - A_test_mean)**2)) return num/den if den != 0 else np.nan def make_equal_population_resolution_bins( d, nbins=20): d = np.asarray(d) # Sort reflections by resolution (low → high res) order = np.argsort(d)[::-1] # large d first d_sorted = d[order] n = len(d) bin_size = n // nbins bins = [] for i in range(nbins): start = i * bin_size end = (i + 1) * bin_size if i < nbins - 1 else n idx = order[start:end] bins.append({ "dmax": d_sorted[start], # low resolution edge "dmin": d_sorted[end - 1], # high resolution edge "indices": idx }) return bins def match_hkls(F1, hkl1, F2, hkl2, F3, hkl3): """ Match two SF arrays on common HKLs. Returns ------- F1c, F2c : np.ndarray (complex) hkl_common : np.ndarray (N, 3) """ dict1 = {tuple(h): F1[i] for i, h in enumerate(hkl1)} dict2 = {tuple(h): F2[i] for i, h in enumerate(hkl2)} dict3 = {tuple(h): F3[i] for i, h in enumerate(hkl3)} common = sorted(dict1.keys() & dict2.keys() & dict3.keys()) if len(common) == 0: raise RuntimeError("No common HKLs between datasets") F1c = np.array([dict1[h] for h in common]) F2c = np.array([dict2[h] for h in common]) F3c = np.array([dict3[h] for h in common]) hklc = np.array(common, dtype=int) return F1c, F2c, F3c, hklc def riso_cciso_vs_resolution(Fref, Ftest, d, nbins, min_per_bin=20): bins = make_equal_population_resolution_bins(d, nbins) dmid = [] riso_vals = [] cciso_vals = [] for b in bins: idx = b["indices"] if len(idx) < min_per_bin: continue F1 = Fref[idx] F2 = Ftest[idx] dmid.append(0.5 * (b["dmin"] + b["dmax"])) riso_vals.append(riso(F1, F2)) cciso_vals.append(cciso(F1, F2)) return np.array(dmid), np.array(riso_vals), np.array(cciso_vals) def plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals): fig, ax1 = plt.subplots() ax1.plot(1/d_centers, riso_vals, color="red", marker='o') ax1.set_xlabel("Resolution (Å)") ax1.set_ylabel("Riso", color="red") ax2 = ax1.twinx() ax2.plot(1/d_centers, cciso_vals, color="blue", marker='s') ax2.set_ylabel("CCiso", color="blue") plt.title("Riso and CCiso vs Resolution") plt.tight_layout() plt.show() # ----------------------------- # User inputs # ----------------------------- #apo_pdb = "../data/hewl_apo.pdb" #apo_pdb = "../data/ser-arg-loop-apo.pdb" #occ100_pdb = "../data/hewl_1.0-I.pdb" #occ100_pdb = "../data/ser-arg-loop-1.0-switch.pdb" #occb_pdb = "../data/hewl_0.1-I.pdb" #occb_pdb = "../data/ser-arg-loop-0.1-switch.pdb" b_val = 0.1 a_val = 1.0 - b_val dmin = 1.5 alpha = 0.0001 ref_mtz = "../data/apo_100k_refine_5.mtz" occ_lst = [ 0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9 ] print( "calculating apo" ) apo_pdb = "../data/hewl_apo.pdb" #apo_pdb = "../data/ser-arg-loop-apo.pdb" SF_apo, hkl_apo, cell_apo = compute_sf(apo_pdb, dmin) print( "done" ) print( "calculating 100 %" ) occ100_pdb = "../data/hewl_1.0-I.pdb" #occ100_pdb = "../data/ser-arg-loop-1.0-switch.pdb" SF_100, hkl_100, cell_100 = compute_sf(occ100_pdb, dmin) print( "done" ) for b_val in occ_lst: print( "occ = {0}".format( b_val ) ) a_val = 1 - b_val print( "calculating b SFs" ) occb_pdb = "../data/hewl_{0}-I.pdb".format( b_val ) #occb_pdb = "../data/ser-arg-loop-{0}-switch.pdb".format( b_val ) SF_b, hkl_b, cell_b = compute_sf(occb_pdb, dmin) print( "done" ) SF_apo, SF_100, SF_b, hklc = match_hkls(SF_apo, hkl_apo, SF_100, hkl_100, SF_b, hkl_b) d = np.array([cell_apo.calculate_d(list(h)) for h in hklc]) F_apo = np.abs(SF_apo) phi_apo = np.angle(SF_apo) F_100 = np.abs(SF_100) phi_100 = np.angle(SF_100) F_b = np.abs(SF_b) phi_b = np.angle(SF_b) F_extr = (F_b - F_apo)/b_val + F_apo SF_extr = F_extr * np.exp(1j * phi_apo) # d_centers, riso_vals, cciso_vals = riso_cciso_vs_resolution(SF_100, SF_extr, d, 20) # plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals) delta_Fextr = SF_extr - SF_100 riso = riso_x8(SF_apo, SF_b) print( "Riso = {0}".format( riso ) ) cciso = cciso_x8(SF_apo, SF_b) print( "CCiso = {0}".format( cciso ) ) #print("Fextr TEST: SF_extr = (F_b - F_apo)/b_val + F_apo") print("Mean |ΔF|:", np.mean(np.abs(delta_Fextr))) print("Mean phase Δ (deg):", np.mean(np.abs(np.angle(SF_extr / SF_100, deg=True)))) #Fextr, phi_extr = resultant(SF_extr) #F_100_model, phi_100_model = resultant(SF_100) #print("Sum SF_extr calculated: |F| = {0}, phi = {1}".format( F_extr, phi_extr ) ) #print("Sum SF_100 model : |F| = {0}, phi = {1}".format( F_100_model, phi_100_model ) ) # ----------------------------- # Resultant vectors for plotting # ----------------------------- sum_SF_apo = np.sum(SF_apo) sum_SF_100 = np.sum(SF_100) sum_SF_b = np.sum(SF_b) sum_SF_extr = np.sum(SF_extr) # ----------------------------- # Plot resultant vectors # ----------------------------- plt.figure(figsize=(8,8)) ax = plt.gca() ax.set_aspect('equal') origin = 0 + 0j ax.arrow(origin.real, origin.imag, sum_SF_apo.real, sum_SF_apo.imag, head_width=500, head_length=500, fc='blue', ec='blue', label='Fapo model') ax.arrow(origin.real, origin.imag, sum_SF_100.real, sum_SF_100.imag, head_width=500, head_length=500, fc='red', ec='red', label='F100 model') ax.arrow(origin.real, origin.imag, sum_SF_b.real, sum_SF_b.imag, head_width=500, head_length=500, fc='green', ec='green', label='Fab model') ax.arrow(origin.real, origin.imag, sum_SF_extr.real, sum_SF_extr.imag, head_width=500, head_length=500, fc='black', ec='black', label='Fextr') ax.set_xlabel("Re(F)") ax.set_ylabel("Im(F)") ax.set_title("SF Resultant Vectors") ax.grid(True) ax.legend() #plt.show()