Vokturz commited on
Commit
6826b0f
·
1 Parent(s): f6e2171

cached model list to memory (exclude falcon)

Browse files
Files changed (1) hide show
  1. src/app.py +18 -6
src/app.py CHANGED
@@ -47,14 +47,24 @@ st.markdown(
47
  """,
48
  unsafe_allow_html=True,
49
  )
 
 
 
 
 
 
 
 
 
 
50
  @st.cache_resource
51
  def get_gpu_specs():
52
  return pd.read_csv("data/gpu_specs.csv")
53
 
54
- @st.cache_resource
55
- def get_mistralai_table():
56
- model = get_model("mistralai/Mistral-7B-v0.1", library="transformers", access_token="")
57
- return calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
58
 
59
  def show_gpu_info(info, trainable_params=0, vendor=""):
60
  for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
@@ -91,7 +101,9 @@ def get_name(index):
91
 
92
  def custom_ceil(a, precision=0):
93
  return np.round(a + 0.5 * 10**(-precision), precision)
 
94
  gpu_specs = get_gpu_specs()
 
95
 
96
  _, col, _ = st.columns([1,3,1])
97
  with col.expander("Information", expanded=True):
@@ -123,8 +135,8 @@ if model_name not in st.session_state:
123
  del st.session_state[st.session_state['actual_model']]
124
  del st.session_state['actual_model']
125
  gc.collect()
126
- if model_name == "mistralai/Mistral-7B-v0.1": # cache Mistral
127
- st.session_state[model_name] = get_mistralai_table()
128
  else:
129
  model = get_model(model_name, library="transformers", access_token=access_token)
130
  st.session_state[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
 
47
  """,
48
  unsafe_allow_html=True,
49
  )
50
+
51
+ @st.cache_resource()
52
+ def cache_model_list():
53
+ model_list_info = {}
54
+ for model_name in model_list:
55
+ if not "tiiuae/falcon" in model_name: # Exclude Falcon models
56
+ model = get_model(model_name, library="transformers", access_token="")
57
+ model_list_info[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
58
+ return model_list_info
59
+
60
  @st.cache_resource
61
  def get_gpu_specs():
62
  return pd.read_csv("data/gpu_specs.csv")
63
 
64
+ # @st.cache_resource
65
+ # def get_mistralai_table():
66
+ # model = get_model("mistralai/Mistral-7B-v0.1", library="transformers", access_token="")
67
+ # return calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
68
 
69
  def show_gpu_info(info, trainable_params=0, vendor=""):
70
  for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
 
101
 
102
  def custom_ceil(a, precision=0):
103
  return np.round(a + 0.5 * 10**(-precision), precision)
104
+
105
  gpu_specs = get_gpu_specs()
106
+ model_list_info = cache_model_list()
107
 
108
  _, col, _ = st.columns([1,3,1])
109
  with col.expander("Information", expanded=True):
 
135
  del st.session_state[st.session_state['actual_model']]
136
  del st.session_state['actual_model']
137
  gc.collect()
138
+ if model_name in model_list_info.keys():
139
+ st.session_state[model_name] = model_list_info[model_name]
140
  else:
141
  model = get_model(model_name, library="transformers", access_token=access_token)
142
  st.session_state[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])