kmiyazaki commited on
Commit
9f2bfe9
·
verified ·
1 Parent(s): d4e6408

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +114 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision.models import resnet50, ResNet50_Weights
4
+ from PIL import Image
5
+ import tempfile
6
+ from gtts import gTTS
7
+ import whisper
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+
10
+ # ----- 画像認識用モデル (ResNet-50) -----
11
+ weights = ResNet50_Weights.IMAGENET1K_V2
12
+ img_model = resnet50(weights=weights)
13
+ img_model.eval()
14
+ img_transform = weights.transforms()
15
+ imagenet_classes = weights.meta["categories"]
16
+
17
+
18
+ def image_classify(img: Image.Image):
19
+ img_tensor = img_transform(img).unsqueeze(0)
20
+ with torch.no_grad():
21
+ outputs = img_model(img_tensor)
22
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
23
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
24
+ result = {imagenet_classes[top5_catid[i]]: float(top5_prob[i]) for i in range(5)}
25
+ return result
26
+
27
+
28
+ model_name = "cyberagent/open-calm-1b"
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_name, device_map="auto", torch_dtype=torch.float16
31
+ )
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
+ model_name, use_fast=True, trust_remote_code=True
34
+ )
35
+
36
+ text_gen_pipeline = pipeline(
37
+ "text-generation",
38
+ model=model,
39
+ tokenizer=tokenizer,
40
+ max_length=128,
41
+ temperature=0.7,
42
+ top_p=0.9,
43
+ pad_token_id=tokenizer.eos_token_id,
44
+ )
45
+
46
+
47
+ # ----- 言語モデル (LM) -----
48
+ def generate_text(prompt):
49
+ # promptに基づき続きのテキストを生成
50
+ result = text_gen_pipeline(prompt, do_sample=True, num_return_sequences=1)
51
+ generated_text = result[0]["generated_text"]
52
+ # prompt部分を含めた全文が返るので、prompt部分はそのままでOK
53
+ return generated_text
54
+
55
+
56
+ # ----- 音声合成 (TTS) -----
57
+ def text_to_speech(text, lang="ja"):
58
+ tts = gTTS(text=text, lang=lang)
59
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as fp:
60
+ tts.save(fp.name)
61
+ return fp.name
62
+
63
+
64
+ # ----- 音声認識 (ASR) -----
65
+ whisper_model = whisper.load_model("small")
66
+
67
+
68
+ def speech_to_text(audio_file):
69
+ result = whisper_model.transcribe(audio_file)
70
+ return result["text"]
71
+
72
+
73
+ # ----- Gradio UI -----
74
+ def run():
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("# 画像認識・言語モデル・音声合成・音声認識")
77
+
78
+ with gr.Tabs():
79
+ with gr.TabItem("画像認識"):
80
+ gr.Markdown("### 画像認識 (ResNet-50)")
81
+ gr.Interface(
82
+ fn=image_classify,
83
+ inputs=gr.Image(type="pil"),
84
+ outputs=gr.Label(num_top_classes=5),
85
+ description="画像をアップロードして分類します。(ImageNet)",
86
+ )
87
+
88
+ with gr.TabItem("言語モデル"):
89
+ gr.Markdown("### 言語モデル")
90
+ lm_output = gr.Textbox(label="生成結果")
91
+ user_input = gr.Textbox(label="入力テキスト")
92
+ send_btn = gr.Button("送信")
93
+ send_btn.click(generate_text, inputs=user_input, outputs=lm_output)
94
+
95
+ with gr.TabItem("音声合成"):
96
+ gr.Markdown("### 音声合成 (gTTS)")
97
+ tts_input = gr.Textbox(label="音声にしたいテキスト")
98
+ tts_output = gr.Audio(label="合成音声")
99
+ tts_button = gr.Button("合成")
100
+ tts_button.click(text_to_speech, inputs=tts_input, outputs=tts_output)
101
+
102
+ with gr.TabItem("音声認識"):
103
+ gr.Markdown("### 音声認識 (Whisper)")
104
+ gr.Interface(
105
+ fn=speech_to_text,
106
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
107
+ outputs="text",
108
+ description="マイクから録音して文字起こし",
109
+ )
110
+
111
+ demo.launch()
112
+
113
+ if __name__ == "__main__":
114
+ run()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ accelerate
5
+ gTTS
6
+ git+https://github.com/openai/whisper.git
7
+ ffmpeg-python
8
+ gradio