papayaga commited on
Commit
a8467fc
·
1 Parent(s): 65fd528
Files changed (7) hide show
  1. .gitignore +8 -0
  2. app.py +149 -0
  3. gender.py +18 -0
  4. llm.py +48 -0
  5. requirements.txt +90 -0
  6. speechtovid.py +15 -0
  7. tts.py +17 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ .DS_Store
3
+ __pycache__
4
+ flagged
5
+
6
+ files
7
+
8
+ .env
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import uuid
5
+ from pprint import pprint
6
+ import dotenv
7
+ dotenv.load_dotenv()
8
+
9
+ from PIL import Image
10
+
11
+ from llm import answer
12
+ from tts import get_audio
13
+ from speechtovid import get_video
14
+ from gender import get_gender
15
+
16
+ SECRET = os.getenv("SECRET_WORD")
17
+
18
+ supported_languages = ["English", "Chinese", "Spanish", "Hindi", "Portuguese", "French", "German", "Japanese", "Arabic", "Korean", "Indonesian", "Italian", "Dutch", "Turkish", "Polish", "Swedish", "Filipino", "Malay", "Russian", "Romanian", "Ukrainian", "Greek", "Czech", "Danish", "Finnish", "Bulgarian", "Croatian", "Slovak", "Tamil"]
19
+
20
+ # create dirs for images, audio and video files if they don't exist
21
+ if not os.path.exists('./files/audio'):
22
+ os.makedirs('./files/audio')
23
+ if not os.path.exists('./files/images'):
24
+ os.makedirs('./files/images')
25
+ if not os.path.exists('./files/video'):
26
+ os.makedirs('./files/video')
27
+
28
+ # image resizer
29
+ def resize_image(input_image_path, file_name):
30
+ with Image.open(input_image_path) as im:
31
+ original_width, original_height = im.size
32
+ desired_width = 300
33
+ aspect_ratio = original_width / original_height
34
+ new_width = desired_width
35
+ new_height = int(desired_width / aspect_ratio)
36
+ resized_image = im.resize((new_width, new_height))
37
+ resized_image.save("./files/images/"+file_name+".png")
38
+ return "./files/images/"+file_name+".png"
39
+
40
+ # main func
41
+ def holiday_card(secret, brief, lang, photo):
42
+
43
+ if not secret or secret.strip().lower() != SECRET.strip().lower():
44
+ raise gr.Error("Please use the correct secret word!")
45
+
46
+ if not brief:
47
+ raise gr.Error("Please enter the kind of greeting you want to create!")
48
+
49
+ if not photo:
50
+ raise gr.Error("Please upload a photo!")
51
+
52
+
53
+ # generate a unique id for this greeting
54
+ uid = str(uuid.uuid4())
55
+
56
+ # resize the image, otherwise it will be too big
57
+ resized_photo = resize_image(photo, uid)
58
+
59
+ # get the gender of the person in the photo - so that we can choose the voice
60
+ gender = get_gender(resized_photo)
61
+ if gender == 'female':
62
+ voice = 'Bella'
63
+ else:
64
+ voice = 'Antoni'
65
+
66
+ # generate the greeting
67
+ system_prompt = f'''
68
+ You are a native {lang} copywriter with an excellent sense of humour. You help people write text for their holiday voice messages that they will send to their friends and colleagues. You take the user brief and then write a short, joyful, funny and beautiful short speech in {lang} that people say when wishing their colleagues and friends a happy 2024. It shouldn't be more than 2-3 sentences long. Please respond with valid JSON.
69
+
70
+ If the client brief was good - please return valid JSON like this:
71
+ {{
72
+ "status": "OK",
73
+ "text": "your greeting text here"
74
+ }}
75
+
76
+ If the client brief is inappropriate - please return valid JSON like this:
77
+ {{
78
+ "status": "ERROR",
79
+ "reason": "your reason for what is wrong with the brief"
80
+ }}
81
+
82
+ Please alsways return valid JSON and nothing else
83
+ '''
84
+
85
+ # get the answer from the model
86
+ answer_text = answer(
87
+ system_message=system_prompt,
88
+ user_message=brief
89
+ )
90
+
91
+ pprint(answer_text)
92
+
93
+ try:
94
+ answer_data = json.loads(answer_text)
95
+
96
+ except Exception as e:
97
+ pprint(e)
98
+ raise gr.Error(f"Sorry, something went wrong with the AI. Please try again later")
99
+
100
+ if answer_data.get('status') == 'ERROR':
101
+ raise gr.Error(answer_data.get('reason'))
102
+
103
+ text = answer_data.get('text')
104
+
105
+ # now get audio
106
+ try:
107
+ audio_file = get_audio(uid, text, voice)
108
+ pprint(audio_file)
109
+
110
+ except Exception as e:
111
+ pprint(e)
112
+ raise gr.Error(f"Sorry, something went wrong with the audio generation AI. Please try again later")
113
+
114
+ # now get video
115
+ try:
116
+ video_url = get_video(uid, resized_photo, audio_file)
117
+ except Exception as e:
118
+ pprint(e)
119
+ raise gr.Error(f"Sorry, something went wrong with the video generation AI. Please try again later")
120
+
121
+ return text, audio_file, video_url
122
+
123
+
124
+ # set up and launch gradio interface
125
+
126
+ inputs=[
127
+ gr.Textbox(lines=1, label="What is the secret word?"),
128
+ gr.Textbox(lines=5, label="What do you want your holiday greeting to be all about?"),
129
+ gr.Dropdown(supported_languages, label="What language would you like your holiday greeting to be in?", value="French"),
130
+ gr.Image(type="filepath", label="Upload Image")
131
+ ]
132
+
133
+ outputs=[
134
+ gr.Textbox(lines=5, label="Your Holiday Greeting Text"),
135
+ gr.Audio(type="filepath", label="Your Holiday Greeting Audio"),
136
+ gr.Video(label="Your Holiday Greeting Video")
137
+ ]
138
+
139
+ demo = gr.Interface(
140
+ holiday_card,
141
+ inputs,
142
+ outputs,
143
+ allow_flagging="never"
144
+ )
145
+
146
+ demo.queue()
147
+
148
+ if __name__ == "__main__":
149
+ demo.launch()
gender.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ from pprint import pprint
5
+
6
+ HF_TOKEN = os.getenv("HF_TOKEN")
7
+
8
+ API_URL = "https://api-inference.huggingface.co/models/rizvandwiki/gender-classification-2"
9
+ headers = {"Authorization": "Bearer "+HF_TOKEN}
10
+
11
+ def get_gender(filename):
12
+ with open(filename, "rb") as f:
13
+ data = f.read()
14
+ response = requests.post(API_URL, headers=headers, data=data)
15
+
16
+ gender_prediction = response.json()
17
+
18
+ return gender_prediction[0]['label']
llm.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ import json
4
+ from pprint import pprint
5
+ from tenacity import retry, wait_random_exponential, stop_after_attempt
6
+
7
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
8
+ MODEL = 'gpt-4'
9
+
10
+ '''
11
+ basic underlying method to get completions. retries built in
12
+ '''
13
+ @retry(wait=wait_random_exponential(multiplier=2, max=10), stop=stop_after_attempt(3))
14
+ def get_completion(messages, temperature=1):
15
+
16
+ try:
17
+ chat_completion = openai.ChatCompletion.create(
18
+ model=MODEL,
19
+ temperature=temperature,
20
+ messages=messages
21
+ )
22
+
23
+ return chat_completion.choices[0].message.content
24
+
25
+ except Exception as e:
26
+ print("Unable to generate ChatCompletion response")
27
+ print(f"Exception: {e}")
28
+ return e
29
+
30
+ '''
31
+ basic text completion
32
+ '''
33
+ def answer(system_message, user_message, temperature=1):
34
+
35
+ messages = [{
36
+ "role": "system",
37
+ "content": system_message
38
+ },{
39
+ "role": "user",
40
+ "content": user_message
41
+ }]
42
+
43
+ completion = get_completion(
44
+ messages=messages,
45
+ temperature = temperature
46
+ )
47
+
48
+ return completion
requirements.txt ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.8.6
3
+ aiosignal==1.3.1
4
+ altair==5.1.2
5
+ annotated-types==0.6.0
6
+ anyio==3.7.1
7
+ appnope==0.1.3
8
+ asttokens==2.4.1
9
+ async-timeout==4.0.3
10
+ attrs==23.1.0
11
+ backcall==0.2.0
12
+ certifi==2023.7.22
13
+ charset-normalizer==3.3.1
14
+ click==8.1.7
15
+ colorama==0.4.6
16
+ contourpy==1.1.1
17
+ cycler==0.12.1
18
+ decorator==5.1.1
19
+ elevenlabs==0.2.26
20
+ executing==2.0.0
21
+ fastapi==0.104.0
22
+ ffmpy==0.3.1
23
+ filelock==3.12.4
24
+ fonttools==4.43.1
25
+ frozenlist==1.4.0
26
+ fsspec==2023.10.0
27
+ gradio==3.50.2
28
+ gradio_client==0.6.1
29
+ h11==0.14.0
30
+ httpcore==0.18.0
31
+ httpx==0.25.0
32
+ huggingface-hub==0.18.0
33
+ idna==3.4
34
+ importlib-resources==6.1.0
35
+ ipython==8.16.1
36
+ jedi==0.19.1
37
+ Jinja2==3.1.2
38
+ jsonschema==4.19.1
39
+ jsonschema-specifications==2023.7.1
40
+ kiwisolver==1.4.5
41
+ livereload==2.6.3
42
+ MarkupSafe==2.1.3
43
+ matplotlib==3.8.0
44
+ matplotlib-inline==0.1.6
45
+ multidict==6.0.4
46
+ numpy==1.26.1
47
+ openai==0.28.1
48
+ orjson==3.9.10
49
+ packaging==23.2
50
+ pandas==2.1.2
51
+ parso==0.8.3
52
+ pexpect==4.8.0
53
+ pickleshare==0.7.5
54
+ Pillow==10.1.0
55
+ prompt-toolkit==3.0.39
56
+ ptyprocess==0.7.0
57
+ pure-eval==0.2.2
58
+ py-mon==1.1.1
59
+ pydantic==2.4.2
60
+ pydantic_core==2.10.1
61
+ pydub==0.25.1
62
+ Pygments==2.16.1
63
+ pyparsing==3.1.1
64
+ python-dateutil==2.8.2
65
+ python-dotenv==1.0.0
66
+ python-multipart==0.0.6
67
+ pytz==2023.3.post1
68
+ PyYAML==6.0.1
69
+ referencing==0.30.2
70
+ replicate==0.15.4
71
+ requests==2.31.0
72
+ rpds-py==0.10.6
73
+ semantic-version==2.10.0
74
+ six==1.16.0
75
+ sniffio==1.3.0
76
+ stack-data==0.6.3
77
+ starlette==0.27.0
78
+ tenacity==8.2.3
79
+ toolz==0.12.0
80
+ tornado==6.3.3
81
+ tqdm==4.66.1
82
+ traitlets==5.12.0
83
+ typing_extensions==4.8.0
84
+ tzdata==2023.3
85
+ urllib3==2.0.7
86
+ uvicorn==0.23.2
87
+ watchdog==3.0.0
88
+ wcwidth==0.2.8
89
+ websockets==11.0.3
90
+ yarl==1.9.2
speechtovid.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import replicate
3
+ from pprint import pprint
4
+
5
+ def get_video(uid, photo_path, audio_path):
6
+
7
+ output = replicate.run(
8
+ "cjwbw/sadtalker:3aa3dac9353cc4d6bd62a8f95957bd844003b401ca4e4a9b33baa574c549d376",
9
+ input={
10
+ "source_image": open(photo_path, "rb"),
11
+ "driven_audio": open(audio_path, "rb"),
12
+ }
13
+ )
14
+
15
+ return output
tts.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from elevenlabs import generate, set_api_key, save
2
+ import os
3
+
4
+ API_KEY = os.getenv('ELEVENLABS_KEY')
5
+
6
+ set_api_key(API_KEY)
7
+
8
+ def get_audio(a_uid, text, voice_id):
9
+
10
+ audio = generate(
11
+ text=text,
12
+ voice=voice_id,
13
+ model='eleven_multilingual_v2'
14
+ )
15
+ file_name = f'./files/audio/{a_uid}.mp3'
16
+ save(audio, file_name)
17
+ return file_name