Spaces:
Sleeping
Sleeping
demo
Browse files- .gitignore +8 -0
- app.py +149 -0
- gender.py +18 -0
- llm.py +48 -0
- requirements.txt +90 -0
- speechtovid.py +15 -0
- tts.py +17 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
.DS_Store
|
3 |
+
__pycache__
|
4 |
+
flagged
|
5 |
+
|
6 |
+
files
|
7 |
+
|
8 |
+
.env
|
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import uuid
|
5 |
+
from pprint import pprint
|
6 |
+
import dotenv
|
7 |
+
dotenv.load_dotenv()
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from llm import answer
|
12 |
+
from tts import get_audio
|
13 |
+
from speechtovid import get_video
|
14 |
+
from gender import get_gender
|
15 |
+
|
16 |
+
SECRET = os.getenv("SECRET_WORD")
|
17 |
+
|
18 |
+
supported_languages = ["English", "Chinese", "Spanish", "Hindi", "Portuguese", "French", "German", "Japanese", "Arabic", "Korean", "Indonesian", "Italian", "Dutch", "Turkish", "Polish", "Swedish", "Filipino", "Malay", "Russian", "Romanian", "Ukrainian", "Greek", "Czech", "Danish", "Finnish", "Bulgarian", "Croatian", "Slovak", "Tamil"]
|
19 |
+
|
20 |
+
# create dirs for images, audio and video files if they don't exist
|
21 |
+
if not os.path.exists('./files/audio'):
|
22 |
+
os.makedirs('./files/audio')
|
23 |
+
if not os.path.exists('./files/images'):
|
24 |
+
os.makedirs('./files/images')
|
25 |
+
if not os.path.exists('./files/video'):
|
26 |
+
os.makedirs('./files/video')
|
27 |
+
|
28 |
+
# image resizer
|
29 |
+
def resize_image(input_image_path, file_name):
|
30 |
+
with Image.open(input_image_path) as im:
|
31 |
+
original_width, original_height = im.size
|
32 |
+
desired_width = 300
|
33 |
+
aspect_ratio = original_width / original_height
|
34 |
+
new_width = desired_width
|
35 |
+
new_height = int(desired_width / aspect_ratio)
|
36 |
+
resized_image = im.resize((new_width, new_height))
|
37 |
+
resized_image.save("./files/images/"+file_name+".png")
|
38 |
+
return "./files/images/"+file_name+".png"
|
39 |
+
|
40 |
+
# main func
|
41 |
+
def holiday_card(secret, brief, lang, photo):
|
42 |
+
|
43 |
+
if not secret or secret.strip().lower() != SECRET.strip().lower():
|
44 |
+
raise gr.Error("Please use the correct secret word!")
|
45 |
+
|
46 |
+
if not brief:
|
47 |
+
raise gr.Error("Please enter the kind of greeting you want to create!")
|
48 |
+
|
49 |
+
if not photo:
|
50 |
+
raise gr.Error("Please upload a photo!")
|
51 |
+
|
52 |
+
|
53 |
+
# generate a unique id for this greeting
|
54 |
+
uid = str(uuid.uuid4())
|
55 |
+
|
56 |
+
# resize the image, otherwise it will be too big
|
57 |
+
resized_photo = resize_image(photo, uid)
|
58 |
+
|
59 |
+
# get the gender of the person in the photo - so that we can choose the voice
|
60 |
+
gender = get_gender(resized_photo)
|
61 |
+
if gender == 'female':
|
62 |
+
voice = 'Bella'
|
63 |
+
else:
|
64 |
+
voice = 'Antoni'
|
65 |
+
|
66 |
+
# generate the greeting
|
67 |
+
system_prompt = f'''
|
68 |
+
You are a native {lang} copywriter with an excellent sense of humour. You help people write text for their holiday voice messages that they will send to their friends and colleagues. You take the user brief and then write a short, joyful, funny and beautiful short speech in {lang} that people say when wishing their colleagues and friends a happy 2024. It shouldn't be more than 2-3 sentences long. Please respond with valid JSON.
|
69 |
+
|
70 |
+
If the client brief was good - please return valid JSON like this:
|
71 |
+
{{
|
72 |
+
"status": "OK",
|
73 |
+
"text": "your greeting text here"
|
74 |
+
}}
|
75 |
+
|
76 |
+
If the client brief is inappropriate - please return valid JSON like this:
|
77 |
+
{{
|
78 |
+
"status": "ERROR",
|
79 |
+
"reason": "your reason for what is wrong with the brief"
|
80 |
+
}}
|
81 |
+
|
82 |
+
Please alsways return valid JSON and nothing else
|
83 |
+
'''
|
84 |
+
|
85 |
+
# get the answer from the model
|
86 |
+
answer_text = answer(
|
87 |
+
system_message=system_prompt,
|
88 |
+
user_message=brief
|
89 |
+
)
|
90 |
+
|
91 |
+
pprint(answer_text)
|
92 |
+
|
93 |
+
try:
|
94 |
+
answer_data = json.loads(answer_text)
|
95 |
+
|
96 |
+
except Exception as e:
|
97 |
+
pprint(e)
|
98 |
+
raise gr.Error(f"Sorry, something went wrong with the AI. Please try again later")
|
99 |
+
|
100 |
+
if answer_data.get('status') == 'ERROR':
|
101 |
+
raise gr.Error(answer_data.get('reason'))
|
102 |
+
|
103 |
+
text = answer_data.get('text')
|
104 |
+
|
105 |
+
# now get audio
|
106 |
+
try:
|
107 |
+
audio_file = get_audio(uid, text, voice)
|
108 |
+
pprint(audio_file)
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
pprint(e)
|
112 |
+
raise gr.Error(f"Sorry, something went wrong with the audio generation AI. Please try again later")
|
113 |
+
|
114 |
+
# now get video
|
115 |
+
try:
|
116 |
+
video_url = get_video(uid, resized_photo, audio_file)
|
117 |
+
except Exception as e:
|
118 |
+
pprint(e)
|
119 |
+
raise gr.Error(f"Sorry, something went wrong with the video generation AI. Please try again later")
|
120 |
+
|
121 |
+
return text, audio_file, video_url
|
122 |
+
|
123 |
+
|
124 |
+
# set up and launch gradio interface
|
125 |
+
|
126 |
+
inputs=[
|
127 |
+
gr.Textbox(lines=1, label="What is the secret word?"),
|
128 |
+
gr.Textbox(lines=5, label="What do you want your holiday greeting to be all about?"),
|
129 |
+
gr.Dropdown(supported_languages, label="What language would you like your holiday greeting to be in?", value="French"),
|
130 |
+
gr.Image(type="filepath", label="Upload Image")
|
131 |
+
]
|
132 |
+
|
133 |
+
outputs=[
|
134 |
+
gr.Textbox(lines=5, label="Your Holiday Greeting Text"),
|
135 |
+
gr.Audio(type="filepath", label="Your Holiday Greeting Audio"),
|
136 |
+
gr.Video(label="Your Holiday Greeting Video")
|
137 |
+
]
|
138 |
+
|
139 |
+
demo = gr.Interface(
|
140 |
+
holiday_card,
|
141 |
+
inputs,
|
142 |
+
outputs,
|
143 |
+
allow_flagging="never"
|
144 |
+
)
|
145 |
+
|
146 |
+
demo.queue()
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
demo.launch()
|
gender.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from pprint import pprint
|
5 |
+
|
6 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
7 |
+
|
8 |
+
API_URL = "https://api-inference.huggingface.co/models/rizvandwiki/gender-classification-2"
|
9 |
+
headers = {"Authorization": "Bearer "+HF_TOKEN}
|
10 |
+
|
11 |
+
def get_gender(filename):
|
12 |
+
with open(filename, "rb") as f:
|
13 |
+
data = f.read()
|
14 |
+
response = requests.post(API_URL, headers=headers, data=data)
|
15 |
+
|
16 |
+
gender_prediction = response.json()
|
17 |
+
|
18 |
+
return gender_prediction[0]['label']
|
llm.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import openai
|
3 |
+
import json
|
4 |
+
from pprint import pprint
|
5 |
+
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
6 |
+
|
7 |
+
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
8 |
+
MODEL = 'gpt-4'
|
9 |
+
|
10 |
+
'''
|
11 |
+
basic underlying method to get completions. retries built in
|
12 |
+
'''
|
13 |
+
@retry(wait=wait_random_exponential(multiplier=2, max=10), stop=stop_after_attempt(3))
|
14 |
+
def get_completion(messages, temperature=1):
|
15 |
+
|
16 |
+
try:
|
17 |
+
chat_completion = openai.ChatCompletion.create(
|
18 |
+
model=MODEL,
|
19 |
+
temperature=temperature,
|
20 |
+
messages=messages
|
21 |
+
)
|
22 |
+
|
23 |
+
return chat_completion.choices[0].message.content
|
24 |
+
|
25 |
+
except Exception as e:
|
26 |
+
print("Unable to generate ChatCompletion response")
|
27 |
+
print(f"Exception: {e}")
|
28 |
+
return e
|
29 |
+
|
30 |
+
'''
|
31 |
+
basic text completion
|
32 |
+
'''
|
33 |
+
def answer(system_message, user_message, temperature=1):
|
34 |
+
|
35 |
+
messages = [{
|
36 |
+
"role": "system",
|
37 |
+
"content": system_message
|
38 |
+
},{
|
39 |
+
"role": "user",
|
40 |
+
"content": user_message
|
41 |
+
}]
|
42 |
+
|
43 |
+
completion = get_completion(
|
44 |
+
messages=messages,
|
45 |
+
temperature = temperature
|
46 |
+
)
|
47 |
+
|
48 |
+
return completion
|
requirements.txt
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohttp==3.8.6
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.1.2
|
5 |
+
annotated-types==0.6.0
|
6 |
+
anyio==3.7.1
|
7 |
+
appnope==0.1.3
|
8 |
+
asttokens==2.4.1
|
9 |
+
async-timeout==4.0.3
|
10 |
+
attrs==23.1.0
|
11 |
+
backcall==0.2.0
|
12 |
+
certifi==2023.7.22
|
13 |
+
charset-normalizer==3.3.1
|
14 |
+
click==8.1.7
|
15 |
+
colorama==0.4.6
|
16 |
+
contourpy==1.1.1
|
17 |
+
cycler==0.12.1
|
18 |
+
decorator==5.1.1
|
19 |
+
elevenlabs==0.2.26
|
20 |
+
executing==2.0.0
|
21 |
+
fastapi==0.104.0
|
22 |
+
ffmpy==0.3.1
|
23 |
+
filelock==3.12.4
|
24 |
+
fonttools==4.43.1
|
25 |
+
frozenlist==1.4.0
|
26 |
+
fsspec==2023.10.0
|
27 |
+
gradio==3.50.2
|
28 |
+
gradio_client==0.6.1
|
29 |
+
h11==0.14.0
|
30 |
+
httpcore==0.18.0
|
31 |
+
httpx==0.25.0
|
32 |
+
huggingface-hub==0.18.0
|
33 |
+
idna==3.4
|
34 |
+
importlib-resources==6.1.0
|
35 |
+
ipython==8.16.1
|
36 |
+
jedi==0.19.1
|
37 |
+
Jinja2==3.1.2
|
38 |
+
jsonschema==4.19.1
|
39 |
+
jsonschema-specifications==2023.7.1
|
40 |
+
kiwisolver==1.4.5
|
41 |
+
livereload==2.6.3
|
42 |
+
MarkupSafe==2.1.3
|
43 |
+
matplotlib==3.8.0
|
44 |
+
matplotlib-inline==0.1.6
|
45 |
+
multidict==6.0.4
|
46 |
+
numpy==1.26.1
|
47 |
+
openai==0.28.1
|
48 |
+
orjson==3.9.10
|
49 |
+
packaging==23.2
|
50 |
+
pandas==2.1.2
|
51 |
+
parso==0.8.3
|
52 |
+
pexpect==4.8.0
|
53 |
+
pickleshare==0.7.5
|
54 |
+
Pillow==10.1.0
|
55 |
+
prompt-toolkit==3.0.39
|
56 |
+
ptyprocess==0.7.0
|
57 |
+
pure-eval==0.2.2
|
58 |
+
py-mon==1.1.1
|
59 |
+
pydantic==2.4.2
|
60 |
+
pydantic_core==2.10.1
|
61 |
+
pydub==0.25.1
|
62 |
+
Pygments==2.16.1
|
63 |
+
pyparsing==3.1.1
|
64 |
+
python-dateutil==2.8.2
|
65 |
+
python-dotenv==1.0.0
|
66 |
+
python-multipart==0.0.6
|
67 |
+
pytz==2023.3.post1
|
68 |
+
PyYAML==6.0.1
|
69 |
+
referencing==0.30.2
|
70 |
+
replicate==0.15.4
|
71 |
+
requests==2.31.0
|
72 |
+
rpds-py==0.10.6
|
73 |
+
semantic-version==2.10.0
|
74 |
+
six==1.16.0
|
75 |
+
sniffio==1.3.0
|
76 |
+
stack-data==0.6.3
|
77 |
+
starlette==0.27.0
|
78 |
+
tenacity==8.2.3
|
79 |
+
toolz==0.12.0
|
80 |
+
tornado==6.3.3
|
81 |
+
tqdm==4.66.1
|
82 |
+
traitlets==5.12.0
|
83 |
+
typing_extensions==4.8.0
|
84 |
+
tzdata==2023.3
|
85 |
+
urllib3==2.0.7
|
86 |
+
uvicorn==0.23.2
|
87 |
+
watchdog==3.0.0
|
88 |
+
wcwidth==0.2.8
|
89 |
+
websockets==11.0.3
|
90 |
+
yarl==1.9.2
|
speechtovid.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import replicate
|
3 |
+
from pprint import pprint
|
4 |
+
|
5 |
+
def get_video(uid, photo_path, audio_path):
|
6 |
+
|
7 |
+
output = replicate.run(
|
8 |
+
"cjwbw/sadtalker:3aa3dac9353cc4d6bd62a8f95957bd844003b401ca4e4a9b33baa574c549d376",
|
9 |
+
input={
|
10 |
+
"source_image": open(photo_path, "rb"),
|
11 |
+
"driven_audio": open(audio_path, "rb"),
|
12 |
+
}
|
13 |
+
)
|
14 |
+
|
15 |
+
return output
|
tts.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from elevenlabs import generate, set_api_key, save
|
2 |
+
import os
|
3 |
+
|
4 |
+
API_KEY = os.getenv('ELEVENLABS_KEY')
|
5 |
+
|
6 |
+
set_api_key(API_KEY)
|
7 |
+
|
8 |
+
def get_audio(a_uid, text, voice_id):
|
9 |
+
|
10 |
+
audio = generate(
|
11 |
+
text=text,
|
12 |
+
voice=voice_id,
|
13 |
+
model='eleven_multilingual_v2'
|
14 |
+
)
|
15 |
+
file_name = f'./files/audio/{a_uid}.mp3'
|
16 |
+
save(audio, file_name)
|
17 |
+
return file_name
|