mqha commited on
Commit
139dc75
·
1 Parent(s): 63a6bfa

复制一份别人的

Browse files
Files changed (1) hide show
  1. app.py +65 -23
app.py CHANGED
@@ -1,39 +1,81 @@
1
- # pip install transformers 依赖在requirements.txt里文件安装
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import streamlit as st
3
  from transformers import pipeline, set_seed
4
 
5
  # 设置全局随机种子,确保每次生成的结果相同
6
  set_seed(42)
7
 
 
 
 
8
 
9
- options = ['中文','英文']
10
- choice = st.radio('不同语言使用不同模型:', options)
11
 
12
- input_text = st.text_input("请输入您要生成的文本", value="")
13
- maxlen = st.text_input("请输入生成文本的最大长度,越长越慢,不要超过1000", value="30")
14
- button_generate = st.button("生成")
15
- output_text = st.empty()
16
 
17
- def generate_text(input_text):
18
- # 加载预训练模型
19
- if choice == '中文':
20
- model = 'uer/gpt2-chinese-cluecorpussmall' # 纠正后的应该可以
21
- #model = 'gpt2-chinese-cluecorpussmall' # 会自动下载
22
- generator = pipeline("text-generation", model)
23
 
24
- # 生成文本
25
- output = generator(input_text, max_length=int(maxlen), num_return_sequences=1)
26
 
27
- # 提取生成的文本
28
- generated_text = output[0]["generated_text"].strip()
29
 
30
- return generated_text
31
 
32
- if button_generate:
33
- # 生成文本
34
- generated_text = generate_text(input_text)
35
 
36
- # 显示生成的文本
37
- output_text.success(generated_text)
38
 
 
 
 
39
 
 
1
+ # # pip install transformers 依赖在requirements.txt里文件安装
2
+ # import streamlit as st
3
+ # from transformers import pipeline, set_seed
4
+
5
+ # # 设置全局随机种子,确保每次生成的结果相同
6
+ # set_seed(42)
7
+
8
+
9
+ # options = ['中文','英文']
10
+ # choice = st.radio('不同语言使用不同模型:', options)
11
+
12
+ # input_text = st.text_input("请输入您要生成的文本", value="")
13
+ # maxlen = st.text_input("请输入生成文本的最大长度,越长越慢,不要超过1000", value="30")
14
+ # button_generate = st.button("生成")
15
+ # output_text = st.empty()
16
+
17
+ # def generate_text(input_text):
18
+ # # 加载预训练模型
19
+ # if choice == '中文':
20
+ # model = 'uer/gpt2-chinese-cluecorpussmall' # 纠正后的应该可以
21
+ # #model = 'gpt2-chinese-cluecorpussmall' # 会自动下载
22
+ # generator = pipeline("text-generation", model)
23
+
24
+ # # 生成文本
25
+ # output = generator(input_text, max_length=int(maxlen), num_return_sequences=1)
26
+
27
+ # # 提取生成的文本
28
+ # generated_text = output[0]["generated_text"].strip()
29
+
30
+ # return generated_text
31
+
32
+ # if button_generate:
33
+ # # 生成文本
34
+ # generated_text = generate_text(input_text)
35
+
36
+ # # 显示生成的文本
37
+ # output_text.success(generated_text)
38
  import streamlit as st
39
  from transformers import pipeline, set_seed
40
 
41
  # 设置全局随机种子,确保每次生成的结果相同
42
  set_seed(42)
43
 
44
+ def app():
45
+ # 创建Streamlit应用程序
46
+ st.title("使用gpt2的文本生成")
47
 
48
+ options = ['中文','英文']
49
+ choice = st.radio('不同语言使用不同模型:', options)
50
 
51
+ input_text = st.text_input("请输入您要生成的文本", value="")
52
+ maxlen = st.text_input("请输入生成文本的最大长度,越长越慢,不要超过1000", value="30")
53
+ button_generate = st.button("生成")
54
+ output_text = st.empty()
55
 
56
+ def generate_text(input_text):
57
+ # 加载预训练模型
58
+ model="gpt2"
59
+ if choice == '中文':
60
+ model = 'uer/gpt2-chinese-cluecorpussmall'
61
+ generator = pipeline("text-generation", model)
62
 
63
+ # 生成文本
64
+ output = generator(input_text, max_length=int(maxlen), num_return_sequences=1)
65
 
66
+ # 提取生成的文本
67
+ generated_text = output[0]["generated_text"].strip()
68
 
69
+ return generated_text
70
 
71
+ if button_generate:
72
+ # 生成文本
73
+ generated_text = generate_text(input_text)
74
 
75
+ # 显示生成的文本
76
+ output_text.success(generated_text)
77
 
78
+ if __name__ == "__main__":
79
+ # 运行应用程序
80
+ app()
81