minoD commited on
Commit
ba34bf3
·
1 Parent(s): 22ba7d0
Files changed (3) hide show
  1. README.md +5 -5
  2. app.py +80 -0
  3. requirements.txt +65 -0
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
  title: JURAN
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: 学生時代打ち込んだことから質問を生成
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: JURAN
3
+ emoji: 🌺
4
+ colorFrom: green
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ short_description: 面接官の質問をシミュレート
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import os
5
+ import shutil
6
+
7
+ model_name = "minoD/JURAN"
8
+
9
+ # モデルのロード
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ device_map="auto"
13
+ )
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
16
+
17
+ # プロンプトテンプレートの準備
18
+ def generate_prompt(F):
19
+ # input キーの代わりに Q と F を使用
20
+ result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.
21
+
22
+ ### 質問:
23
+ {F}
24
+
25
+ ### 回答:
26
+ """ # 回答セクションを追加
27
+ # 改行→<NL>
28
+ result = result.replace('\n', '<NL>')
29
+ return result
30
+
31
+ # テキスト生成関数の定義
32
+ def generate2(F=None, maxTokens=256):
33
+ # 推論
34
+ prompt = generate_prompt(F)
35
+ input_ids = tokenizer(prompt,
36
+ return_tensors="pt",
37
+ truncation=True,
38
+ add_special_tokens=False).input_ids.cuda()
39
+ outputs = model.generate(
40
+ input_ids=input_ids,
41
+ max_new_tokens=maxTokens,
42
+ do_sample=True,
43
+ temperature=0.7,
44
+ top_p=0.75,
45
+ top_k=40,
46
+ no_repeat_ngram_size=2,
47
+ )
48
+ outputs = outputs[0].tolist()
49
+ decoded = tokenizer.decode(outputs)
50
+
51
+ # EOSトークンにヒットしたらデコード完了
52
+ if tokenizer.eos_token_id in outputs:
53
+ eos_index = outputs.index(tokenizer.eos_token_id)
54
+ decoded = tokenizer.decode(outputs[:eos_index])
55
+
56
+ # レスポンス内容のみ抽出
57
+ sentinel = "### 回答:"
58
+ sentinelLoc = decoded.find(sentinel)
59
+ if sentinelLoc >= 0:
60
+ result = decoded[sentinelLoc + len(sentinel):]
61
+ return result.replace("<NL>", "\n") # <NL>→改行
62
+ else:
63
+ return 'Warning: Expected prompt template to be emitted. Ignoring output.'
64
+ else:
65
+ return 'Warning: no <eos> detected ignoring output'
66
+
67
+ def inference(input_text):
68
+ return generate2(input_text)
69
+
70
+
71
+ iface = gr.Interface(
72
+ fn=inference,
73
+ inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
74
+ outputs=gr.Textbox(label="想定される質問"),
75
+ title="JURAN🌺",
76
+ description="面接官モデルが回答を生成します。",
77
+ )
78
+
79
+ if __name__ == "__main__":
80
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.2.1
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.7.0
5
+ bitsandbytes==0.42.0
6
+ certifi==2024.12.14
7
+ charset-normalizer==3.4.0
8
+ click==8.1.8
9
+ exceptiongroup==1.2.2
10
+ fastapi==0.115.6
11
+ ffmpy==0.5.0
12
+ filelock==3.16.1
13
+ fsspec==2024.12.0
14
+ gradio==5.9.1
15
+ gradio_client==1.5.2
16
+ h11==0.14.0
17
+ httpcore==1.0.7
18
+ httpx==0.28.1
19
+ huggingface-hub==0.27.0
20
+ idna==3.10
21
+ Jinja2==3.1.5
22
+ markdown-it-py==3.0.0
23
+ MarkupSafe==2.1.5
24
+ mdurl==0.1.2
25
+ mpmath==1.3.0
26
+ networkx==3.4.2
27
+ numpy==1.26.4
28
+ orjson==3.10.12
29
+ packaging==24.2
30
+ pandas==2.2.3
31
+ pillow==11.0.0
32
+ psutil==6.1.1
33
+ pydantic==2.10.4
34
+ pydantic_core==2.27.2
35
+ pydub==0.25.1
36
+ Pygments==2.18.0
37
+ python-dateutil==2.9.0.post0
38
+ python-multipart==0.0.20
39
+ pytz==2024.2
40
+ PyYAML==6.0.2
41
+ regex==2024.11.6
42
+ requests==2.32.3
43
+ rich==13.9.4
44
+ ruff==0.8.4
45
+ safehttpx==0.1.6
46
+ safetensors==0.4.5
47
+ scipy==1.14.1
48
+ semantic-version==2.10.0
49
+ sentencepiece==0.2.0
50
+ shellingham==1.5.4
51
+ six==1.17.0
52
+ sniffio==1.3.1
53
+ starlette==0.41.3
54
+ sympy==1.13.1
55
+ tokenizers==0.21.0
56
+ tomlkit==0.13.2
57
+ torch==2.5.1
58
+ tqdm==4.67.1
59
+ transformers==4.47.1
60
+ typer==0.15.1
61
+ typing_extensions==4.12.2
62
+ tzdata==2024.2
63
+ urllib3==2.2.3
64
+ uvicorn==0.34.0
65
+ websockets==14.1