Floki00 commited on
Commit
1ff7cbe
1 Parent(s): 722ba78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -9,26 +9,26 @@ from genQC.util import infer_torch_device
9
  #--------------------------------
10
  # download model into storage
11
 
12
- save_destination = "saves/"
 
 
13
 
14
- url_config = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
15
- url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"
 
 
 
16
 
17
- def download(url, dst_dir):
18
- if not os.path.exists(dst_dir): os.mkdir(dst_dir)
19
- filename = os.path.join(dst_dir, os.path.basename(url))
20
- if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
21
- return filename
22
-
23
- config_file = download(url_config, save_destination)
24
- weigths_file = download(url_weights, save_destination)
25
 
26
  #--------------------------------
27
  # setup
28
 
29
  @st.cache_resource
30
  def load_pipeline():
31
- pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())
 
32
  pipeline.scheduler.set_timesteps(20)
33
  return pipeline
34
 
@@ -56,7 +56,7 @@ def get_qcs(srv, num_of_qubits, max_gates, g):
56
 
57
  for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()):
58
  ax.clear()
59
- qc.draw("mpl", plot_barriers=False, ax=ax, style="clifford")
60
  ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
61
  status.update(label="Generation complete!", state="complete", expanded=False)
62
 
@@ -76,10 +76,10 @@ Generating quantum circuits with diffusion models. Official demo of [[paper-arxi
76
 
77
  col1, col2 = st.columns(2)
78
 
79
- srv = col1.text_input('SRV', "[1,1,1,2,2]")
80
- num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=2)
81
  max_gates = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
82
- g = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=7.5)
83
 
84
  srv_list = ast.literal_eval(srv)
85
  if len(srv_list)!=num_of_qubits:
 
9
  #--------------------------------
10
  # download model into storage
11
 
12
+ #save_destination = "saves/"
13
+ #url_config = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/config.yaml"
14
+ #url_weights = "https://github.com/FlorianFuerrutter/genQC/blob/044f7da6ebe907bd796d3db293024db223cc1852/saves/qc_unet_config_SRV_3to8_qubit/model.pt"
15
 
16
+ #def download(url, dst_dir):
17
+ # if not os.path.exists(dst_dir): os.mkdir(dst_dir)
18
+ # filename = os.path.join(dst_dir, os.path.basename(url))
19
+ # if not os.path.exists(filename): filename = wget.download(url + "?raw=true", out=filename)
20
+ # return filename
21
 
22
+ #config_file = download(url_config, save_destination)
23
+ #weigths_file = download(url_weights, save_destination)
 
 
 
 
 
 
24
 
25
  #--------------------------------
26
  # setup
27
 
28
  @st.cache_resource
29
  def load_pipeline():
30
+ #pipeline = DiffusionPipeline.from_config_file(save_destination, infer_torch_device())
31
+ pipeline = DiffusionPipeline.from_pretrained("Floki00/qc_srv_3to8qubit", "cpu")
32
  pipeline.scheduler.set_timesteps(20)
33
  return pipeline
34
 
 
56
 
57
  for qc,is_svr,ax in zip(qc_list, srv_list, axs.flatten()):
58
  ax.clear()
59
+ qc.draw("mpl", plot_barriers=False, ax=ax)
60
  ax.set_title(f"{'Correct' if is_svr==srv else 'NOT correct'}, is SRV = {is_svr}")
61
  status.update(label="Generation complete!", state="complete", expanded=False)
62
 
 
76
 
77
  col1, col2 = st.columns(2)
78
 
79
+ srv = col1.text_input('SRV', "[1,1,1,2,2,2]")
80
+ num_of_qubits = col1.radio('Number of qubits (should match SRV)', [3,4,5,6,7,8], index=3)
81
  max_gates = col1.select_slider('Max gates', options=[4,8,12,16,20,24,28], value=16)
82
+ g = col1.slider('Guidance scale', min_value=0.0, max_value=15.0, value=10)
83
 
84
  srv_list = ast.literal_eval(srv)
85
  if len(srv_list)!=num_of_qubits: