from pathlib import Path import numpy as np import pandas as pd import plotly.colors as pcolors import plotly.graph_objects as go import streamlit as st from ase.data import chemical_symbols from plotly.subplots import make_subplots from scipy.interpolate import CubicSpline color_sequence = pcolors.qualitative.Plotly st.markdown("# Homonuclear diatomics") # button to toggle plots container = st.container(border=True) energy_plot = container.checkbox("Show energy curves", value=True) force_plot = container.checkbox("Show force curves", value=True) ncols = 2 DATA_DIR = Path("mlip_arena/tasks/diatomics") mlips = ["MACE-MP", "CHGNet"] dfs = [pd.read_json(DATA_DIR / mlip.lower() / "homonuclear-diatomics.json") for mlip in mlips] df = pd.concat(dfs, ignore_index=True) df.drop_duplicates(inplace=True, subset=["name", "method"]) for i, symbol in enumerate(chemical_symbols[1:]): if i % ncols == 0: cols = st.columns(ncols) rows = df[df["name"] == symbol + symbol] if rows.empty: continue fig = make_subplots(specs=[[{"secondary_y": True}]]) elo, flo = float("inf"), float("inf") for j, method in enumerate(rows["method"].unique()): row = rows[rows["method"] == method].iloc[0] rs = np.array(row["R"]) es = np.array(row["E"]) fs = np.array(row["F"]) rs = np.array(rs) ind = np.argsort(rs) es = np.array(es) fs = np.array(fs) rs = rs[ind] es = es[ind] es = es - es[-1] fs = fs[ind] xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2)) if energy_plot: cs = CubicSpline(rs, es) ys = cs(xs) elo = min(elo, ys.min()*1.2, -1) fig.add_trace( go.Scatter( x=xs, y=ys, mode="lines", line=dict( color=color_sequence[j % len(color_sequence)], width=2, ), name=method, ), secondary_y=False, ) if force_plot: cs = CubicSpline(rs, fs) ys = cs(xs) flo = min(flo, ys.min()*1.2) fig.add_trace( go.Scatter( x=xs, y=ys, mode="lines", line=dict( color=color_sequence[j % len(color_sequence)], width=1, dash="dot", ), name=method, showlegend=False if energy_plot else True, ), secondary_y=True, ) fig.update_layout( showlegend=True, title_text=f"{symbol}-{symbol}", title_x=0.5, # yaxis_range=[ylo, 2*(abs(ylo))], ) # Set x-axis title fig.update_xaxes(title_text="Bond length (Å)") # Set y-axes titles if energy_plot: fig.update_layout( yaxis=dict( title=dict(text="Energy [eV]"), side="left", range=[elo, 2*(abs(elo))], ) ) # fig.update_yaxes(title_text="Energy [eV]", secondary_y=False) if force_plot: fig.update_layout( yaxis2=dict( title=dict(text="Force [eV/Å]"), side="right", range=[flo, 2*(abs(flo))], overlaying="y", tickmode="sync", ), ) # fig.update_yaxes(title_text="Force [eV/Å]", secondary_y=True) # cols[i % ncols].title(f"{row['name']}") cols[i % ncols].plotly_chart(fig, use_container_width=True, height=250)