import numpy as np import gemmi import matplotlib.pyplot as plt def riso_xtrapol8(F1, F2): A1 = np.abs(F1) A2 = np.abs(F2) denom = 0.5 * (A1 + A2) mask = denom > 0 if np.sum(mask) == 0: return np.nan num = np.sum(np.abs(A1[mask] - A2[mask])) den = np.sum(denom[mask]) return num / den def cciso_xtrapol8(F1, F2): A1 = np.abs(F1) A2 = np.abs(F2) mask = np.isfinite(A1) & np.isfinite(A2) A1 = A1[mask] A2 = A2[mask] if len(A1) < 2: return np.nan A1m = A1 - np.mean(A1) A2m = A2 - np.mean(A2) num = np.sum(A1m * A2m) den = np.sqrt(np.sum(A1m**2) * np.sum(A2m**2)) return num / den if den > 0 else np.nan def read_mtz_amplitudes(mtz_file, amp_label): mtz = gemmi.read_mtz_file(mtz_file) H = np.array(mtz.column_with_label('H'), dtype=int) K = np.array(mtz.column_with_label('K'), dtype=int) L = np.array(mtz.column_with_label('L'), dtype=int) hkl = np.vstack([H, K, L]).T F = np.array(mtz.column_with_label(amp_label), dtype=float) cell = mtz.cell return hkl, F, cell def match_reflections(hkl1, F1, hkl2, F2): map1 = {tuple(h): f for h, f in zip(hkl1, F1)} map2 = {tuple(h): f for h, f in zip(hkl2, F2)} common = sorted(set(map1.keys()) & set(map2.keys())) F1m = np.array([map1[h] for h in common]) F2m = np.array([map2[h] for h in common]) hklm = np.array(common, dtype=int) return hklm, F1m, F2m def compute_resolution(cell, hkl): return np.array([cell.calculate_d(list(h)) for h in hkl]) def make_equal_count_bins(d, n_bins): order = np.argsort(d)[::-1] # low → high resolution bins = np.array_split(order, n_bins) return bins def make_equal_count_resolution_bins(d, n_bins): """ Returns a list of index arrays, each containing roughly the same number of reflections. """ order = np.argsort(d)[::-1] # low → high resolution return np.array_split(order, n_bins) def riso_cciso_by_resolution(d, F1, F2, n_bins=20): """ Compute Xtrapol8 Riso and CCiso in equal-count resolution bins. Returns: d_mid : midpoint resolution per bin Riso : array CCiso : array counts : reflections per bin """ bins = make_equal_count_resolution_bins(d, n_bins) d_mid = [] Riso = [] CCiso = [] counts = [] for idx in bins: if len(idx) < 10: continue d_bin = d[idx] F1b = F1[idx] F2b = F2[idx] d_mid.append(0.5 * (d_bin.min() + d_bin.max())) Riso.append(riso_xtrapol8(F1b, F2b)) CCiso.append(cciso_xtrapol8(F1b, F2b)) counts.append(len(idx)) return ( np.array(d_mid), np.array(Riso), np.array(CCiso), np.array(counts), ) def plot_riso_cciso_vs_resolution(d_mid, Riso, CCiso): fig, ax1 = plt.subplots(figsize=(6, 4)) # Riso ax1.plot(1/d_mid, Riso, marker='o', label='Riso') ax1.set_xlabel('Resolution (Å)') ax1.set_ylabel('Riso') ax1.grid(True, alpha=0.3) # CCiso ax2 = ax1.twinx() ax2.plot(1/d_mid, CCiso, marker='s', linestyle='--', label='CCiso') ax2.set_ylabel('CCiso') ax2.set_ylim(0, 1) # Combined legend lines, labels = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax1.legend(lines + lines2, labels + labels2, loc='best') plt.tight_layout() plt.show() def riso_cciso_from_mtz( mtz1, mtz2, amp1_label, amp2_label, n_bins=None ): hkl1, F1, cell1 = read_mtz_amplitudes(mtz1, amp1_label) hkl2, F2, cell2 = read_mtz_amplitudes(mtz2, amp2_label) hkl, F1m, F2m = match_reflections(hkl1, F1, hkl2, F2) d = compute_resolution(cell1, hkl) print(f"Matched reflections: {len(hkl)}") # Overall statistics Riso = riso_xtrapol8(F1m, F2m) CCiso = cciso_xtrapol8(F1m, F2m) print(f"\nOverall:") print(f" Riso = {Riso:.4f}") print(f" CCiso = {CCiso:.4f}") # Per-bin statistics if n_bins is not None: print(f"\nPer-resolution-bin statistics:") bins = make_equal_count_bins(d, n_bins) for i, idx in enumerate(bins): if len(idx) < 10: continue Rb = riso_xtrapol8(F1m[idx], F2m[idx]) CCb = cciso_xtrapol8(F1m[idx], F2m[idx]) dmin = np.min(d[idx]) dmax = np.max(d[idx]) print( f"Bin {i+1:02d} " f"{dmax:5.2f}–{dmin:5.2f} Å " f"Riso={Rb:6.3f} CCiso={CCb:6.3f} n={len(idx)}" ) d_mid, Rb, CCb, nref = riso_cciso_by_resolution( d, F1m, F2m, n_bins=20 ) plot_riso_cciso_vs_resolution(d_mid, Rb, CCb) return Riso, CCiso mtz1 = "../data/apo_100k_refine_5.mtz" mtz2 = "../data/on_100k_refine_141.mtz" amp1_label = "F-obs-filtered" amp2_label = "F-obs-filtered" riso_cciso_from_mtz( mtz1, mtz2, amp1_label, amp2_label, n_bins=None)