File size: 3,637 Bytes
2817176
 
814a015
 
9604a21
4e06e40
2817176
814a015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e06e40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
814a015
 
 
 
 
 
 
 
 
4e06e40
 
 
 
 
 
2817176
 
e56158c
2817176
e56158c
 
 
 
2817176
814a015
e56158c
2817176
 
e56158c
 
 
 
 
2817176
 
 
 
 
e56158c
814a015
e56158c
2817176
 
 
 
 
814a015
2817176
 
 
814a015
2817176
814a015
2817176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import pyttsx3  # Importing pyttsx3 for text-to-speech

# Replace 'your_huggingface_token' with your actual Hugging Face access token
access_token = os.getenv('token')

# Initialize the tokenizer and model with the Hugging Face access token
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", use_auth_token=access_token)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    torch_dtype=torch.bfloat16,
    use_auth_token=access_token
)
model.eval()  # Set the model to evaluation mode

# Initialize the inference client (if needed for other API-based tasks)
client = InferenceClient(token=access_token)

# Initialize the text-to-speech engine
tts_engine = pyttsx3.init()

# Import required modules for E2-F5-TTS
from huggingface_hub import Client

# Initialize the E2-F5-TTS client
client_tts = Client("mrfakename/E2-F5-TTS")

def text_to_speech(text, sample):
    result = client_tts.predict(
        ref_audio_input=handle_file(f'input/{sample}.mp3'),
        ref_text_input="",
        gen_text_input=text,
        remove_silence=False,
        cross_fade_duration_slider=0.15,
        speed_slider=1,
        api_name="/basic_tts"
    )
    audio_file = open(result[0], "rb")
    audio_bytes = audio_file.read()
    return audio_bytes

def conversation_predict(input_text):
    """Generate a response for single-turn input using the model."""
    # Tokenize the input text
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Generate a response with the model
    outputs = model.generate(input_ids, max_new_tokens=2048)

    # Decode and return the generated response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Convert the text response to speech using E2-F5-TTS
    audio_bytes = text_to_speech(response, sample="input")
    
    return response, audio_bytes

def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """Generate a response for a multi-turn chat conversation."""
    # Prepare the messages in the correct format for the API
    messages = [{"role": "system", "content": system_message}]

    for user_input, assistant_reply in history:
        if user_input:
            messages.append({"role": "user", "content": user_input})
        if assistant_reply:
            messages.append({"role": "assistant", "content": assistant_reply})

    messages.append({"role": "user", "content": message})

    response = ""

    # Stream response tokens from the chat completion API
    for message_chunk in client.chat_completion(
        messages=messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message_chunk["choices"][0]["delta"].get("content", "")
        response += token
        yield response

# Create a Gradio ChatInterface demo
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()