File size: 5,489 Bytes
30d06f3
ec4e3bf
3180e31
4b65fd2
142b484
6c0ac6b
 
 
 
dd17730
015696a
 
 
 
 
6c0ac6b
49c0f95
cbe1d01
19a01a7
3180e31
 
 
 
 
14fa9b7
3180e31
 
 
 
 
14fa9b7
3180e31
 
 
 
 
 
14fa9b7
 
19a01a7
 
cbe1d01
31fc42e
 
 
 
 
9ca2069
 
ec4e3bf
6c0ac6b
142b484
 
19d44da
142b484
9ca2069
e8afa15
7099e7c
e8afa15
9ca2069
9198ac8
 
 
 
 
6c0ac6b
 
 
9198ac8
 
aaadcd8
 
 
 
9198ac8
aaadcd8
9198ac8
 
 
 
 
 
cbe1d01
 
3180e31
60d22a5
3180e31
cbe1d01
3180e31
 
 
 
 
cbe1d01
3180e31
 
 
9198ac8
9f48b8d
f92c145
6c0ac6b
 
 
 
 
 
 
 
 
 
 
c62697b
6c0ac6b
 
19a01a7
9198ac8
142b484
31fc42e
9ca2069
 
 
 
 
e8afa15
9ca2069
e8afa15
9ca2069
 
31fc42e
b0261c2
31fc42e
142b484
9ca2069
142b484
 
19d44da
 
 
 
 
 
 
 
 
142b484
 
 
 
 
 
 
 
 
ec4e3bf
 
4b65fd2
ec4e3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import time
from langchain_core.pydantic_v1 import BaseModel, Field
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from TextGen.suno import custom_generate_audio, get_audio_information
from langchain_google_genai import (
    ChatGoogleGenerativeAI,
    HarmBlockThreshold,
    HarmCategory,
)
from TextGen import app
from gradio_client import Client, handle_file
from typing import List

class PlayLastMusic(BaseModel):
    '''plays the lastest created music '''
    Desicion: str = Field(
        ..., description="Yes or No"
    )

class CreateLyrics(BaseModel):
    f'''create some Lyrics for a new music'''
    Desicion: str = Field(
        ..., description="Yes or No"
    )

class CreateNewMusic(BaseModel):
    f'''create a new music with the Lyrics previously computed'''
    Name: str = Field(
        ..., description="tags to describe the new music"
    )



class Message(BaseModel):
    npc: str | None  = None
    messages: List[str] | None = None
    
class VoiceMessage(BaseModel):
    npc: str | None  = None
    input: str | None = None
    language: str | None = "en"
    genre:str | None = "Male"
    
song_base_api=os.environ["VERCEL_API"]

my_hf_token=os.environ["HF_TOKEN"]

tts_client = Client("Jofthomas/xtts",hf_token=my_hf_token)

main_npcs={
    "Blacksmith":"./voices/Blacksmith.mp3",
    "Herbalist":"./voices/female.mp3",
    "Bard":"./voices/Bard_voice.mp3"
}
main_npc_system_prompts={
    "Blacksmith":"You are a blacksmith in a video game",
    "Herbalist":"You are an herbalist in a video game",
    "Bard":"You are a bard in a video game"
}
class Generate(BaseModel):
    text:str

def generate_text(messages: List[str], npc:str):
    print(npc)
    if npc in main_npcs:
        system_prompt=main_npc_system_prompts[npc]
    else:
        system_prompt="you're a character in a video game. Play along."
    print(system_prompt)    
    new_messages=[{"role": "user", "content": system_prompt}]
    for index, message in enumerate(messages):
      if index%2==0:
        new_messages.append({"role": "user", "content": message})
      else:
        new_messages.append({"role": "assistant", "content": message})
    print(new_messages)
    # Initialize the LLM
    llm = ChatGoogleGenerativeAI(
        model="gemini-1.5-pro-latest",
        max_output_tokens=100,
        temperature=1,
        safety_settings={
                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
            },
    )
    if npc=="bard":
        llm = llm.bind_tools([PlayLastMusic,CreateNewMusic,CreateLyrics])

    llm_response = llm.invoke(new_messages)
    print(llm_response)
    return Generate(text=llm_response.content)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/", tags=["Home"])
def api_home():
    return {'detail': 'Everchanging Quest backend, nothing to see here'}

@app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
def inference(message: Message):
    return generate_text(messages=message.messages, npc=message.npc)

#Dummy function for now
def determine_vocie_from_npc(npc,genre):
    if npc in main_npcs:
        return main_npcs[npc]
    else:
        if genre =="Male":
            "./voices/default_male.mp3"
        if genre=="Female":
            return"./voices/default_female.mp3"
        else:
            return "./voices/narator_out.wav"
    
@app.post("/generate_wav")
async def generate_wav(message:VoiceMessage):
    try:
        voice=determine_vocie_from_npc(message.npc, message.genre)
        # Use the Gradio client to generate the wav file
        result = tts_client.predict(
          prompt=message.input,
          language=message.language,
          audio_file_pth=handle_file(voice),
          mic_file_path=None,
          use_mic=False,
          voice_cleanup=False,
          no_lang_auto_detect=False,
          agree=True,
          api_name="/predict"
        )

        # Get the path of the generated wav file
        wav_file_path = result[1]

        # Return the generated wav file as a response
        return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav")

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/generate_song")
async def generate_song(text: str):
    try:
        data = custom_generate_audio({
            "prompt": f"{text}",
            "make_instrumental": False,
            "wait_audio": False
        })
        ids = f"{data[0]['id']},{data[1]['id']}"
        print(f"ids: {ids}")

        for _ in range(60):
            data = get_audio_information(ids)
            if data[0]["status"] == 'streaming':
                print(f"{data[0]['id']} ==> {data[0]['audio_url']}")
                print(f"{data[1]['id']} ==> {data[1]['audio_url']}")
                break
            # sleep 5s
            time.sleep(5)
    except:
        print("Error")