EmTpro01 commited on
Commit
8d8d3b6
1 Parent(s): 13be039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -32
app.py CHANGED
@@ -1,33 +1,91 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
-
4
- # Load the merged model
5
- model_name = "EmTpro01/gemma-paraphraser-4bit" # Replace with your merged model path
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name) # Default device is CPU
8
-
9
- # Streamlit UI
10
- st.title("Text Paraphrasing ")
11
- st.write("Provide a paragraph, and this AI will paraphrase it for you.")
12
-
13
- # Input paragraph
14
- paragraph = st.text_area("Enter a paragraph to paraphrase:", height=200)
15
-
16
- if st.button("Paraphrase"):
17
- if paragraph.strip():
18
- with st.spinner("Paraphrasing..."):
19
- # Prepare the prompt
20
- alpaca_prompt = f"Below is a paragraph, paraphrase it.\n### paragraph: {paragraph}\n### paraphrased:"
21
-
22
- # Tokenize input and move to CPU
23
- inputs = tokenizer(alpaca_prompt, return_tensors="pt")
24
-
25
- # Generate paraphrased text
26
- output = model.generate(**inputs, max_new_tokens=200)
27
- paraphrased = tokenizer.decode(output[0], skip_special_tokens=True)
28
-
29
- # Extract the paraphrased portion
30
- result = paraphrased.split("### paraphrased:")[1].strip()
31
- st.text_area("Paraphrased Output:", result, height=200)
32
- else:
33
- st.warning("Please enter a paragraph to paraphrase.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ # Set page configuration
6
+ st.set_page_config(page_title="Gemma Paraphraser", page_icon="✍️")
7
+
8
+ # Load model and tokenizer
9
+ @st.cache_resource
10
+ def load_model():
11
+ model_name = "EmTpro01/gemma-paraphraser-16bit"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ device_map="cpu",
16
+ torch_dtype=torch.float16
17
+ )
18
+ return model, tokenizer
19
+
20
+ # Paraphrase function
21
+ def paraphrase_text(text, model, tokenizer):
22
+ # Prepare the prompt using Alpaca format
23
+ system_prompt = "Below is provided a paragraph, paraphrase it"
24
+ prompt = f"{system_prompt}\n\n### Input:\n{text}\n\n### Output:\n"
25
+
26
+ # Tokenize input
27
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
28
+
29
+ # Generate paraphrased text
30
+ outputs = model.generate(
31
+ inputs.input_ids,
32
+ max_length=512, # Adjust based on your needs
33
+ num_return_sequences=1,
34
+ temperature=0.7,
35
+ do_sample=True
36
+ )
37
+
38
+ # Decode and clean the output
39
+ paraphrased = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+
41
+ # Extract the output part (after "### Output:")
42
+ output_start = paraphrased.find("### Output:") + len("### Output:")
43
+ paraphrased_text = paraphrased[output_start:].strip()
44
+
45
+ return paraphrased_text
46
+
47
+ # Streamlit App
48
+ def main():
49
+ st.title("📝 Gemma Paraphraser")
50
+ st.write("Paraphrase your text using the Gemma model")
51
+
52
+ # Load model
53
+ try:
54
+ model, tokenizer = load_model()
55
+ except Exception as e:
56
+ st.error(f"Error loading model: {e}")
57
+ return
58
+
59
+ # Input text area
60
+ input_text = st.text_area("Enter text to paraphrase:", height=200)
61
+
62
+ # Paraphrase button
63
+ if st.button("Paraphrase"):
64
+ if input_text:
65
+ with st.spinner("Generating paraphrase..."):
66
+ try:
67
+ paraphrased_text = paraphrase_text(input_text, model, tokenizer)
68
+
69
+ # Display results
70
+ st.subheader("Paraphrased Text:")
71
+ st.write(paraphrased_text)
72
+
73
+ # Optional: Copy to clipboard
74
+ st.button("Copy to Clipboard",
75
+ on_click=lambda: st.write(paraphrased_text))
76
+ except Exception as e:
77
+ st.error(f"Error during paraphrasing: {e}")
78
+ else:
79
+ st.warning("Please enter some text to paraphrase.")
80
+
81
+ # Additional information
82
+ st.sidebar.info(
83
+ "Model: EmTpro01/gemma-paraphraser-16bit\n\n"
84
+ "Tips:\n"
85
+ "- Enter a paragraph to paraphrase\n"
86
+ "- Click 'Paraphrase' to generate\n"
87
+ "- Running on CPU with 16-bit precision"
88
+ )
89
+
90
+ if __name__ == "__main__":
91
+ main()