diff --git a/coherent_difference.py b/coherent_difference.py new file mode 100644 index 0000000..c6c2abd --- /dev/null +++ b/coherent_difference.py @@ -0,0 +1,248 @@ +import gemmi +import numpy as np + +def coherent_difference_fraction(F1, F2): + """ + Coherent Difference Fraction (CDF) + + Parameters + ---------- + F1, F2 : array-like of complex + Complex structure factors on a common HKL set + + Returns + ------- + cdf : float + Fraction of coherent difference (0..1) + """ + dF = F2 - F1 + + num = np.abs(np.sum(dF)) + den = np.sum(np.abs(dF)) + + if den <= 0: + return np.nan + + return num / den + +def resultant_angle(F1, F2, degrees=True): + """ + Angle between resultant structure-factor vectors. + + Parameters + ---------- + F1, F2 : array-like of complex + Complex structure factors + degrees : bool + Return angle in degrees (default) or radians + + Returns + ------- + angle : float + Angle between resultants + """ + R1 = np.sum(F1) + R2 = np.sum(F2) + + if np.abs(R1) == 0 or np.abs(R2) == 0: + return np.nan + + cosang = np.real(R1 * np.conj(R2)) / (np.abs(R1) * np.abs(R2)) + cosang = np.clip(cosang, -1.0, 1.0) + + angle = np.arccos(cosang) + + if degrees: + angle = np.degrees(angle) + + return angle + +def riso_xtrapol8(F1, F2): + A1 = np.abs(F1) + A2 = np.abs(F2) + + denom = 0.5 * (A1 + A2) + mask = denom > 0 + + if np.sum(mask) == 0: + return np.nan + + num = np.sum(np.abs(A1[mask] - A2[mask])) + den = np.sum(denom[mask]) + + return num / den + +def extrapolation_score(F1, F2): + """ + Hybrid extrapolation success metric. + + Returns + ------- + score : float + 0..1 (higher = better) + details : dict + Breakdown of components + """ + cdf = coherent_difference_fraction(F1, F2) + theta = resultant_angle(F1, F2, degrees=False) + riso = riso_xtrapol8(F1, F2) + + if np.isnan(cdf) or np.isnan(theta): + return np.nan, {} + + score = cdf * np.cos(theta) + + details = { + "CDF": cdf, + "theta_rad": theta, + "theta_deg": np.degrees(theta), + "cos_theta": np.cos(theta), + "riso" : riso + } + + return score, details + +def read_sf_from_mtz(mtz_file, + F_col="F-obs-filtered", + PHI_col="PHIF-model", + require_positive_F=True): + """ + Read complex structure factors from an MTZ. + + Returns + ------- + F : np.ndarray (complex) + hkl : np.ndarray (N, 3) int + cell : gemmi.UnitCell + """ + + mtz = gemmi.read_mtz_file(mtz_file) + cell = mtz.cell + + # --- Extract columns --- + F = mtz.column_with_label(F_col).array + PHI = mtz.column_with_label(PHI_col).array + hkl = mtz.make_miller_array().astype(int) + + if len(F) != len(hkl): + raise RuntimeError("HKL and F column length mismatch") + + # --- Optional filtering --- + if require_positive_F: + mask = F > 0.0 + F = F[mask] + PHI = PHI[mask] + hkl = hkl[mask] + + # --- Build complex SFs --- + F_complex = F * np.exp(1j * np.deg2rad(PHI)) + + return F_complex, hkl, cell + +def read_sf_from_pdb(pdb_file, dmin): + """ + Calculate structure factors from a PDB using Gemmi. + Returns complex SFs and matching HKLs. + """ + + # ----------------------------- + # Read structure + # ----------------------------- + st = gemmi.read_structure(pdb_file) + st.setup_entities() + model = st[0] + + cell = st.cell + sg = gemmi.find_spacegroup_by_name(st.spacegroup_hm) + + # ----------------------------- + # Generate HKLs + # ----------------------------- + hkl = np.array( + gemmi.make_miller_array(cell, sg, dmin), + dtype=int + ) + + # ----------------------------- + # SF calculator + # ----------------------------- + calc = gemmi.StructureFactorCalculatorX(cell) + + # ----------------------------- + # Compute SFs one-by-one + # ----------------------------- + F = np.empty(len(hkl), dtype=complex) + + for i, (h, k, l) in enumerate(hkl): + F[i] = calc.calculate_sf_from_model( + model, + [int(h), int(k), int(l)] + ) + + return F, hkl, cell + + + +def match_hkls(F1, hkl1, F2, hkl2): + """ + Match two SF arrays on common HKLs. + + Returns + ------- + F1c, F2c : np.ndarray (complex) + hkl_common : np.ndarray (N, 3) + """ + + dict1 = {tuple(h): F1[i] for i, h in enumerate(hkl1)} + dict2 = {tuple(h): F2[i] for i, h in enumerate(hkl2)} + + common = sorted(dict1.keys() & dict2.keys()) + + if len(common) == 0: + raise RuntimeError("No common HKLs between datasets") + + F1c = np.array([dict1[h] for h in common]) + F2c = np.array([dict2[h] for h in common]) + hklc = np.array(common, dtype=int) + + return F1c, F2c, hklc + +def extrapolation_metrics(F1, F2): + """ + Compute all coherence-based extrapolation metrics. + """ + return { + "CDF": coherent_difference_fraction(F1, F2), + "theta_deg": resultant_angle(F1, F2, degrees=True), + "theta_rad": resultant_angle(F1, F2, degrees=False), + "score": extrapolation_score(F1, F2)[0], + "riso" : riso_xtrapol8(F1, F2) + } + +#=============================================================================== +# mtz_ref = "../data/apo_100k_refine_5.mtz" +# SF_ref, hkl_ref, cell_ref = read_sf_from_mtz( mtz_ref ) +# +# mtz_trig = "../data/on_100k_refine_141.mtz" +# SF_trig, hkl_trig, cell_trig = read_sf_from_mtz( mtz_trig ) +# +# SF_ref, SF_trig, hkl_comb = match_hkls(SF_ref, hkl_ref, SF_trig, hkl_trig) +#=============================================================================== + +dmin = 8.0 +pdb_ref = "../data/hewl_apo.pdb" +#pdb_ref = "../data/ser-arg-loop-apo.pdb" +print("generating ref SFs") +SF_ref, hkl_ref, cell_ref = read_sf_from_pdb(pdb_ref, dmin) +print("DONE") + + +pdb_trig = "../data/hewl_1.0-I.pdb" +#pdb_trig = "../data/ser-arg-loop-1.0-switch.pdb" +print("generating trig. SFs") +SF_trig, hkl_trig, cell_trig = read_sf_from_pdb(pdb_trig, dmin) +print("DONE") + +SF_ref, SF_trig, hkl_comb = match_hkls(SF_ref, hkl_ref, SF_trig, hkl_trig) + +print( extrapolation_metrics(SF_ref, SF_trig) ) \ No newline at end of file diff --git a/combine_resultants.py b/combine_resultants.py new file mode 100644 index 0000000..7e043c1 --- /dev/null +++ b/combine_resultants.py @@ -0,0 +1,126 @@ +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +def combine_structure_factors(F_A, phi_A, + F_B, phi_B, + a, b, + plot=True): + """ + Combine two resultant structure factors: + F_comb = a*F1 + b*F2 + + Returns: + magnitude, phase_deg, complex_value + """ + + # Convert to complex SFs + SF_A = F_A * np.exp(1j * np.deg2rad(phi_A)) + SF_B = F_B * np.exp(1j * np.deg2rad(phi_B)) + + # Linear combination + SF_AB_comb = a * SF_A + b * SF_B + F_AB_comb = np.abs(SF_AB_comb) + SF_AB_phi_A = F_AB_comb * np.exp(1j * np.deg2rad(phi_A)) + + SF_B_extr = ( SF_AB_comb - a * SF_A ) / b + + + F_AB = a * F_A + b * F_B + SF_AB_x8 = F_AB * np.exp(1j * np.deg2rad(phi_A)) + + # x8 combination + SF_extr_comb = (SF_AB_comb - SF_A)/b + SF_A + + F_extr = (F_AB_comb - F_A)/b + F_A + SF_extr_x8 = F_extr * np.exp(1j * np.deg2rad(phi_A)) + + # Optional Argand plot + origin = 0 + 0j + if plot: + plt.figure(figsize=(6, 6)) + plt.arrow(origin.real, origin.imag, SF_A.real, SF_A.imag, + color="blue", alpha=0.6, + head_width=0.2, head_length=0.2, + length_includes_head=True, label="SFA") + plt.arrow(origin.real, origin.imag, SF_B.real, SF_B.imag, + color="orange", alpha=0.6, + head_width=0.4, head_length=0.4, + length_includes_head=True, label="SFB") + plt.arrow(origin.real, origin.imag, SF_AB_comb.real, SF_AB_comb.imag, + color="green", linewidth=3, + head_width=0.2, head_length=0.2, + length_includes_head=True, label="SF_extr") + plt.arrow(origin.real, origin.imag, SF_B_extr.real, SF_B_extr.imag, + color="yellow", linewidth=3, + head_width=0.2, head_length=0.2, + length_includes_head=True, label="SF_B_extr") + plt.arrow(origin.real, origin.imag, F_extr.real, F_extr.imag, + color="pink", linewidth=3, + head_width=0.2, head_length=0.2, + length_includes_head=True, label="F_extr") + plt.axhline(0) + plt.axvline(0) + plt.axis("equal") + plt.xlabel("Re(F)") + plt.ylabel("Im(F)") + plt.title(f"F_comb = aF1 + bF2 (a={a}, b={b})") + plt.legend() + plt.grid(True) + plt.show() + + return SF_AB_comb, SF_AB_x8, SF_B_extr, SF_extr_x8 + + +# ----------------------------- +# Example usage +# ----------------------------- +F_A = 100 +#phi_A = 90 + +F_B = 50 +phi_B = 0 + +phi_lst = np.arange(0,360,10) +b_lst = [ 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1 ] +df = pd.DataFrame() + +for phi_A in phi_lst: + + for b in b_lst: + + a = 1 - b + SF_AB_comb, SF_AB_x8, SF_extr_comb, SF_extr_x8 = combine_structure_factors( + F_A, phi_A, + F_B, phi_B, + a, b, + plot=True + ) + + data = [ { "F_A" : F_A, + "phi" : phi_A, + "F_B" : F_B, + "phi_B" : phi_B, + "b" : b, + "SF_AB_comb_abs": np.abs(SF_AB_comb), + "SF_AB_comb_phi": np.angle(SF_AB_comb, deg=True), + "SF_AB_x8_abs": np.abs(SF_AB_x8), + "SF_AB_x8_phi": np.angle(SF_AB_x8, deg=True), + "SF_B_extr_abs": np.abs(SF_extr_comb), + "SF_B_extr_comb_phi": np.angle(SF_extr_comb, deg=True), + "SF_extr_x8_abs": np.abs(SF_extr_x8), + "SF_extr_x8_phi": np.angle(SF_extr_x8, deg=True), + } ] + df_1 = pd.DataFrame( data ) + df = pd.concat( ( df, df_1 ) ) + + df = df.reset_index( drop=True ) + +print( df ) + +df.to_csv( "../results/complete_varied_b_and_phi.csv", sep="," ) +df_mean = pd.DataFrame() +df_mean[ "b" ] = df.b +df_mean[ "SF_extr_x8_abs" ] = ( df.groupby('b')["SF_extr_x8_abs"].transform('mean') / F_B ) *100 + +print(df_mean.drop_duplicates()) diff --git a/compare_mtz.py b/compare_mtz.py new file mode 100644 index 0000000..7d6bfff --- /dev/null +++ b/compare_mtz.py @@ -0,0 +1,207 @@ +import numpy as np +import gemmi +import matplotlib.pyplot as plt + +def riso_xtrapol8(F1, F2): + A1 = np.abs(F1) + A2 = np.abs(F2) + + denom = 0.5 * (A1 + A2) + mask = denom > 0 + + if np.sum(mask) == 0: + return np.nan + + num = np.sum(np.abs(A1[mask] - A2[mask])) + den = np.sum(denom[mask]) + + return num / den + +def cciso_xtrapol8(F1, F2): + A1 = np.abs(F1) + A2 = np.abs(F2) + + mask = np.isfinite(A1) & np.isfinite(A2) + + A1 = A1[mask] + A2 = A2[mask] + + if len(A1) < 2: + return np.nan + + A1m = A1 - np.mean(A1) + A2m = A2 - np.mean(A2) + + num = np.sum(A1m * A2m) + den = np.sqrt(np.sum(A1m**2) * np.sum(A2m**2)) + + return num / den if den > 0 else np.nan + +def read_mtz_amplitudes(mtz_file, amp_label): + mtz = gemmi.read_mtz_file(mtz_file) + + H = np.array(mtz.column_with_label('H'), dtype=int) + K = np.array(mtz.column_with_label('K'), dtype=int) + L = np.array(mtz.column_with_label('L'), dtype=int) + + hkl = np.vstack([H, K, L]).T + + F = np.array(mtz.column_with_label(amp_label), dtype=float) + + cell = mtz.cell + + return hkl, F, cell + + +def match_reflections(hkl1, F1, hkl2, F2): + map1 = {tuple(h): f for h, f in zip(hkl1, F1)} + map2 = {tuple(h): f for h, f in zip(hkl2, F2)} + + common = sorted(set(map1.keys()) & set(map2.keys())) + + F1m = np.array([map1[h] for h in common]) + F2m = np.array([map2[h] for h in common]) + + hklm = np.array(common, dtype=int) + + return hklm, F1m, F2m + +def compute_resolution(cell, hkl): + return np.array([cell.calculate_d(list(h)) for h in hkl]) + +def make_equal_count_bins(d, n_bins): + order = np.argsort(d)[::-1] # low → high resolution + bins = np.array_split(order, n_bins) + return bins + +def make_equal_count_resolution_bins(d, n_bins): + """ + Returns a list of index arrays, each containing + roughly the same number of reflections. + """ + order = np.argsort(d)[::-1] # low → high resolution + return np.array_split(order, n_bins) + +def riso_cciso_by_resolution(d, F1, F2, n_bins=20): + """ + Compute Xtrapol8 Riso and CCiso in equal-count resolution bins. + + Returns: + d_mid : midpoint resolution per bin + Riso : array + CCiso : array + counts : reflections per bin + """ + bins = make_equal_count_resolution_bins(d, n_bins) + + d_mid = [] + Riso = [] + CCiso = [] + counts = [] + + for idx in bins: + if len(idx) < 10: + continue + + d_bin = d[idx] + F1b = F1[idx] + F2b = F2[idx] + + d_mid.append(0.5 * (d_bin.min() + d_bin.max())) + Riso.append(riso_xtrapol8(F1b, F2b)) + CCiso.append(cciso_xtrapol8(F1b, F2b)) + counts.append(len(idx)) + + return ( + np.array(d_mid), + np.array(Riso), + np.array(CCiso), + np.array(counts), + ) + + + +def plot_riso_cciso_vs_resolution(d_mid, Riso, CCiso): + fig, ax1 = plt.subplots(figsize=(6, 4)) + + # Riso + ax1.plot(1/d_mid, Riso, marker='o', label='Riso') + ax1.set_xlabel('Resolution (Å)') + ax1.set_ylabel('Riso') + ax1.grid(True, alpha=0.3) + + # CCiso + ax2 = ax1.twinx() + ax2.plot(1/d_mid, CCiso, marker='s', linestyle='--', label='CCiso') + ax2.set_ylabel('CCiso') + ax2.set_ylim(0, 1) + + # Combined legend + lines, labels = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines + lines2, labels + labels2, loc='best') + + plt.tight_layout() + plt.show() + + +def riso_cciso_from_mtz( + + mtz1, mtz2, + amp1_label, amp2_label, + n_bins=None +): + hkl1, F1, cell1 = read_mtz_amplitudes(mtz1, amp1_label) + hkl2, F2, cell2 = read_mtz_amplitudes(mtz2, amp2_label) + + hkl, F1m, F2m = match_reflections(hkl1, F1, hkl2, F2) + + d = compute_resolution(cell1, hkl) + + print(f"Matched reflections: {len(hkl)}") + + # Overall statistics + Riso = riso_xtrapol8(F1m, F2m) + CCiso = cciso_xtrapol8(F1m, F2m) + + print(f"\nOverall:") + print(f" Riso = {Riso:.4f}") + print(f" CCiso = {CCiso:.4f}") + + # Per-bin statistics + if n_bins is not None: + print(f"\nPer-resolution-bin statistics:") + bins = make_equal_count_bins(d, n_bins) + + for i, idx in enumerate(bins): + if len(idx) < 10: + continue + + Rb = riso_xtrapol8(F1m[idx], F2m[idx]) + CCb = cciso_xtrapol8(F1m[idx], F2m[idx]) + + dmin = np.min(d[idx]) + dmax = np.max(d[idx]) + + print( + f"Bin {i+1:02d} " + f"{dmax:5.2f}–{dmin:5.2f} Å " + f"Riso={Rb:6.3f} CCiso={CCb:6.3f} n={len(idx)}" + ) + + d_mid, Rb, CCb, nref = riso_cciso_by_resolution( d, F1m, F2m, n_bins=20 ) + + plot_riso_cciso_vs_resolution(d_mid, Rb, CCb) + + + return Riso, CCiso + + + + +mtz1 = "../data/apo_100k_refine_5.mtz" +mtz2 = "../data/on_100k_refine_141.mtz" +amp1_label = "F-obs-filtered" +amp2_label = "F-obs-filtered" + +riso_cciso_from_mtz( mtz1, mtz2, amp1_label, amp2_label, n_bins=None) \ No newline at end of file diff --git a/extrapolation_blow-up.py b/extrapolation_blow-up.py new file mode 100644 index 0000000..f2c72f6 --- /dev/null +++ b/extrapolation_blow-up.py @@ -0,0 +1,66 @@ +import numpy as np +import matplotlib.pyplot as plt + +# ----------------------------- +# Fake but realistic SFs +# ----------------------------- +np.random.seed(1) + +n_ref = 30 + +# Apo structure factors +Fa_amp = np.random.uniform(50, 200, n_ref) +Fa_phi = np.random.uniform(-np.pi, np.pi, n_ref) +Fa = Fa_amp * np.exp(1j * Fa_phi) + +# Difference signal (e.g. ligand / backbone change) +dF_amp = np.random.uniform(5, 40, n_ref) +dF_phi = Fa_phi + np.random.normal(0, 0.4, n_ref) +dF = dF_amp * np.exp(1j * dF_phi) + +# Mixed data +Fab = Fa + dF + +# ----------------------------- +# Choose occupancies +# ----------------------------- +b_values = [1.0, 0.5, 0.2, 0.1] + +# ----------------------------- +# Plot +# ----------------------------- +fig, axes = plt.subplots(2, len(b_values), figsize=(14, 6), sharex=True, sharey=True) + +for i, b in enumerate(b_values): + Fb = Fa + (Fab - Fa) / b # extrapolated SF + dF_scaled = (Fab - Fa) / b # pure difference term + + # ----- Extrapolated vectors ----- + ax = axes[0, i] + ax.set_title(f"Fb extrapolated (b={b})") + for z in Fb: + ax.arrow(0, 0, z.real, z.imag, + head_width=2, length_includes_head=True, + alpha=0.6) + ax.axhline(0, color='grey', lw=0.5) + ax.axvline(0, color='grey', lw=0.5) + ax.set_aspect('equal') + + # ----- Difference vectors ----- + ax = axes[1, i] + ax.set_title("ΔF / b") + for z in dF_scaled: + ax.arrow(0, 0, z.real, z.imag, + head_width=2, length_includes_head=True, + alpha=0.6, color='crimson') + ax.axhline(0, color='grey', lw=0.5) + ax.axvline(0, color='grey', lw=0.5) + ax.set_aspect('equal') + +axes[0, 0].set_ylabel("Imag(F)") +axes[1, 0].set_ylabel("Imag(F)") +for ax in axes[1]: + ax.set_xlabel("Real(F)") + +plt.tight_layout() +plt.show() diff --git a/extrapolation_blow-up_v2.py b/extrapolation_blow-up_v2.py new file mode 100644 index 0000000..31571c2 --- /dev/null +++ b/extrapolation_blow-up_v2.py @@ -0,0 +1,76 @@ +import numpy as np +import matplotlib.pyplot as plt + +# ----------------------------- +# Generate realistic fake SFs +# ----------------------------- +np.random.seed(2) + +n_ref = 2000 + +# Apo SFs +Fa_amp = np.random.uniform(50, 200, n_ref) +Fa_phi = np.random.uniform(-np.pi, np.pi, n_ref) +Fa = Fa_amp * np.exp(1j * Fa_phi) + +# Difference signal (correlated phases!) +dF_amp = np.random.uniform(5, 25, n_ref) +dF_phi = Fa_phi + np.random.normal(0, 0.3, n_ref) +dF = dF_amp * np.exp(1j * dF_phi) + +# Mixed SFs +Fab = Fa + dF + +# ----------------------------- +# Occupancies to test +# ----------------------------- +b_values = [1.0, 0.5, 0.2, 0.1, 0.05] + +# ----------------------------- +# Compute resultants +# ----------------------------- +R_Fa = np.sum(Fa) +R_Fab = np.sum(Fab) +R_dF = np.sum(Fab - Fa) + +R_Fb = [] +R_dF_scaled = [] + +for b in b_values: + Fb = Fa + (Fab - Fa) / b + R_Fb.append(np.sum(Fb)) + R_dF_scaled.append(R_dF / b) + +# ----------------------------- +# Plot resultants +# ----------------------------- +plt.figure(figsize=(6, 6)) + +# Reference resultants +plt.arrow(0, 0, R_Fa.real, R_Fa.imag, + head_width=200, length_includes_head=True, + label="R(Fa)", color="black") + +plt.arrow(0, 0, R_Fab.real, R_Fab.imag, + head_width=200, length_includes_head=True, + label="R(Fab)", color="blue") + +# Extrapolated vs difference resultants +for b, Rb, Rd in zip(b_values, R_Fb, R_dF_scaled): + plt.arrow(0, 0, Rb.real, Rb.imag, + head_width=200, length_includes_head=True, + alpha=0.7, label=f"R(Fb), b={b}") + plt.arrow(0, 0, Rd.real, Rd.imag, + head_width=200, length_includes_head=True, + linestyle="dashed", alpha=0.7) + +plt.axhline(0, color="grey", lw=0.5) +plt.axvline(0, color="grey", lw=0.5) +plt.gca().set_aspect("equal") + +plt.xlabel("Real") +plt.ylabel("Imag") +plt.title("Resultant vectors: extrapolation vs difference") +plt.legend(fontsize=8) +plt.tight_layout() +plt.show() diff --git a/generate_argand_v4.py b/generate_argand_v4.py new file mode 100644 index 0000000..d4e07c2 --- /dev/null +++ b/generate_argand_v4.py @@ -0,0 +1,59 @@ +import gemmi +import numpy as np +import matplotlib.pyplot as plt + +# ----------------------------- +# User input +# ----------------------------- +mtz_file = "../data/Fext_tests_refine_5.mtz" +F_label = "F-model" +PHI_label = "PHIF-model" +n_max = 30000 + +# ----------------------------- +# Load MTZ +# ----------------------------- +mtz = gemmi.read_mtz_file(mtz_file) + +F = np.array(mtz.column_with_label(F_label))[:n_max] +PHI = np.array(mtz.column_with_label(PHI_label))[:n_max] + +# Convert to complex structure factors +C = F * np.exp(1j * np.deg2rad(PHI)) + +# ----------------------------- +# Head-to-tail vector plot +# ----------------------------- +plt.figure(figsize=(7, 7)) + +x_start, y_start = 0.0, 0.0 # first vector starts at origin + +for z in C: + dx, dy = z.real, z.imag + + # draw arrow FROM (x_start, y_start) TO (x_start+dx, y_start+dy) + plt.arrow( + x_start, y_start, + dx, dy, + length_includes_head=True, + head_width=0.03 * abs(z), + alpha=0.6 + ) + + # update start point for next vector + x_start += dx + y_start += dy + + + +# Mark start and end explicitly +plt.scatter(0, 0, c="black", s=40, label="Start") +plt.scatter(x_start, y_start, c="red", s=60, label="End") + +plt.xlabel("Re(F)") +plt.ylabel("Im(F)") +plt.title("Head-to-tail Argand plot of structure factors") +plt.axis("equal") +plt.grid(True) +plt.legend() +plt.show() diff --git a/generate_argand_v5.py b/generate_argand_v5.py new file mode 100644 index 0000000..84cafc7 --- /dev/null +++ b/generate_argand_v5.py @@ -0,0 +1,138 @@ +import gemmi +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import Normalize +from matplotlib.cm import ScalarMappable + +# ----------------------------- +# User input +# ----------------------------- +#mtz_file = "../data/on_100k_refine_141.mtz" +#F_label = "F-model" +#PHI_label = "PHIF-model" +mtz_file = "../data/apo_100k_refine_5.mtz" +F_label = "F-obs-filtered" +PHI_label = "PHIF-model" +n_max = 10000 +n_bins = 10 + +# ----------------------------- +# Load MTZ +# ----------------------------- +mtz = gemmi.read_mtz_file(mtz_file) + +# Unit cell +cell = mtz.cell + +# Structure factor data +F = np.array(mtz.column_with_label(F_label)) +PHI = np.array(mtz.column_with_label(PHI_label)) + +# Gemmi-native HKL +# HKL (must be integers!) +H = np.array(mtz.column_with_label("H"), dtype=int) +K = np.array(mtz.column_with_label("K"), dtype=int) +L = np.array(mtz.column_with_label("L"), dtype=int) + +# Remove invalid reflections +mask = np.isfinite(F) & np.isfinite(PHI) + +F = F[mask] +PHI = PHI[mask] +H = H[mask] +K = K[mask] +L = L[mask] + +# Complex structure factors +C = F * np.exp(1j * np.deg2rad(PHI)) + +# ----------------------------- +# Compute resolution (d-spacing) +# ----------------------------- +# d = 1 / |h*| where |h*| comes from reciprocal cell +d = np.array([cell.calculate_d([h, k, l]) for h, k, l in zip(H, K, L)]) + +# ----------------------------- +# Sort reflections: low → high resolution +# (large d → small d) +# ----------------------------- +idx = np.argsort(-d) + +C = C[idx][:n_max] +d = d[idx][:n_max] + +# ----------------------------- +# Assign resolution bins +# ----------------------------- +bins = np.linspace(d.min(), d.max(), n_bins + 1) +colors_array = plt.cm.coolwarm(np.linspace(0, 1, n_bins)) # colormap + +bin_indices = np.digitize(d, bins) - 1 +bin_indices = np.clip(bin_indices, 0, n_bins - 1) # <- fix +# ----------------------------- +# Head-to-tail Argand walk +# ----------------------------- +plt.figure(figsize=(7, 7)) + +x, y = 0.0, 0.0 # first vector starts at origin + +for i, z in enumerate(C): + dx, dy = z.real, z.imag + + # draw arrow FROM (x_start, y_start) TO (x_start+dx, y_start+dy) + plt.arrow( + x, y, + dx, dy, + color=colors_array[bin_indices[i]], + length_includes_head=True, + head_width=0.03 * abs(z), + alpha=0.6 + ) + + # update start point for next vector + x += dx + y += dy + +# ----------------------------- +# Final tip on top +# ----------------------------- +plt.scatter(x, y, c="red", s=80, label="Final tip") + +# ----------------------------- +# Resultant vector (sum of all SFs) +# ----------------------------- +sum_C = np.sum(C) +plt.arrow( + 0, 0, + sum_C.real, sum_C.imag, + length_includes_head=True, + head_width=0.07 * abs(sum_C), + color="green", + alpha=0.9, + label="Resultant ΣF" +) + + +# Start / end markers +plt.scatter(0, 0, c="black", s=40, label="Start") + +plt.xlabel("Re(F)") +plt.ylabel("Im(F)") +plt.title("Head-to-tail Argand walk (sorted by resolution)") +plt.axis("equal") +plt.grid(True) +plt.legend() +plt.show() + + +# Assuming C is your filtered, complex structure factor array +sum_C = np.sum(C) # sum all complex vectors + +# Magnitude +magnitude = np.abs(sum_C) + +# Phase in degrees (-180 to +180) +phase_deg = np.angle(sum_C, deg=True) + +print(f"Resultant vector magnitude: {magnitude:.3f}") +print(f"Resultant vector phase: {phase_deg:.2f}°") \ No newline at end of file diff --git a/model_vectors.py b/model_vectors.py new file mode 100644 index 0000000..f3a6da5 --- /dev/null +++ b/model_vectors.py @@ -0,0 +1,104 @@ +import gemmi +import numpy as np +import matplotlib.pyplot as plt +# ----------------------------- +# User inputs +# ----------------------------- +apo_pdb = "../data/hewl_apo.pdb" +occ100_pdb = "../data/hewl_1.0-I.pdb" +occb_pdb = "../data/hewl_0.5-I.pdb" + +b = 0.5 +a = 1.0 - b +dmin = 2.0 # resolution cutoff (Å) + +# ----------------------------- +# Compute structure factors +# ----------------------------- +def compute_sf(pdb_file, dmin): + # Read structure + st = gemmi.read_structure(pdb_file) + st.setup_entities() + st.assign_label_seq_id() + st.remove_alternative_conformations() + + model = st[0] + cell = st.cell + + # Space group + sg = st.find_spacegroup() + if sg is None: + if st.spacegroup_hm: + sg = gemmi.SpaceGroup(st.spacegroup_hm) + else: + raise RuntimeError("No space group found") + + # Generate HKLs + hkl = gemmi.make_miller_array( + cell=cell, + spacegroup=sg, + dmin=dmin, + unique=True + ) + + # Structure factor calculator + calc = gemmi.StructureFactorCalculatorX(cell) + + # Compute SFs ONE reflection at a time (this is required) + F = np.array([ + calc.calculate_sf_from_model(model, [int(h), int(k), int(l)]) + for h, k, l in hkl + ], dtype=np.complex128) + + return hkl, F + + +hkl, Fa = compute_sf(apo_pdb, dmin) +_, F100 = compute_sf(occ100_pdb, dmin) +_, Fb_model = compute_sf(occb_pdb, dmin) + +# ----------------------------- +# Sanity: same HKLs? +# ----------------------------- +assert len(Fa) == len(F100) == len(Fb_model) + +# ----------------------------- +# Test forward relation: +# Fab = a Fa + b F100 +# ----------------------------- +Fab = a * Fa + b * F100 +print(Fab) + +# Compare to model b-occupancy SFs +delta_forward = Fab - Fb_model + +print("FORWARD TEST: Fab = aFa + bF100") +print("Mean |ΔF|:", np.mean(np.abs(delta_forward))) +print("Mean phase Δ (deg):", + np.mean(np.abs(np.angle(Fab / Fb_model, deg=True)))) + +# ----------------------------- +# Test inverse relation: +# F100 = (Fab − Fa) / b +# ----------------------------- +F100_recon = (Fb_model - Fa) / b + Fa + +delta_inverse = F100_recon - F100 + +print("\nINVERSE TEST: F100 = (Fab − Fa)/b + Fa") +print("Mean |ΔF|:", np.mean(np.abs(delta_inverse))) +print("Mean phase Δ (deg):", + np.mean(np.abs(np.angle(F100_recon / F100, deg=True)))) + +# ----------------------------- +# Optional: inspect worst reflections +# ----------------------------- +idx = np.argsort(np.abs(delta_inverse))[-5:] + +print("\nWorst reflections (inverse):") +for i in idx: + print( + f"|F100|={abs(F100[i]):.2f} " + f"|ΔF|={abs(delta_inverse[i]):.2f} " + f"Δφ={np.angle(F100_recon[i]/F100[i],deg=True):.1f}°" + ) \ No newline at end of file diff --git a/model_vectors_publication_images_v3.py b/model_vectors_publication_images_v3.py new file mode 100644 index 0000000..ad6e111 --- /dev/null +++ b/model_vectors_publication_images_v3.py @@ -0,0 +1,425 @@ +import gemmi +import numpy as np +import matplotlib.pyplot as plt + + +# ----------------------------- +# Helper: compute SFs +# ----------------------------- + +def compute_sf(pdb_file, dmin): + + st = gemmi.read_structure(pdb_file) + st.setup_entities() + st.assign_label_seq_id() + st.remove_alternative_conformations() + model = st[0] + cell = st.cell + + sg = st.find_spacegroup() + if sg is None: + if st.spacegroup_hm: + sg = gemmi.SpaceGroup(st.spacegroup_hm) + else: + raise RuntimeError("No space group found") + + hkl = gemmi.make_miller_array( + cell=cell, + spacegroup=sg, + dmin=dmin, + unique=True + ) + + calc = gemmi.StructureFactorCalculatorX(cell) + F = np.array([ + calc.calculate_sf_from_model(model, [int(h), int(k), int(l)]) + for h, k, l in hkl + ], dtype=np.complex128) + + return F, hkl, cell + + +def add_resolution_dependent_noise(SF, hkl, cell, alpha=0.01, B=20.0, rng=None): + """ + Add resolution-dependent fractional noise to complex structure factors. + + Noise is applied to amplitudes only, phases preserved. + """ + F = np.abs(SF) + phase = np.angle(SF) + + d = np.array([cell.calculate_d(list(h)) for h in hkl]) + + if rng is None: + rng = np.random.default_rng() + + scale = alpha * np.exp(B / (4.0 * d**2)) + noise = rng.normal(0.0, scale) # <-- FIX HERE + + F_noise = F * (1.0 + noise) + + return F_noise * np.exp(1j * phase), d + +def scale_sf_against_ref_mtz(ref_mtz_file, F_gen, hkl_gen,n_bins=20 ): + """ + Scale generated complex structure factors against a reference MTZ + using global RMS scaling followed by resolution-binned RMS refinement. + + Parameters + ---------- + ref_mtz_file : str + Reference MTZ filename + F_gen : ndarray (complex) + Generated structure factors + hkl_gen : ndarray (N,3) + HKLs for generated SFs + n_bins : int + Number of equal-population resolution bins + + Returns + ------- + Fgen_scaled : ndarray (complex) + Scaled generated SFs (matched reflections only) + hkls_common : list of tuple + HKLs used + d : ndarray + Resolution values (Å) + """ + + # ----------------------------- + # Read reference MTZ + # ----------------------------- + mtz = gemmi.read_mtz_file(ref_mtz_file) + + F_col = "F-obs-filtered" + PHI_col = "PHIF-model" + + F_ref = mtz.column_with_label(F_col).array + PHI_ref = mtz.column_with_label(PHI_col).array + hkl_ref = mtz.make_miller_array() + + cell = mtz.cell + + # Build complex reference SFs + F_ref_cplx = F_ref * np.exp(1j * np.deg2rad(PHI_ref)) + + # ----------------------------- + # HKL → SF dictionaries + # ----------------------------- + ref_dict = { + tuple(map(int, h)): F_ref_cplx[i] + for i, h in enumerate(hkl_ref) + if F_ref[i] > 0.0 + } + + gen_dict = { + tuple(map(int, h)): F_gen[i] + for i, h in enumerate(hkl_gen) + } + + # ----------------------------- + # Find common reflections + # ----------------------------- + hkls_common = sorted(ref_dict.keys() & gen_dict.keys()) + if len(hkls_common) == 0: + raise RuntimeError("No common HKLs between reference MTZ and generated SFs") + + Fref = np.array([ref_dict[h] for h in hkls_common]) + Fgen = np.array([gen_dict[h] for h in hkls_common]) + + # ----------------------------- + # Resolution + # ----------------------------- + d = np.array([cell.calculate_d(list(h)) for h in hkls_common]) + + # ----------------------------- + # Sort by resolution (low → high) + # ----------------------------- + order = np.argsort(d)[::-1] + Fref = Fref[order] + Fgen = Fgen[order] + d = d[order] + hkls_common = [hkls_common[i] for i in order] + + # ----------------------------- + # GLOBAL RMS SCALING + # ----------------------------- + k0 = np.sqrt( + np.mean(np.abs(Fref)**2) / + np.mean(np.abs(Fgen)**2) + ) + Fgen *= k0 + + # ----------------------------- + # BIN-WISE RMS REFINEMENT + # ----------------------------- + bins = np.array_split(np.arange(len(d)), n_bins) + Fgen_scaled = Fgen.copy() + + for b in bins: + if len(b) == 0: + continue + + amp_ref = np.abs(Fref[b]) + amp_gen = np.abs(Fgen_scaled[b]) + + den = np.sum(amp_gen**2) + if den <= 0: + continue + + k = np.sqrt(np.sum(amp_ref**2) / den) + Fgen_scaled[b] *= k + + return Fgen_scaled, hkls_common + + +def resultant(SF): + + R = np.sum(SF) # vector sum + magnitude = np.abs(R) # |sum| + phase_deg = np.angle(R, deg=True) # phase in degrees + return magnitude, phase_deg + +def riso_x8(F1, F2): + A1 = np.abs(F1) + A2 = np.abs(F2) + + denom = 0.5 * (A1 + A2) + mask = denom > 0 + + if np.sum(mask) == 0: + return np.nan + + num = np.sum(np.abs(A1[mask] - A2[mask])) + den = np.sum(denom[mask]) + + return num / den + +def cciso_x8(F_ref, F_test): + """ + Xtrapol8 definition of CCiso: + Pearson CC of amplitudes F_ref and F_test + """ + A_ref = np.abs(F_ref) + A_test = np.abs(F_test) + + # mask out invalids if any + mask = np.isfinite(A_ref) & np.isfinite(A_test) + + A_ref = A_ref[mask] + A_test = A_test[mask] + + if len(A_ref) < 2: + return np.nan + + A_ref_mean = np.mean(A_ref) + A_test_mean = np.mean(A_test) + + num = np.sum((A_ref - A_ref_mean)*(A_test - A_test_mean)) + den = np.sqrt(np.sum((A_ref - A_ref_mean)**2)*np.sum((A_test - A_test_mean)**2)) + + return num/den if den != 0 else np.nan + +def make_equal_population_resolution_bins( d, nbins=20): + + d = np.asarray(d) + + # Sort reflections by resolution (low → high res) + order = np.argsort(d)[::-1] # large d first + d_sorted = d[order] + + n = len(d) + bin_size = n // nbins + + bins = [] + + for i in range(nbins): + start = i * bin_size + end = (i + 1) * bin_size if i < nbins - 1 else n + + idx = order[start:end] + + bins.append({ + "dmax": d_sorted[start], # low resolution edge + "dmin": d_sorted[end - 1], # high resolution edge + "indices": idx + }) + + return bins + +def match_hkls(F1, hkl1, F2, hkl2, F3, hkl3): + """ + Match two SF arrays on common HKLs. + + Returns + ------- + F1c, F2c : np.ndarray (complex) + hkl_common : np.ndarray (N, 3) + """ + + dict1 = {tuple(h): F1[i] for i, h in enumerate(hkl1)} + dict2 = {tuple(h): F2[i] for i, h in enumerate(hkl2)} + dict3 = {tuple(h): F3[i] for i, h in enumerate(hkl3)} + + common = sorted(dict1.keys() & dict2.keys() & dict3.keys()) + + if len(common) == 0: + raise RuntimeError("No common HKLs between datasets") + + F1c = np.array([dict1[h] for h in common]) + F2c = np.array([dict2[h] for h in common]) + F3c = np.array([dict3[h] for h in common]) + hklc = np.array(common, dtype=int) + + return F1c, F2c, F3c, hklc + +def riso_cciso_vs_resolution(Fref, Ftest, d, nbins, min_per_bin=20): + + bins = make_equal_population_resolution_bins(d, nbins) + + dmid = [] + riso_vals = [] + cciso_vals = [] + + for b in bins: + idx = b["indices"] + + if len(idx) < min_per_bin: + continue + + F1 = Fref[idx] + F2 = Ftest[idx] + + dmid.append(0.5 * (b["dmin"] + b["dmax"])) + riso_vals.append(riso(F1, F2)) + cciso_vals.append(cciso(F1, F2)) + + return np.array(dmid), np.array(riso_vals), np.array(cciso_vals) + + +def plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals): + + fig, ax1 = plt.subplots() + + ax1.plot(1/d_centers, riso_vals, color="red", marker='o') + ax1.set_xlabel("Resolution (Å)") + ax1.set_ylabel("Riso", color="red") + + ax2 = ax1.twinx() + ax2.plot(1/d_centers, cciso_vals, color="blue", marker='s') + ax2.set_ylabel("CCiso", color="blue") + + plt.title("Riso and CCiso vs Resolution") + plt.tight_layout() + plt.show() + +# ----------------------------- +# User inputs +# ----------------------------- +#apo_pdb = "../data/hewl_apo.pdb" +#apo_pdb = "../data/ser-arg-loop-apo.pdb" +#occ100_pdb = "../data/hewl_1.0-I.pdb" +#occ100_pdb = "../data/ser-arg-loop-1.0-switch.pdb" +#occb_pdb = "../data/hewl_0.1-I.pdb" +#occb_pdb = "../data/ser-arg-loop-0.1-switch.pdb" +b_val = 0.1 +a_val = 1.0 - b_val +dmin = 1.5 +alpha = 0.0001 +ref_mtz = "../data/apo_100k_refine_5.mtz" + +occ_lst = [ 0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9 ] + +print( "calculating apo" ) +apo_pdb = "../data/hewl_apo.pdb" +#apo_pdb = "../data/ser-arg-loop-apo.pdb" +SF_apo, hkl_apo, cell_apo = compute_sf(apo_pdb, dmin) +print( "done" ) + +print( "calculating 100 %" ) +occ100_pdb = "../data/hewl_1.0-I.pdb" +#occ100_pdb = "../data/ser-arg-loop-1.0-switch.pdb" +SF_100, hkl_100, cell_100 = compute_sf(occ100_pdb, dmin) +print( "done" ) + +for b_val in occ_lst: + + print( "occ = {0}".format( b_val ) ) + a_val = 1 - b_val + + print( "calculating b SFs" ) + occb_pdb = "../data/hewl_{0}-I.pdb".format( b_val ) + #occb_pdb = "../data/ser-arg-loop-{0}-switch.pdb".format( b_val ) + SF_b, hkl_b, cell_b = compute_sf(occb_pdb, dmin) + print( "done" ) + + SF_apo, SF_100, SF_b, hklc = match_hkls(SF_apo, hkl_apo, SF_100, hkl_100, SF_b, hkl_b) + d = np.array([cell_apo.calculate_d(list(h)) for h in hklc]) + + F_apo = np.abs(SF_apo) + phi_apo = np.angle(SF_apo) + + F_100 = np.abs(SF_100) + phi_100 = np.angle(SF_100) + + F_b = np.abs(SF_b) + phi_b = np.angle(SF_b) + + F_extr = (F_b - F_apo)/b_val + F_apo + SF_extr = F_extr * np.exp(1j * phi_apo) + +# d_centers, riso_vals, cciso_vals = riso_cciso_vs_resolution(SF_100, SF_extr, d, 20) +# plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals) + + delta_Fextr = SF_extr - SF_100 + + riso = riso_x8(SF_apo, SF_b) + print( "Riso = {0}".format( riso ) ) + cciso = cciso_x8(SF_apo, SF_b) + print( "CCiso = {0}".format( cciso ) ) + + #print("Fextr TEST: SF_extr = (F_b - F_apo)/b_val + F_apo") + print("Mean |ΔF|:", np.mean(np.abs(delta_Fextr))) + print("Mean phase Δ (deg):", + np.mean(np.abs(np.angle(SF_extr / SF_100, deg=True)))) + + #Fextr, phi_extr = resultant(SF_extr) + #F_100_model, phi_100_model = resultant(SF_100) + + #print("Sum SF_extr calculated: |F| = {0}, phi = {1}".format( F_extr, phi_extr ) ) + #print("Sum SF_100 model : |F| = {0}, phi = {1}".format( F_100_model, phi_100_model ) ) + + # ----------------------------- + # Resultant vectors for plotting + # ----------------------------- + sum_SF_apo = np.sum(SF_apo) + sum_SF_100 = np.sum(SF_100) + sum_SF_b = np.sum(SF_b) + sum_SF_extr = np.sum(SF_extr) + + # ----------------------------- + # Plot resultant vectors + # ----------------------------- + plt.figure(figsize=(8,8)) + ax = plt.gca() + ax.set_aspect('equal') + origin = 0 + 0j + + ax.arrow(origin.real, origin.imag, sum_SF_apo.real, sum_SF_apo.imag, + head_width=500, head_length=500, fc='blue', ec='blue', label='Fapo model') + + ax.arrow(origin.real, origin.imag, sum_SF_100.real, sum_SF_100.imag, + head_width=500, head_length=500, fc='red', ec='red', label='F100 model') + + ax.arrow(origin.real, origin.imag, sum_SF_b.real, sum_SF_b.imag, + head_width=500, head_length=500, fc='green', ec='green', label='Fab model') + + ax.arrow(origin.real, origin.imag, sum_SF_extr.real, sum_SF_extr.imag, + head_width=500, head_length=500, fc='black', ec='black', label='Fextr') + + ax.set_xlabel("Re(F)") + ax.set_ylabel("Im(F)") + ax.set_title("SF Resultant Vectors") + ax.grid(True) + ax.legend() + #plt.show() diff --git a/model_vectors_v2.py b/model_vectors_v2.py new file mode 100644 index 0000000..10e7312 --- /dev/null +++ b/model_vectors_v2.py @@ -0,0 +1,472 @@ +import gemmi +import numpy as np +import matplotlib.pyplot as plt +from Bio.SeqUtils.ProtParamData import fs + +# ----------------------------- +# User inputs +# ----------------------------- +#apo_pdb = "../data/hewl_apo.pdb" +apo_pdb = "../data/ser-arg-loop-apo.pdb" +#occ100_pdb = "../data/hewl_1.0-I.pdb" +occ100_pdb = "../data/ser-arg-loop-1.0-switch.pdb" +#occb_pdb = "../data/hewl_0.1-I.pdb" +occb_pdb = "../data/ser-arg-loop-0.1-switch.pdb" +b_val = 0.1 +a_val = 1.0 - b_val +dmin = 2.5 +alpha = 0.0001 +ref_mtz = "../data/apo_100k_refine_5.mtz" + + +# ----------------------------- +# Helper: compute SFs +# ----------------------------- + +def compute_sf(pdb_file, dmin): + + st = gemmi.read_structure(pdb_file) + st.setup_entities() + st.assign_label_seq_id() + st.remove_alternative_conformations() + model = st[0] + cell = st.cell + + sg = st.find_spacegroup() + if sg is None: + if st.spacegroup_hm: + sg = gemmi.SpaceGroup(st.spacegroup_hm) + else: + raise RuntimeError("No space group found") + + hkl = gemmi.make_miller_array( + cell=cell, + spacegroup=sg, + dmin=dmin, + unique=True + ) + + calc = gemmi.StructureFactorCalculatorX(cell) + F = np.array([ + calc.calculate_sf_from_model(model, [int(h), int(k), int(l)]) + for h, k, l in hkl + ], dtype=np.complex128) + + return F, hkl, cell + + +def add_resolution_dependent_noise(SF, hkl, cell, alpha=0.01, B=20.0, rng=None): + """ + Add resolution-dependent fractional noise to complex structure factors. + + Noise is applied to amplitudes only, phases preserved. + """ + F = np.abs(SF) + phase = np.angle(SF) + + d = np.array([cell.calculate_d(list(h)) for h in hkl]) + + if rng is None: + rng = np.random.default_rng() + + scale = alpha * np.exp(B / (4.0 * d**2)) + noise = rng.normal(0.0, scale) # <-- FIX HERE + + F_noise = F * (1.0 + noise) + + return F_noise * np.exp(1j * phase), d + +def scale_sf_against_ref_mtz(ref_mtz_file, F_gen, hkl_gen,n_bins=20 ): + """ + Scale generated complex structure factors against a reference MTZ + using global RMS scaling followed by resolution-binned RMS refinement. + + Parameters + ---------- + ref_mtz_file : str + Reference MTZ filename + F_gen : ndarray (complex) + Generated structure factors + hkl_gen : ndarray (N,3) + HKLs for generated SFs + n_bins : int + Number of equal-population resolution bins + + Returns + ------- + Fgen_scaled : ndarray (complex) + Scaled generated SFs (matched reflections only) + hkls_common : list of tuple + HKLs used + d : ndarray + Resolution values (Å) + """ + + # ----------------------------- + # Read reference MTZ + # ----------------------------- + mtz = gemmi.read_mtz_file(ref_mtz_file) + + F_col = "F-obs-filtered" + PHI_col = "PHIF-model" + + F_ref = mtz.column_with_label(F_col).array + PHI_ref = mtz.column_with_label(PHI_col).array + hkl_ref = mtz.make_miller_array() + + cell = mtz.cell + + # Build complex reference SFs + F_ref_cplx = F_ref * np.exp(1j * np.deg2rad(PHI_ref)) + + # ----------------------------- + # HKL → SF dictionaries + # ----------------------------- + ref_dict = { + tuple(map(int, h)): F_ref_cplx[i] + for i, h in enumerate(hkl_ref) + if F_ref[i] > 0.0 + } + + gen_dict = { + tuple(map(int, h)): F_gen[i] + for i, h in enumerate(hkl_gen) + } + + # ----------------------------- + # Find common reflections + # ----------------------------- + hkls_common = sorted(ref_dict.keys() & gen_dict.keys()) + if len(hkls_common) == 0: + raise RuntimeError("No common HKLs between reference MTZ and generated SFs") + + Fref = np.array([ref_dict[h] for h in hkls_common]) + Fgen = np.array([gen_dict[h] for h in hkls_common]) + + # ----------------------------- + # Resolution + # ----------------------------- + d = np.array([cell.calculate_d(list(h)) for h in hkls_common]) + + # ----------------------------- + # Sort by resolution (low → high) + # ----------------------------- + order = np.argsort(d)[::-1] + Fref = Fref[order] + Fgen = Fgen[order] + d = d[order] + hkls_common = [hkls_common[i] for i in order] + + # ----------------------------- + # GLOBAL RMS SCALING + # ----------------------------- + k0 = np.sqrt( + np.mean(np.abs(Fref)**2) / + np.mean(np.abs(Fgen)**2) + ) + Fgen *= k0 + + # ----------------------------- + # BIN-WISE RMS REFINEMENT + # ----------------------------- + bins = np.array_split(np.arange(len(d)), n_bins) + Fgen_scaled = Fgen.copy() + + for b in bins: + if len(b) == 0: + continue + + amp_ref = np.abs(Fref[b]) + amp_gen = np.abs(Fgen_scaled[b]) + + den = np.sum(amp_gen**2) + if den <= 0: + continue + + k = np.sqrt(np.sum(amp_ref**2) / den) + Fgen_scaled[b] *= k + + return Fgen_scaled, hkls_common + + +def resultant(SF): + + R = np.sum(SF) # vector sum + magnitude = np.abs(R) # |sum| + phase_deg = np.angle(R, deg=True) # phase in degrees + return magnitude, phase_deg + +def riso(F_ref, F_test): + """ + Xtrapol8 definition of Riso: + sum(|F_ref - F_test|) / sum((F_ref + F_test)/2) + """ + # amplitudes (ensure real and non-negative) + A_ref = np.abs(F_ref) + A_test = np.abs(F_test) + + denom = 0.5*(A_ref + A_test) + mask = denom > 0 # avoid zero denominator + + num = np.sum(np.abs(A_ref - A_test)) + den = np.sum(denom) + + return num/den if den != 0 else np.nan + +def cciso(F_ref, F_test): + """ + Xtrapol8 definition of CCiso: + Pearson CC of amplitudes F_ref and F_test + """ + A_ref = np.abs(F_ref) + A_test = np.abs(F_test) + + # mask out invalids if any + mask = np.isfinite(A_ref) & np.isfinite(A_test) + + A_ref = A_ref[mask] + A_test = A_test[mask] + + if len(A_ref) < 2: + return np.nan + + A_ref_mean = np.mean(A_ref) + A_test_mean = np.mean(A_test) + + num = np.sum((A_ref - A_ref_mean)*(A_test - A_test_mean)) + den = np.sqrt(np.sum((A_ref - A_ref_mean)**2)*np.sum((A_test - A_test_mean)**2)) + + return num/den if den != 0 else np.nan + +def make_equal_population_resolution_bins( d, nbins=20): + + d = np.asarray(d) + + # Sort reflections by resolution (low → high res) + order = np.argsort(d)[::-1] # large d first + d_sorted = d[order] + + n = len(d) + bin_size = n // nbins + + bins = [] + + for i in range(nbins): + start = i * bin_size + end = (i + 1) * bin_size if i < nbins - 1 else n + + idx = order[start:end] + + bins.append({ + "dmax": d_sorted[start], # low resolution edge + "dmin": d_sorted[end - 1], # high resolution edge + "indices": idx + }) + + return bins + + +def riso_cciso_vs_resolution(Fref, Ftest, d, nbins, min_per_bin=20): + + bins = make_equal_population_resolution_bins(d, nbins) + + dmid = [] + riso_vals = [] + cciso_vals = [] + + for b in bins: + idx = b["indices"] + + if len(idx) < min_per_bin: + continue + + F1 = Fref[idx] + F2 = Ftest[idx] + + dmid.append(0.5 * (b["dmin"] + b["dmax"])) + riso_vals.append(riso(F1, F2)) + cciso_vals.append(cciso(F1, F2)) + + return np.array(dmid), np.array(riso_vals), np.array(cciso_vals) + + +def plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals): + + fig, ax1 = plt.subplots() + + ax1.plot(1/d_centers, riso_vals, color="red", marker='o') + ax1.set_xlabel("Resolution (Å)") + ax1.set_ylabel("Riso", color="red") + + ax2 = ax1.twinx() + ax2.plot(1/d_centers, cciso_vals, color="blue", marker='s') + ax2.set_ylabel("CCiso", color="blue") + + plt.title("Riso and CCiso vs Resolution") + plt.tight_layout() + plt.show() + +# ----------------------------- +# Compute structure factors +# ----------------------------- +print( "making apo SFs" ) +SF_apo, hkl_apo, cell_apo = compute_sf(apo_pdb, dmin) +SF_apo, hkl_apo = scale_sf_against_ref_mtz( ref_mtz, SF_apo, hkl_apo ) +#SF_apo, d = add_resolution_dependent_noise(SF_apo, hkl_apo, cell_apo, alpha, B=20.0, rng=None) +print( "done" ) + +print( "making 100 % SFs" ) +SF_100, hkl_100, cell_100 = compute_sf(occ100_pdb, dmin) # Fb_full +SF_100, hkl_100 = scale_sf_against_ref_mtz( ref_mtz, SF_100, hkl_100 ) +#SF_100, d = add_resolution_dependent_noise( SF_100, hkl_100, cell_100, alpha, B=20.0, rng=None ) +print( "done" ) + +print( "making b % SFs" ) +SF_b, hkl_b, cell_b = compute_sf(occb_pdb, dmin) # Fb_partial +SF_b, hkl_b = scale_sf_against_ref_mtz( ref_mtz, SF_b, hkl_b ) +SF_b, d = add_resolution_dependent_noise( SF_b, hkl_b, cell_b, alpha, B=20.0, rng=None ) +print( "done" ) + +d_centers, riso_vals, cciso_vals = riso_cciso_vs_resolution(SF_apo, SF_b, d, 20) +plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals) + + +F_apo = np.abs(SF_apo) +phi_apo = np.angle(SF_apo) + +F_100 = np.abs(SF_100) +phi_100 = np.angle(SF_100) + +F_b = np.abs(SF_b) +phi_b = np.angle(SF_b) + +SF_100_phi_apo = F_100 * np.exp(1j * np.deg2rad(phi_apo)) + +SF_b_phi_apo = F_b * np.exp(1j * np.deg2rad(phi_b)) + +# Use only amplitudes of partial/full occupancy, phases from Fa + +#Fb_partial_extrap = np.abs(Fb_partial) * np.exp(1j*phi_apo) +#Fb_full_extrap = np.abs(Fb_full) * np.exp(1j*phi_apo) + +# ----------------------------- +# correct method +# ----------------------------- +# Forward extrapolation: Fab = aFa + bFb_partial +SF_ab = a_val*SF_apo + b_val*SF_100 + +delta_forward = SF_ab - SF_b + +print("Correct method TEST: SFab = aSFa + bSF100") +print("Mean |ΔF|:", np.mean(np.abs(delta_forward))) +print("Mean phase Δ (deg):", + np.mean(np.abs(np.angle(SF_ab / SF_b, deg=True)))) + +F_ab_calc, phi_ab_calc = resultant(SF_ab) +F_ab_model, phi_ab_model = resultant(SF_b) + +print("Sum SF_ab calculated: |F| = {0}, phi = {1}".format( F_ab_calc, phi_ab_calc ) ) +print("Sum SF_b model : |F| = {0}, phi = {1}".format( F_ab_model, phi_ab_model ) ) + +# ----------------------------- +# Fext test +# ----------------------------- +SF_extr = (SF_b - SF_apo)/b_val + SF_apo + +delta_Fextr = SF_extr - SF_100 + +print("Fextr SF TEST: SF_extr = (SF_b - SF_apo)/b_val + SF_apo") +print("Mean |ΔF|:", np.mean(np.abs(delta_Fextr))) +print("Mean phase Δ (deg):", + np.mean(np.abs(np.angle(SF_extr / SF_100, deg=True)))) + +F_100_calc, phi_100_calc = resultant(SF_extr) +F_100_model, phi_100_model = resultant(SF_100) + +print("Sum SF_extr calculated: |F| = {0}, phi = {1}".format( F_100_calc, phi_100_calc ) ) +print("Sum SF_100 model : |F| = {0}, phi = {1}".format( F_100_model, phi_100_model ) ) + +# ----------------------------- +# correct method test with only Fs and +# ----------------------------- +# Forward extrapolation: Fab = aFa + bFb_partial +F_ab = a_val*F_apo + b_val*F_100 +SF_ab = F_ab * np.exp(1j * phi_apo) + +delta_forward = SF_ab - SF_b + +print("Correct method TEST: Fab = aFa + bF100") +print("Mean |ΔF|:", np.mean(np.abs(delta_forward))) +print("Mean phase Δ (deg):", + np.mean(np.abs(np.angle(SF_ab / SF_b, deg=True)))) + +F_ab_calc, phi_ab_calc = resultant(SF_ab) +F_ab_model, phi_ab_model = resultant(SF_b) + +print("Sum SF_ab calculated: |F| = {0}, phi = {1}".format( F_ab_calc, phi_ab_calc ) ) +print("Sum SF_b model : |F| = {0}, phi = {1}".format( F_ab_model, phi_ab_model ) ) + +# ----------------------------- +# Fext test just Fs +# ----------------------------- +F_extr = (F_b - F_apo)/b_val + F_apo +SF_extr = F_extr * np.exp(1j * phi_apo) + +delta_Fextr = SF_extr - SF_100 + +print("Fextr TEST: SF_extr = (F_b - F_apo)/b_val + F_apo") +print("Mean |ΔF|:", np.mean(np.abs(delta_Fextr))) +print("Mean phase Δ (deg):", + np.mean(np.abs(np.angle(SF_extr / SF_100, deg=True)))) + +F_100_calc, phi_100_calc = resultant(SF_extr) +F_100_model, phi_100_model = resultant(SF_100) + +print("Sum SF_extr calculated: |F| = {0}, phi = {1}".format( F_100_calc, phi_100_calc ) ) +print("Sum SF_100 model : |F| = {0}, phi = {1}".format( F_100_model, phi_100_model ) ) + + +F_apo_model, phi_apo_model = resultant(SF_apo) + +print("Sum SF_apo : |F| = {0}, phi = {1}".format( F_apo_model, phi_apo_model ) ) + + +# ----------------------------- +# Resultant vectors for plotting +# ----------------------------- +sum_SF_apo = np.sum(SF_apo) +sum_SF_100 = np.sum(SF_100) +sum_SF_b = np.sum(SF_b) +sum_SF_ab_calc = np.sum(SF_ab) +sum_SF_100_calc = np.sum(SF_extr) + +# ----------------------------- +# Plot resultant vectors +# ----------------------------- +plt.figure(figsize=(8,8)) +ax = plt.gca() +ax.set_aspect('equal') +origin = 0 + 0j + +ax.arrow(origin.real, origin.imag, sum_SF_apo.real, sum_SF_apo.imag, + head_width=500, head_length=500, fc='blue', ec='blue', label='Fapo model') + +ax.arrow(origin.real, origin.imag, sum_SF_100.real, sum_SF_100.imag, + head_width=500, head_length=500, fc='red', ec='red', label='F100 model') + +ax.arrow(origin.real, origin.imag, sum_SF_b.real, sum_SF_b.imag, + head_width=500, head_length=500, fc='green', ec='green', label='Fab model') + +# Vector showing addition to Fa +#ax.arrow(origin.real, origin.imag, sum_SF_ab_calc.real, sum_SF_ab_calc.imag, +# head_width=500, head_length=500, fc='yellow', ec='yellow', label='Fab calc') + +ax.arrow(origin.real, origin.imag, sum_SF_100_calc.real, sum_SF_100_calc.imag, + head_width=500, head_length=500, fc='black', ec='black', label='Fextr calc') + +ax.set_xlabel("Re(F)") +ax.set_ylabel("Im(F)") +ax.set_title("SF Resultant Vectors") +ax.grid(True) +ax.legend() +plt.show() diff --git a/riso_calc.py b/riso_calc.py new file mode 100644 index 0000000..485ce67 --- /dev/null +++ b/riso_calc.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +import argparse +import sys +import numpy as np +import gemmi + + +def read_mtz_amplitudes(mtz_path, f_label): + """ + Reads an MTZ file and returns a dictionary: + {(h,k,l): amplitude} + """ + mtz = gemmi.read_mtz_file(mtz_path) + mtz.setup_batched() + + try: + col = mtz.column_with_label(f_label) + except RuntimeError: + sys.exit(f"ERROR: Column '{f_label}' not found in {mtz_path}") + + hkl_dict = {} + + for row in mtz: + h, k, l = int(row[0]), int(row[1]), int(row[2]) + fval = row[col.idx] + + if not np.isnan(fval): + hkl_dict[(h, k, l)] = float(fval) + + return hkl_dict + + +def match_reflections(dict1, dict2): + """ + Returns matched amplitude arrays for common HKLs. + """ + common_hkls = set(dict1.keys()) & set(dict2.keys()) + + if len(common_hkls) == 0: + sys.exit("ERROR: No matching HKLs between files.") + + f1 = np.array([dict1[hkl] for hkl in common_hkls]) + f2 = np.array([dict2[hkl] for hkl in common_hkls]) + + return f1, f2, len(common_hkls) + + +def compute_riso(f1, f2): + """ + Computes Riso: + Riso = sum | |F1| - |F2| | / sum |F1| + """ + numerator = np.sum(np.abs(np.abs(f1) - np.abs(f2))) + denominator = np.sum(np.abs(f1)) + + if denominator == 0: + sys.exit("ERROR: Denominator is zero.") + + return numerator / denominator + + +def main(): + parser = argparse.ArgumentParser( + description="Compute Riso between two MTZ files." + ) + + parser.add_argument("mtz1", help="Reference MTZ file") + parser.add_argument("mtz2", help="Second MTZ file") + parser.add_argument("--f1", required=True, + help="Amplitude column label in first MTZ") + parser.add_argument("--f2", required=True, + help="Amplitude column label in second MTZ") + + args = parser.parse_args() + + print("Reading MTZ files...") + dict1 = read_mtz_amplitudes(args.mtz1, args.f1) + dict2 = read_mtz_amplitudes(args.mtz2, args.f2) + + print("Matching reflections...") + f1, f2, nmatch = match_reflections(dict1, dict2) + + print(f"Matched reflections: {nmatch}") + + riso_value = compute_riso(f1, f2) + + print(f"\nRiso = {riso_value:.6f}") + + +if __name__ == "__main__": + main()