Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import structlog
|
3 |
+
import openai
|
4 |
+
import random
|
5 |
+
import tiktoken
|
6 |
+
import enum
|
7 |
+
import time
|
8 |
+
import retrying
|
9 |
+
import IPython.display as display
|
10 |
+
from base64 import b64decode
|
11 |
+
import base64
|
12 |
+
from io import BytesIO
|
13 |
+
import PIL
|
14 |
+
import PIL.Image
|
15 |
+
import PIL.ImageDraw
|
16 |
+
import PIL.ImageFont
|
17 |
+
from ChatPodcastGPT import Chat
|
18 |
+
|
19 |
+
logger = structlog.getLogger()
|
20 |
+
weather_api_key = os.environ['WEATHER_API']
|
21 |
+
openai.api_key = os.environ.get("OPENAI_KEY", None)
|
22 |
+
|
23 |
+
animals = [x.strip() for x in open('animals.txt').readlines()]
|
24 |
+
art_styles = [x.strip() for x in open('art_styles.txt').readlines()]
|
25 |
+
|
26 |
+
|
27 |
+
class Chat:
|
28 |
+
class Model(enum.Enum):
|
29 |
+
GPT3_5 = "gpt-3.5-turbo"
|
30 |
+
GPT_4 = "gpt-4"
|
31 |
+
|
32 |
+
def __init__(self, system, max_length=4096//2):
|
33 |
+
self._system = system
|
34 |
+
self._max_length = max_length
|
35 |
+
self._history = [
|
36 |
+
{"role": "system", "content": self._system},
|
37 |
+
]
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def num_tokens_from_text(cls, text, model="gpt-3.5-turbo"):
|
41 |
+
"""Returns the number of tokens used by some text."""
|
42 |
+
encoding = tiktoken.encoding_for_model(model)
|
43 |
+
return len(encoding.encode(text))
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def num_tokens_from_messages(cls, messages, model="gpt-3.5-turbo"):
|
47 |
+
"""Returns the number of tokens used by a list of messages."""
|
48 |
+
encoding = tiktoken.encoding_for_model(model)
|
49 |
+
num_tokens = 0
|
50 |
+
for message in messages:
|
51 |
+
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
52 |
+
for key, value in message.items():
|
53 |
+
num_tokens += len(encoding.encode(value))
|
54 |
+
if key == "name": # if there's a name, the role is omitted
|
55 |
+
num_tokens += -1 # role is always required and always 1 token
|
56 |
+
num_tokens += 2 # every reply is primed with <im_start>assistant
|
57 |
+
return num_tokens
|
58 |
+
|
59 |
+
@retrying.retry(stop_max_attempt_number=5, wait_fixed=2000)
|
60 |
+
def _msg(self, *args, model=Model.GPT3_5.value, **kwargs):
|
61 |
+
return openai.ChatCompletion.create(
|
62 |
+
*args,
|
63 |
+
model=model,
|
64 |
+
messages=self._history,
|
65 |
+
**kwargs
|
66 |
+
)
|
67 |
+
|
68 |
+
def message(self, next_msg=None, **kwargs):
|
69 |
+
# TODO: Optimize this if slow through easy caching
|
70 |
+
while len(self._history) > 1 and self.num_tokens_from_messages(self._history) > self._max_length:
|
71 |
+
logger.info(f'Popping message: {self._history.pop(1)}')
|
72 |
+
if next_msg is not None:
|
73 |
+
self._history.append({"role": "user", "content": next_msg})
|
74 |
+
logger.info('requesting openai...')
|
75 |
+
resp = self._msg(**kwargs)
|
76 |
+
logger.info('received openai...')
|
77 |
+
text = resp.choices[0].message.content
|
78 |
+
self._history.append({"role": "assistant", "content": text})
|
79 |
+
return text
|
80 |
+
|
81 |
+
class Weather:
|
82 |
+
def __init__(self, zip_code='10001', api_key=weather_api_key):
|
83 |
+
self.zip_code = zip_code
|
84 |
+
self.api_key = api_key
|
85 |
+
|
86 |
+
def get_weather(self):
|
87 |
+
url = f"https://api.weatherapi.com/v1/forecast.json?q={self.zip_code}&days=1&lang=en&aqi=yes&key={self.api_key}"
|
88 |
+
headers = {'accept': 'application/json'}
|
89 |
+
return requests.get(url, headers=headers).json()
|
90 |
+
|
91 |
+
def get_info(self):
|
92 |
+
weather = self.get_weather()
|
93 |
+
curr_hour = None
|
94 |
+
next_hour = None
|
95 |
+
for hour_data in weather['forecast']['forecastday'][0]["hour"]:
|
96 |
+
if abs(hour_data["time_epoch"] - time.time()) < 60 * 60:
|
97 |
+
if curr_hour is None: curr_hour = hour_data
|
98 |
+
next_hour = hour_data
|
99 |
+
return {
|
100 |
+
"now": weather["current"],
|
101 |
+
"day": weather["forecast"]["forecastday"][0]["day"],
|
102 |
+
"curr_hour": curr_hour,
|
103 |
+
"next_hour": next_hour,
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
class Image:
|
108 |
+
class Size(enum.Enum):
|
109 |
+
SMALL = "256x256"
|
110 |
+
MEDIUM = "512x512"
|
111 |
+
LARGE = "1024x1024"
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
@retrying.retry(stop_max_attempt_number=5, wait_fixed=2000)
|
115 |
+
def create(cls, prompt, n=1, size=Size.SMALL):
|
116 |
+
logger.info('requesting openai.Image...')
|
117 |
+
resp = openai.Image.create(prompt=prompt, n=n, size=size.value, response_format='b64_json')
|
118 |
+
logger.info('received openai.Image...')
|
119 |
+
if n == 1: return resp["data"][0]
|
120 |
+
return resp["data"]
|
121 |
+
|
122 |
+
|
123 |
+
def overlay_text_on_image(img, text, position, text_color=(255, 255, 255), box_color=(0, 0, 0, 128)):
|
124 |
+
# Convert the base64 string back to an image
|
125 |
+
if isinstance(img, str) or isinstance(img, bytes):
|
126 |
+
img_bytes = base64.b64decode(img)
|
127 |
+
img = PIL.Image.open(BytesIO(img_bytes))
|
128 |
+
|
129 |
+
# Get image dimensions
|
130 |
+
img_width, img_height = img.size
|
131 |
+
|
132 |
+
# Create a ImageDraw object
|
133 |
+
draw = PIL.ImageDraw.Draw(img)
|
134 |
+
|
135 |
+
# Reduce the font size until it fits the image width or height
|
136 |
+
l, r = 1, 50
|
137 |
+
while l < r:
|
138 |
+
font_size = (l + r) // 2
|
139 |
+
font = PIL.ImageFont.truetype("/System/Library/Fonts/NewYork.ttf", font_size)
|
140 |
+
left, upper, right, lower = draw.textbbox((0, 0), text, font=font)
|
141 |
+
text_width = right - left
|
142 |
+
text_height = lower - upper
|
143 |
+
if text_width <= img_width and text_height <= img_height:
|
144 |
+
l = font_size + 1
|
145 |
+
else:
|
146 |
+
r = font_size - 1
|
147 |
+
font_size = max(l-1, 1)
|
148 |
+
|
149 |
+
text_width, text_height = draw.textsize(text, font=font)
|
150 |
+
|
151 |
+
if position == 'top-left':
|
152 |
+
x, y = 0, 0
|
153 |
+
elif position == 'top-right':
|
154 |
+
x, y = img_width - text_width, 0
|
155 |
+
elif position == 'bottom-left':
|
156 |
+
x, y = 0, img_height - text_height
|
157 |
+
elif position == 'bottom-right':
|
158 |
+
x, y = img_width - text_width, img_height - text_height
|
159 |
+
else:
|
160 |
+
raise ValueError("Position should be 'top-left', 'top-right', 'bottom-left' or 'bottom-right'.")
|
161 |
+
|
162 |
+
# Draw a semi-transparent box around the text
|
163 |
+
draw.rectangle([x, y, x + text_width, y + text_height], fill=box_color)
|
164 |
+
|
165 |
+
# Draw the text on the image
|
166 |
+
draw.text((x, y), text, font=font, fill=text_color)
|
167 |
+
|
168 |
+
return img
|
169 |
+
|
170 |
+
|
171 |
+
class WeatherDraw:
|
172 |
+
def clean_text(self, weather_info):
|
173 |
+
chat = Chat("Given the following weather conditions, write a very small, concise plaintext summary that will overlay on top of an image.")
|
174 |
+
text = chat.message(str(weather_info))
|
175 |
+
return text
|
176 |
+
|
177 |
+
def generate_image(self, weather_info, **kwargs):
|
178 |
+
animal = random.choice(animals)
|
179 |
+
logger.info(f"Got animal {animal}")
|
180 |
+
chat = Chat(f'''
|
181 |
+
Given the following weather conditions, write a plaintext, short, and vivid description of an
|
182 |
+
adorable {animal} in the weather conditions doing an activity matching the weather.
|
183 |
+
Only write the short description and nothing else.
|
184 |
+
Do not include specific numbers.'''.replace('\n', ' '))
|
185 |
+
description = chat.message(str(weather_info))
|
186 |
+
prompt = f'{description}. Adorable, cute, 4k, Award winning, in the style of {random.choice(art_styles)}'
|
187 |
+
logger.info(prompt)
|
188 |
+
img = Image.create(prompt, **kwargs)
|
189 |
+
return img["b64_json"]
|
190 |
+
|
191 |
+
def step_one_forecast(self, weather_info, **kwargs):
|
192 |
+
img = self.generate_image(weather_info, **kwargs)
|
193 |
+
# text = self.clean_text(weather_info)
|
194 |
+
# return overlay_text_on_image(img, text, 'bottom-left')
|
195 |
+
return img
|
196 |
+
|
197 |
+
def step(self, zip_code='10001', **kwargs):
|
198 |
+
forecast = Weather(zip_code).get_info()
|
199 |
+
return {time: overlay_text_on_image(self.step_one_forecast(data, **kwargs), time, 'top-right') for time, data in forecast.items()}
|
200 |
+
|
201 |
+
|