pretzinger commited on
Commit
aacdac0
·
1 Parent(s): efd757b

Enhance app.py with error handling and logging

Browse files
Files changed (3) hide show
  1. app.py +94 -2
  2. requirements.txt +3 -0
  3. style.css +52 -0
app.py CHANGED
@@ -1,4 +1,96 @@
1
  import streamlit as st
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import os
5
 
6
+ # Apply custom CSS for retro 80s green theme
7
+ def apply_custom_css():
8
+ try:
9
+ with open("style.css") as f:
10
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
11
+ except FileNotFoundError:
12
+ st.warning("style.css not found. Using default styles.")
13
+
14
+ @st.cache_resource
15
+ def load_model():
16
+ model_path = "HuggingFaceH4/zephyr-7b-beta"
17
+ peft_model_path = "yitzashapiro/FDA-guidance-zephyr-7b-beta-PEFT"
18
+
19
+ try:
20
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_path,
23
+ device_map="auto",
24
+ torch_dtype=torch.float16 # Adjust if necessary
25
+ ).eval()
26
+ model.load_adapter(peft_model_path)
27
+ st.success("Model loaded successfully.")
28
+ except Exception as e:
29
+ st.error(f"Error loading model: {e}")
30
+ st.stop()
31
+
32
+ return tokenizer, model
33
+
34
+ def generate_response(tokenizer, model, user_input):
35
+ messages = [
36
+ {"role": "user", "content": user_input}
37
+ ]
38
+
39
+ try:
40
+ if hasattr(tokenizer, 'apply_chat_template'):
41
+ input_ids = tokenizer.apply_chat_template(
42
+ conversation=messages,
43
+ max_length=45,
44
+ tokenize=True,
45
+ add_generation_prompt=True,
46
+ return_tensors='pt'
47
+ )
48
+ else:
49
+ input_ids = tokenizer(
50
+ user_input,
51
+ return_tensors='pt',
52
+ truncation=True,
53
+ max_length=45
54
+ )['input_ids']
55
+
56
+ pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
57
+ attention_mask = (input_ids != pad_token_id).long()
58
+
59
+ output_ids = model.generate(
60
+ input_ids.to(model.device),
61
+ max_length=2048,
62
+ max_new_tokens=500,
63
+ attention_mask=attention_mask.to(model.device)
64
+ )
65
+
66
+ response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
67
+ return response
68
+ except Exception as e:
69
+ st.error(f"Error generating response: {e}")
70
+ return "An error occurred while generating the response."
71
+
72
+ def main():
73
+ apply_custom_css()
74
+
75
+ st.set_page_config(page_title="FDA NDA Submission Assistant", layout="centered")
76
+ st.title("FDA NDA Submission Assistant")
77
+ st.write("Ask the model about submitting an NDA to the FDA.")
78
+
79
+ tokenizer, model = load_model()
80
+
81
+ user_input = st.text_input("Enter your question:", "What's the best way to submit an NDA to the FDA?")
82
+
83
+ if st.button("Generate Response"):
84
+ if user_input.strip() == "":
85
+ st.error("Please enter a valid question.")
86
+ else:
87
+ try:
88
+ with st.spinner("Generating response..."):
89
+ response = generate_response(tokenizer, model, user_input)
90
+ st.success("Response:")
91
+ st.write(response)
92
+ except Exception as e:
93
+ st.error(f"An error occurred: {e}")
94
+
95
+ if __name__ == "__main__":
96
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit==1.24.1
2
+ transformers==4.30.2
3
+ torch==2.0.1
style.css ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* style.css */
2
+
3
+ /* Set the background color to black */
4
+ body {
5
+ background-color: #000000;
6
+ color: #00FF00;
7
+ font-family: 'Courier New', Courier, monospace;
8
+ }
9
+
10
+ /* Style the header */
11
+ h1, h2, h3, h4, h5, h6 {
12
+ color: #00FF00;
13
+ }
14
+
15
+ /* Style the text input and buttons */
16
+ .stTextInput > div > div > input {
17
+ background-color: #000000;
18
+ color: #00FF00;
19
+ border: 2px solid #00FF00;
20
+ border-radius: 5px;
21
+ }
22
+
23
+ .stButton > button {
24
+ background-color: #00FF00;
25
+ color: #000000;
26
+ border: 2px solid #00FF00;
27
+ border-radius: 5px;
28
+ }
29
+
30
+ .stButton > button:hover {
31
+ background-color: #00CC00;
32
+ color: #FFFFFF;
33
+ border: 2px solid #00CC00;
34
+ }
35
+
36
+ /* Style the spinner and success/error messages */
37
+ div[data-testid="stSpinner"] > div {
38
+ border-top-color: #00FF00;
39
+ }
40
+
41
+ .stAlert {
42
+ background-color: #000000;
43
+ color: #00FF00;
44
+ border: 2px solid #00FF00;
45
+ border-radius: 5px;
46
+ }
47
+
48
+ /* Remove Streamlit's default padding */
49
+ .css-18e3th9 {
50
+ padding-top: 1rem;
51
+ padding-bottom: 1rem;
52
+ }