hspark1212 commited on
Commit
40fc2b7
·
1 Parent(s): 1511665

initial commit

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. app.py +295 -0
  3. assets/beaker.svg +1 -0
  4. assets/logo.gif +0 -0
  5. assets/logo_static.jpg +0 -0
  6. assets/style.css +4 -0
  7. requirements.txt +1 -0
  8. utils.py +44 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import random
4
+ import base64
5
+ from io import BytesIO
6
+ from fire import Fire
7
+
8
+ import streamlit as st
9
+ from ase.atoms import Atoms
10
+ from ase.build import bulk
11
+ from ase.io import write
12
+ from chemeleon import Chemeleon
13
+ from chemeleon.visualize import Visualizer
14
+
15
+ from utils import dict_to_atoms
16
+
17
+ # Constants
18
+ TIMESTEPS = 1000
19
+ TRAJECTORY_STEPS = 100
20
+ DEFAULT_NUM_SAMPLES = 3
21
+ DEMO = False
22
+
23
+ # Set page configuration
24
+ st.set_page_config(page_title="Chemeleon", layout="wide")
25
+
26
+ # Hide Streamlit's default menu and footer for a cleaner look
27
+ hide_streamlit_style = """
28
+ <style>
29
+ #MainMenu {visibility: hidden;}
30
+ footer {visibility: hidden;}
31
+ </style>
32
+ """
33
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
34
+
35
+
36
+ def demo_generator_structures(num_atoms, text_input, num_samples):
37
+ """
38
+ Generate crystal structures for demonstration purposes.
39
+ """
40
+ elements = random.choices(["Si", "Ge", "C", "Na", "Cl"], k=num_samples)
41
+ random_elements = random.choices(elements, k=num_atoms)
42
+
43
+ for step in range(TIMESTEPS):
44
+ time.sleep(0.001)
45
+ random_atoms = Atoms(
46
+ "Li",
47
+ positions=[[random.random() * 5 for _ in range(3)]],
48
+ )
49
+ atoms_list = [bulk(element, "fcc", a=5.43) for element in random_elements]
50
+ new_atoms_list = []
51
+
52
+ for atoms in atoms_list:
53
+ # Adding random atoms to each bulk structure
54
+ combined_atoms = atoms + random_atoms
55
+ new_atoms_list.append(combined_atoms)
56
+
57
+ yield new_atoms_list
58
+
59
+
60
+ def generator_structures_chemeleon(
61
+ num_atoms, test_input, num_samples, use_client=False
62
+ ):
63
+ """
64
+ Generate crystal structures based on the given number of atoms and input text.
65
+ """
66
+ if use_client:
67
+ response = client(
68
+ url="https://8000-01j80snre5xdhq828s1q5brs0m.cloudspaces.litng.ai/predict",
69
+ n_samples=num_samples,
70
+ n_atoms=num_atoms,
71
+ text_input=test_input,
72
+ )
73
+
74
+ for line in response.iter_lines():
75
+ output = json.loads(line)["output"]
76
+ atom_dict = json.loads(output)
77
+ atoms_list = [dict_to_atoms(atoms_dict) for atoms_dict in atom_dict]
78
+ yield atoms_list
79
+ else:
80
+ chemeleon = Chemeleon.load_general_text_model()
81
+ for atoms_list in chemeleon.sample(
82
+ text_input=test_input,
83
+ n_atoms=num_atoms,
84
+ n_samples=num_samples,
85
+ stream=True,
86
+ ):
87
+ yield atoms_list
88
+
89
+
90
+ def visualize_structure(atoms):
91
+ """
92
+ Visualize the given atomic structure using Plotly.
93
+ """
94
+ visualizer = Visualizer([atoms], atomic_size=0.6, resolution=20)
95
+ fig = visualizer.view()
96
+ return fig
97
+
98
+
99
+ def visualize_trajectory(atoms_list):
100
+ """
101
+ Visualize the given atomic structure trajectory using Plotly.
102
+ """
103
+ visualizer = Visualizer(atoms_list, atomic_size=0.6, resolution=20)
104
+ fig = visualizer.view_trajectory(duration=1000)
105
+ return fig
106
+
107
+
108
+ # Main application function
109
+ def main(use_client=False):
110
+ # Initialize session state
111
+ if "structures" not in st.session_state:
112
+ st.session_state.structures = []
113
+ if "trajectory" not in st.session_state:
114
+ st.session_state.trajectory = []
115
+ if "progress_in_generating" not in st.session_state:
116
+ st.session_state["progress_in_generating"] = False
117
+
118
+ # Sidebar for user inputs
119
+ with st.sidebar:
120
+ st.image("assets/logo_static.jpg", width=200)
121
+ st.markdown(
122
+ """
123
+ <h1 style='text-align: center; color: #4CAF50;'>Chemeleon</h1>
124
+ <h3 style='text-align: center;'>A text-guided diffusion model for crystal structure generation</h3>
125
+ """,
126
+ unsafe_allow_html=True,
127
+ )
128
+ st.markdown("---")
129
+ description = st.text_input(
130
+ "Input your text prompt to generate crystal structures",
131
+ "A Crystal Structure of LiMnO4 with orthorhombic symmetry",
132
+ help="Examples: 'LiMnO4' or 'A Crystal Structure of BaTiO3 with cubic symmetry'",
133
+ )
134
+ num_atoms = st.slider(
135
+ "🔢 Number of Atoms:",
136
+ min_value=1,
137
+ max_value=20,
138
+ value=6,
139
+ help="Select the number of atoms in the unit cell.",
140
+ )
141
+ num_samples = st.number_input(
142
+ "🧪 Number of Samples:",
143
+ min_value=1,
144
+ max_value=5,
145
+ value=DEFAULT_NUM_SAMPLES,
146
+ step=1,
147
+ help="Determine how many structure samples to generate.",
148
+ )
149
+
150
+ # Generate Structures when button is clicked
151
+ if st.session_state["progress_in_generating"]:
152
+ # Clear previous structures
153
+ st.session_state.structures = []
154
+ st.session_state.trajectory = []
155
+
156
+ # Initialize progress bar in the sidebar
157
+ progress_placeholder = st.empty()
158
+ progress_bar = progress_placeholder.progress(0)
159
+
160
+ # Initialize loading animation
161
+ image_placeholder = st.empty()
162
+ with st.spinner("Generating structures..."):
163
+ with image_placeholder:
164
+ data_url = base64.b64encode(
165
+ open("assets/logo.gif", "rb").read()
166
+ ).decode()
167
+ image_placeholder.markdown(
168
+ f'<img src="data:image/gif;base64,{data_url}" width=100>',
169
+ unsafe_allow_html=True,
170
+ )
171
+
172
+ # Generate structures
173
+ trajectory = []
174
+ if DEMO:
175
+ generator = demo_generator_structures(num_atoms, description, num_samples)
176
+ else:
177
+ generator = generator_structures_chemeleon(
178
+ num_atoms, description, num_samples, use_client
179
+ )
180
+ for step, atoms_list in enumerate(generator):
181
+ progress_bar.progress((step + 1) / TIMESTEPS)
182
+ if step % TRAJECTORY_STEPS == 0 or step == TIMESTEPS - 1:
183
+ st.session_state.structures = atoms_list
184
+ trajectory.append(atoms_list)
185
+
186
+ st.session_state.trajectory = trajectory
187
+
188
+ # Remove the progress bar
189
+ progress_placeholder.empty()
190
+
191
+ # Remove the loading animation
192
+ image_placeholder.empty()
193
+
194
+ # Reset the progress state
195
+ st.session_state["progress_in_generating"] = False
196
+
197
+ # Display success message
198
+ st.sidebar.success("✨ Structures generated successfully!")
199
+
200
+ with st.sidebar:
201
+ if st.button(
202
+ "Generate Structures 🚀",
203
+ disabled=st.session_state["progress_in_generating"],
204
+ ):
205
+ st.session_state["progress_in_generating"] = True
206
+ st.rerun()
207
+
208
+ # Check if structures are generated
209
+ if st.session_state.structures:
210
+ # Tabs for visualization
211
+ tabs = st.tabs(["Structure Visualization", "Trajectory Analysis"])
212
+
213
+ # Structure Visualization Tab
214
+ with tabs[0]:
215
+ col1, col2 = st.columns([1, 3])
216
+ with col1:
217
+ st.session_state.selected_sample_index = (
218
+ st.radio(
219
+ "Select Sample",
220
+ options=list(range(1, num_samples + 1)),
221
+ index=0,
222
+ help="Choose which sample to visualize.",
223
+ )
224
+ - 1
225
+ ) # Adjust for zero-based indexing
226
+ # Download file
227
+ atoms = st.session_state.structures[
228
+ st.session_state.selected_sample_index
229
+ ]
230
+ buffer = BytesIO()
231
+ write(buffer, atoms, format="cif")
232
+ buffer.seek(0)
233
+ st.download_button(
234
+ label="Download CIF File",
235
+ data=buffer,
236
+ file_name=f"{str(atoms.symbols)}.cif",
237
+ mime="chemical/cif",
238
+ )
239
+ with col2:
240
+ atoms = st.session_state.structures[
241
+ st.session_state.selected_sample_index
242
+ ]
243
+ fig = visualize_structure(atoms)
244
+ st.plotly_chart(fig, use_container_width=True)
245
+
246
+ # Trajectory Analysis Tab
247
+ with tabs[1]:
248
+ if st.session_state.trajectory:
249
+ trajectory = [
250
+ traj[st.session_state.selected_sample_index]
251
+ for traj in st.session_state.trajectory
252
+ ]
253
+ tabs_2 = st.tabs(["Animation", "Step View"])
254
+ # Animation
255
+ with tabs_2[0]:
256
+ fig = visualize_trajectory(trajectory)
257
+ st.plotly_chart(fig, use_container_width=True)
258
+ # Slider
259
+ with tabs_2[1]:
260
+ trajectory_index = st.slider(
261
+ "Select Trajectory Step",
262
+ min_value=0,
263
+ max_value=len(trajectory) - 1,
264
+ value=0,
265
+ step=1,
266
+ help="Navigate through different steps of the structure generation.",
267
+ )
268
+ selected_atoms = trajectory[trajectory_index]
269
+ trajectory_fig = visualize_structure(selected_atoms)
270
+ st.plotly_chart(trajectory_fig, use_container_width=True)
271
+ else:
272
+ st.info("No trajectory data available.")
273
+
274
+ # Footer
275
+ st.markdown(
276
+ """
277
+ <div style="text-align: center; color: grey; margin-top: 50px;">
278
+ <p style="font-size: 14px; margin: 0;">
279
+ Developed by
280
+ <a href="https://hspark1212.github.io" target="_blank">Hyunsoo Park</a>,
281
+ as a part of <a href="https://github.com/wmd-group" target="_blank">Materials Design Group</a>
282
+ at Imperial College London
283
+ </p>
284
+ <p>
285
+ <a href="https://chemrxiv.org/engage/chemrxiv/article-details/6728e27cf9980725cf118177" target="_blank">Research Paper</a> |
286
+ <a href="https://github.com/hspark1212/chemeleon" target="_blank">Repository</a>
287
+ </p>
288
+ </div>
289
+ """,
290
+ unsafe_allow_html=True,
291
+ )
292
+
293
+
294
+ if __name__ == "__main__":
295
+ Fire(main) # Usage example: streamlit run app/streamlit_app.py -- --use_client=True
assets/beaker.svg ADDED
assets/logo.gif ADDED
assets/logo_static.jpg ADDED
assets/style.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /* In your assets/style.css */
2
+ #logo-image {
3
+ transition: all 0.5s ease-in-out;
4
+ }
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chemeleon
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from ase import Atoms
3
+ import plotly.graph_objects as go
4
+
5
+ empty_fig = go.Figure(
6
+ data=[
7
+ go.Scatter(
8
+ x=[],
9
+ y=[],
10
+ mode="markers",
11
+ marker=dict(color="rgba(0,0,0,0)"),
12
+ )
13
+ ],
14
+ layout=go.Layout(
15
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
16
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
17
+ template="plotly_white",
18
+ ),
19
+ )
20
+
21
+
22
+ def atoms_to_dict(atoms):
23
+ """
24
+ Converts an ASE Atoms object to a dictionary for storing.
25
+ """
26
+ return {
27
+ "symbols": atoms.get_chemical_symbols(),
28
+ "positions": atoms.get_positions().tolist(),
29
+ "cell": atoms.get_cell().tolist(),
30
+ "pbc": atoms.get_pbc().tolist(),
31
+ }
32
+
33
+
34
+ def dict_to_atoms(data):
35
+ """
36
+ Converts a dictionary back to an ASE Atoms object.
37
+ """
38
+ atoms = Atoms(
39
+ symbols=data["symbols"],
40
+ positions=np.array(data["positions"]),
41
+ cell=np.array(data["cell"]),
42
+ pbc=np.array(data["pbc"]),
43
+ )
44
+ return atoms