|
import gradio as gr |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
from pymatgen.core import Structure |
|
from pymatgen.analysis.diffraction.xrd import XRDCalculator |
|
import tempfile |
|
import os |
|
import traceback |
|
|
|
|
|
def generate_xrd_pattern(cif_file): |
|
""" |
|
Processes an uploaded CIF file, calculates the XRD pattern, |
|
and returns a Plotly figure, a Pandas DataFrame, and the path to a CSV file. |
|
|
|
Args: |
|
cif_file: A file object from Gradio's gr.File component. |
|
|
|
Returns: |
|
tuple: (plotly_fig, dataframe, csv_filepath) or (None, None, None) if processing fails. |
|
plotly_fig: A Plotly figure object. |
|
dataframe: A Pandas DataFrame containing the peak data. |
|
csv_filepath: Path to the generated temporary CSV file. |
|
""" |
|
if cif_file is None: |
|
|
|
return None, None, None |
|
|
|
try: |
|
|
|
cif_filepath = cif_file.name |
|
|
|
|
|
structure = Structure.from_file(cif_filepath) |
|
|
|
|
|
calculator = XRDCalculator() |
|
pattern = calculator.get_pattern(structure, two_theta_range=(10, 90)) |
|
|
|
|
|
miller_indices = [] |
|
for hkl_list in pattern.hkls: |
|
if hkl_list: |
|
|
|
|
|
|
|
miller_indices.append(str(tuple(hkl_list[0]['hkl']))) |
|
|
|
else: |
|
miller_indices.append("N/A") |
|
|
|
|
|
two_theta_rounded = [round(x, 3) for x in pattern.x] |
|
intensity_rounded = [round(y, 3) for y in pattern.y] |
|
|
|
data = pd.DataFrame({ |
|
"2θ (°)": two_theta_rounded, |
|
"Intensity (norm)": intensity_rounded, |
|
"Miller Indices (hkl)": miller_indices |
|
}) |
|
|
|
|
|
fig = go.Figure() |
|
|
|
fig.add_trace(go.Bar( |
|
x=data["2θ (°)"], |
|
y=data["Intensity (norm)"], |
|
hovertext=[f"2θ: {t:.3f}<br>Intensity: {i:.1f}<br>hkl: {m}" |
|
for t, i, m in zip(data["2θ (°)"], data["Intensity (norm)"], data["Miller Indices (hkl)"])], |
|
hoverinfo="text", |
|
width=0.1, |
|
marker_color="#4682B4", |
|
marker_line_width=0, |
|
name='Peaks' |
|
)) |
|
|
|
|
|
max_intensity = data["Intensity (norm)"].max() if not data.empty else 100 |
|
min_2theta = data["2θ (°)"].min() if not data.empty else 10 |
|
max_2theta = data["2θ (°)"].max() if not data.empty else 90 |
|
|
|
fig.update_layout( |
|
title=dict(text=f"Simulated XRD Pattern: {structure.formula}", x=0.5, xanchor='center'), |
|
xaxis_title="2θ (°)", |
|
yaxis_title="Intensity (Arb. Unit)", |
|
xaxis_title_font_size=16, |
|
yaxis_title_font_size=16, |
|
xaxis=dict( |
|
range=[min_2theta - 2, max_2theta + 2], |
|
showline=True, linewidth=1.5, linecolor='black', mirror=True, |
|
ticks='outside', tickwidth=1.5, tickcolor='black', |
|
tickfont_size=12 |
|
), |
|
yaxis=dict( |
|
range=[0, max_intensity * 1.05], |
|
showline=True, linewidth=1.5, linecolor='black', mirror=True, |
|
ticks='outside', tickwidth=1.5, tickcolor='black', |
|
tickfont_size=12 |
|
), |
|
plot_bgcolor='white', |
|
paper_bgcolor='white', |
|
bargap=0.9, |
|
font=dict(family="Arial, sans-serif", size=12, color="black"), |
|
margin=dict(l=70, r=30, t=60, b=70), |
|
|
|
height=450, |
|
|
|
) |
|
fig.update_xaxes(showgrid=False, zeroline=False) |
|
fig.update_yaxes(showgrid=False, zeroline=False) |
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', newline='', encoding='utf-8') as temp_csv: |
|
data.to_csv(temp_csv.name, index=False) |
|
csv_filepath_out = temp_csv.name |
|
|
|
|
|
return fig, data, csv_filepath_out |
|
|
|
except Exception as e: |
|
print(f"Error processing file: {e}") |
|
traceback.print_exc() |
|
|
|
raise gr.Error(f"Failed to process CIF file. Please ensure it's a valid CIF. Error: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
theme = gr.themes.Soft( |
|
primary_hue="sky", |
|
secondary_hue="blue", |
|
neutral_hue="slate" |
|
) |
|
|
|
with gr.Blocks(theme=theme, title="XRD Pattern Generator") as demo: |
|
gr.Markdown( |
|
""" |
|
# XRD Pattern Simulator from CIF |
|
Upload a Crystallographic Information File (.cif) to generate its simulated |
|
X-ray Diffraction (XRD) pattern using [pymatgen](https://github.com/materialsproject/pymatgen). |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
cif_input = gr.File( |
|
label="Upload CIF File", |
|
file_types=[".cif"], |
|
type="filepath" |
|
) |
|
gr.Markdown("*(Example source: [Crystallography Open Database](http://crystallography.net/cod/))*") |
|
|
|
with gr.Column(scale=3): |
|
with gr.Tabs(): |
|
with gr.TabItem("📊 XRD Plot"): |
|
|
|
|
|
|
|
plot_output = gr.Plot(label="XRD Pattern") |
|
|
|
with gr.TabItem("📄 Peak Data Table"): |
|
dataframe_output = gr.DataFrame( |
|
label="Calculated Peak Data", |
|
headers=["2θ (°)", "Intensity (norm)", "Miller Indices (hkl)"], |
|
wrap=True, |
|
|
|
|
|
) |
|
|
|
with gr.TabItem("⬇️ Download Data"): |
|
csv_output = gr.File(label="Download Peak Data as CSV") |
|
gr.Markdown("Click the link above to download the full data.") |
|
|
|
|
|
|
|
cif_input.clear( |
|
lambda: (None, None, None), |
|
inputs=[], |
|
outputs=[plot_output, dataframe_output, csv_output] |
|
) |
|
|
|
|
|
cif_input.change( |
|
fn=generate_xrd_pattern, |
|
inputs=cif_input, |
|
outputs=[plot_output, dataframe_output, csv_output], |
|
|
|
) |
|
examples = gr.Examples( |
|
examples=[ |
|
["example_cif/NaCl_1000041.cif"], |
|
["example_cif/Al2O3_1000017.cif"], |
|
], |
|
inputs=[cif_input], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|