Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,171 +1,169 @@
|
|
1 |
-
import os
|
2 |
-
import threading
|
3 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
-
from fastapi.responses import FileResponse
|
5 |
-
from fastapi.staticfiles import StaticFiles
|
6 |
-
from huggingface_hub import InferenceClient
|
7 |
-
from concurrent.futures import ThreadPoolExecutor
|
8 |
-
from fastapi.responses import HTMLResponse
|
9 |
-
from groq import Groq
|
10 |
-
import uvicorn
|
11 |
-
import base64
|
12 |
-
from io import BytesIO
|
13 |
-
import os
|
14 |
-
import requests
|
15 |
-
|
16 |
-
API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
|
17 |
-
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
|
18 |
-
|
19 |
-
def is_image_safe(data):
|
20 |
-
try:
|
21 |
-
response = requests.post(API_URL, headers=headers, data=data)
|
22 |
-
response.raise_for_status() # Raise an error for HTTP issues
|
23 |
-
result = response.json()
|
24 |
-
print(result)
|
25 |
-
|
26 |
-
# Look for the 'nsfw' label and extract its score
|
27 |
-
nsfw_score = next((item["score"] for item in result if item["label"] == "nsfw"), None)
|
28 |
-
|
29 |
-
# Ensure the 'nsfw' label exists in the response
|
30 |
-
if nsfw_score is not None:
|
31 |
-
return nsfw_score < 0.5 # Return True if the NSFW score is less than 0.5
|
32 |
-
else:
|
33 |
-
print("Error: 'nsfw' label not found in response.")
|
34 |
-
return False
|
35 |
-
except requests.exceptions.RequestException as e:
|
36 |
-
print(f"Error: Failed to process the request. {e}")
|
37 |
-
return False
|
38 |
-
|
39 |
-
|
40 |
-
app = FastAPI()
|
41 |
-
|
42 |
-
# Serve static files (HTML, CSS, JS)
|
43 |
-
app.mount("/static", StaticFiles(directory="static"), name="static")
|
44 |
-
|
45 |
-
# Initialize a single Groq client and inference client
|
46 |
-
client = Groq()
|
47 |
-
image_gen_client = InferenceClient("black-forest-labs/FLUX.1-schnell")
|
48 |
-
executor = ThreadPoolExecutor()
|
49 |
-
|
50 |
-
@app.get("/", response_class=FileResponse)
|
51 |
-
async def get():
|
52 |
-
return FileResponse("static/index.html")
|
53 |
-
|
54 |
-
|
55 |
-
@app.websocket("/ws")
|
56 |
-
async def websocket_endpoint(websocket: WebSocket):
|
57 |
-
await websocket.accept()
|
58 |
-
|
59 |
-
def generate_text(input_text):
|
60 |
-
chat_completion = client.chat.completions.create(
|
61 |
-
messages=[
|
62 |
-
{"role": "system", "content": "You are an assistant that generates gratitude journal entries. Focus solely on expressing gratitude in a concise, meaningful tone. Do not include introductory or concluding statements. Avoid unrelated topics or personal opinions. Keep the entries simple, neutral, and directly reflective of gratitude."},
|
63 |
-
{"role": "user", "content": input_text},
|
64 |
-
],
|
65 |
-
model="llama3-8b-8192",
|
66 |
-
)
|
67 |
-
print(chat_completion.choices[0].message.content)
|
68 |
-
return chat_completion.choices[0].message.content
|
69 |
-
|
70 |
-
def generate_image_prompt(input_text):
|
71 |
-
chat_completion = client.chat.completions.create(
|
72 |
-
messages=[
|
73 |
-
{"role": "system", "content": "You are an assistant that generates image prompts reflecting gratitude journal entries. Create concise, vivid descriptions that visually represent gratitude and positive moments. Avoid adding any introductory or concluding statements. Keep the prompts simple, neutral, and focused solely on the imagery."},
|
74 |
-
{"role": "user", "content": input_text},
|
75 |
-
],
|
76 |
-
model="llama3-8b-8192",
|
77 |
-
)
|
78 |
-
print(chat_completion.choices[0].message.content)
|
79 |
-
return chat_completion.choices[0].message.content
|
80 |
-
|
81 |
-
def analyze_image(image_data):
|
82 |
-
completion = client.chat.completions.create(
|
83 |
-
model="llama-3.2-11b-vision-preview",
|
84 |
-
messages=[
|
85 |
-
{
|
86 |
-
"role": "user",
|
87 |
-
"content": [
|
88 |
-
{"type": "text", "text": "Describe the contents of this image for a gratitude journal."},
|
89 |
-
{"type": "image", "image": image_data},
|
90 |
-
],
|
91 |
-
}
|
92 |
-
],
|
93 |
-
temperature=1,
|
94 |
-
max_tokens=1024,
|
95 |
-
top_p=1,
|
96 |
-
stream=False,
|
97 |
-
stop=None,
|
98 |
-
)
|
99 |
-
print(completion.choices[0].message.content)
|
100 |
-
return completion.choices[0].message.content
|
101 |
-
|
102 |
-
def handle_text_to_speech(audio_data):
|
103 |
-
transcription = client.audio.transcriptions.create(
|
104 |
-
file=("audio.mp3", audio_data),
|
105 |
-
model="whisper-large-v3-turbo",
|
106 |
-
response_format="json",
|
107 |
-
language="en",
|
108 |
-
)
|
109 |
-
return transcription.text
|
110 |
-
|
111 |
-
def generate_image(prompt):
|
112 |
-
print(prompt)
|
113 |
-
return image_gen_client.text_to_image(prompt)
|
114 |
-
|
115 |
-
try:
|
116 |
-
while True:
|
117 |
-
data = await websocket.receive_json()
|
118 |
-
input_text = data.get("text", "")
|
119 |
-
input_images = data.get("images", [])
|
120 |
-
input_audio = data.get("audio", None)
|
121 |
-
|
122 |
-
journal_text = input_text
|
123 |
-
tasks = []
|
124 |
-
|
125 |
-
# Process images
|
126 |
-
for image in input_images:
|
127 |
-
def process_image(image=image):
|
128 |
-
image_description = analyze_image(image)
|
129 |
-
nonlocal journal_text
|
130 |
-
journal_text += f" {image_description}"
|
131 |
-
tasks.append(executor.submit(process_image))
|
132 |
-
|
133 |
-
# Process audio
|
134 |
-
if input_audio:
|
135 |
-
def process_audio():
|
136 |
-
audio_transcription = handle_text_to_speech(input_audio)
|
137 |
-
nonlocal journal_text
|
138 |
-
journal_text += f" {audio_transcription}"
|
139 |
-
tasks.append(executor.submit(process_audio))
|
140 |
-
|
141 |
-
# Wait for all tasks to complete
|
142 |
-
for task in tasks:
|
143 |
-
task.result()
|
144 |
-
|
145 |
-
# Generate journal entry
|
146 |
-
journal_output = generate_text(journal_text)
|
147 |
-
journal_image_prompt= generate_image_prompt(journal_text)
|
148 |
-
await websocket.send_json({"type": "journal", "content": journal_output})
|
149 |
-
await websocket.send_json({"type": "info", "content": "Gratitude-themed image is being generated"})
|
150 |
-
image_prompt = f"Generate a gratitude-themed artistic image based on this journal entry: {journal_image_prompt}"
|
151 |
-
generated_image = generate_image(image_prompt) # This returns a PIL.Image object
|
152 |
-
print("image generated")
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
# Convert PIL.Image to Base64
|
157 |
-
buffered = BytesIO()
|
158 |
-
generated_image.save(buffered, format="PNG") # Save as PNG or any format you prefer
|
159 |
-
img_str=""
|
160 |
-
if is_image_safe(buffered.getvalue()):
|
161 |
-
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Base64 encode the image
|
162 |
-
|
163 |
-
# Send the image as a base64 string via WebSocket
|
164 |
-
await websocket.send_json({"type": "image", "image": img_str})
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
except WebSocketDisconnect:
|
169 |
-
await websocket.close()
|
170 |
-
if __name__ == "__main__":
|
171 |
-
uvicorn.run("app:app", host="0.0.0.0", port=8090, reload=True)
|
|
|
1 |
+
import os
|
2 |
+
import threading
|
3 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
+
from fastapi.responses import FileResponse
|
5 |
+
from fastapi.staticfiles import StaticFiles
|
6 |
+
from huggingface_hub import InferenceClient
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from fastapi.responses import HTMLResponse
|
9 |
+
from groq import Groq
|
10 |
+
import uvicorn
|
11 |
+
import base64
|
12 |
+
from io import BytesIO
|
13 |
+
import os
|
14 |
+
import requests
|
15 |
+
|
16 |
+
API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
|
17 |
+
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
|
18 |
+
|
19 |
+
def is_image_safe(data):
|
20 |
+
try:
|
21 |
+
response = requests.post(API_URL, headers=headers, data=data)
|
22 |
+
response.raise_for_status() # Raise an error for HTTP issues
|
23 |
+
result = response.json()
|
24 |
+
print(result)
|
25 |
+
|
26 |
+
# Look for the 'nsfw' label and extract its score
|
27 |
+
nsfw_score = next((item["score"] for item in result if item["label"] == "nsfw"), None)
|
28 |
+
|
29 |
+
# Ensure the 'nsfw' label exists in the response
|
30 |
+
if nsfw_score is not None:
|
31 |
+
return nsfw_score < 0.5 # Return True if the NSFW score is less than 0.5
|
32 |
+
else:
|
33 |
+
print("Error: 'nsfw' label not found in response.")
|
34 |
+
return False
|
35 |
+
except requests.exceptions.RequestException as e:
|
36 |
+
print(f"Error: Failed to process the request. {e}")
|
37 |
+
return False
|
38 |
+
|
39 |
+
|
40 |
+
app = FastAPI()
|
41 |
+
|
42 |
+
# Serve static files (HTML, CSS, JS)
|
43 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
44 |
+
|
45 |
+
# Initialize a single Groq client and inference client
|
46 |
+
client = Groq()
|
47 |
+
image_gen_client = InferenceClient("black-forest-labs/FLUX.1-schnell")
|
48 |
+
executor = ThreadPoolExecutor()
|
49 |
+
|
50 |
+
@app.get("/", response_class=FileResponse)
|
51 |
+
async def get():
|
52 |
+
return FileResponse("static/index.html")
|
53 |
+
|
54 |
+
|
55 |
+
@app.websocket("/ws")
|
56 |
+
async def websocket_endpoint(websocket: WebSocket):
|
57 |
+
await websocket.accept()
|
58 |
+
|
59 |
+
def generate_text(input_text):
|
60 |
+
chat_completion = client.chat.completions.create(
|
61 |
+
messages=[
|
62 |
+
{"role": "system", "content": "You are an assistant that generates gratitude journal entries. Focus solely on expressing gratitude in a concise, meaningful tone. Do not include introductory or concluding statements. Avoid unrelated topics or personal opinions. Keep the entries simple, neutral, and directly reflective of gratitude."},
|
63 |
+
{"role": "user", "content": input_text},
|
64 |
+
],
|
65 |
+
model="llama3-8b-8192",
|
66 |
+
)
|
67 |
+
print(chat_completion.choices[0].message.content)
|
68 |
+
return chat_completion.choices[0].message.content
|
69 |
+
|
70 |
+
def generate_image_prompt(input_text):
|
71 |
+
chat_completion = client.chat.completions.create(
|
72 |
+
messages=[
|
73 |
+
{"role": "system", "content": "You are an assistant that generates image prompts reflecting gratitude journal entries. Create concise, vivid descriptions that visually represent gratitude and positive moments. Avoid adding any introductory or concluding statements. Keep the prompts simple, neutral, and focused solely on the imagery."},
|
74 |
+
{"role": "user", "content": input_text},
|
75 |
+
],
|
76 |
+
model="llama3-8b-8192",
|
77 |
+
)
|
78 |
+
print(chat_completion.choices[0].message.content)
|
79 |
+
return chat_completion.choices[0].message.content
|
80 |
+
|
81 |
+
def analyze_image(image_data):
|
82 |
+
completion = client.chat.completions.create(
|
83 |
+
model="llama-3.2-11b-vision-preview",
|
84 |
+
messages=[
|
85 |
+
{
|
86 |
+
"role": "user",
|
87 |
+
"content": [
|
88 |
+
{"type": "text", "text": "Describe the contents of this image for a gratitude journal."},
|
89 |
+
{"type": "image", "image": image_data},
|
90 |
+
],
|
91 |
+
}
|
92 |
+
],
|
93 |
+
temperature=1,
|
94 |
+
max_tokens=1024,
|
95 |
+
top_p=1,
|
96 |
+
stream=False,
|
97 |
+
stop=None,
|
98 |
+
)
|
99 |
+
print(completion.choices[0].message.content)
|
100 |
+
return completion.choices[0].message.content
|
101 |
+
|
102 |
+
def handle_text_to_speech(audio_data):
|
103 |
+
transcription = client.audio.transcriptions.create(
|
104 |
+
file=("audio.mp3", audio_data),
|
105 |
+
model="whisper-large-v3-turbo",
|
106 |
+
response_format="json",
|
107 |
+
language="en",
|
108 |
+
)
|
109 |
+
return transcription.text
|
110 |
+
|
111 |
+
def generate_image(prompt):
|
112 |
+
print(prompt)
|
113 |
+
return image_gen_client.text_to_image(prompt)
|
114 |
+
|
115 |
+
try:
|
116 |
+
while True:
|
117 |
+
data = await websocket.receive_json()
|
118 |
+
input_text = data.get("text", "")
|
119 |
+
input_images = data.get("images", [])
|
120 |
+
input_audio = data.get("audio", None)
|
121 |
+
|
122 |
+
journal_text = input_text
|
123 |
+
tasks = []
|
124 |
+
|
125 |
+
# Process images
|
126 |
+
for image in input_images:
|
127 |
+
def process_image(image=image):
|
128 |
+
image_description = analyze_image(image)
|
129 |
+
nonlocal journal_text
|
130 |
+
journal_text += f" {image_description}"
|
131 |
+
tasks.append(executor.submit(process_image))
|
132 |
+
|
133 |
+
# Process audio
|
134 |
+
if input_audio:
|
135 |
+
def process_audio():
|
136 |
+
audio_transcription = handle_text_to_speech(input_audio)
|
137 |
+
nonlocal journal_text
|
138 |
+
journal_text += f" {audio_transcription}"
|
139 |
+
tasks.append(executor.submit(process_audio))
|
140 |
+
|
141 |
+
# Wait for all tasks to complete
|
142 |
+
for task in tasks:
|
143 |
+
task.result()
|
144 |
+
|
145 |
+
# Generate journal entry
|
146 |
+
journal_output = generate_text(journal_text)
|
147 |
+
journal_image_prompt= generate_image_prompt(journal_text)
|
148 |
+
await websocket.send_json({"type": "journal", "content": journal_output})
|
149 |
+
await websocket.send_json({"type": "info", "content": "Gratitude-themed image is being generated"})
|
150 |
+
image_prompt = f"Generate a gratitude-themed artistic image based on this journal entry: {journal_image_prompt}"
|
151 |
+
generated_image = generate_image(image_prompt) # This returns a PIL.Image object
|
152 |
+
print("image generated")
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
# Convert PIL.Image to Base64
|
157 |
+
buffered = BytesIO()
|
158 |
+
generated_image.save(buffered, format="PNG") # Save as PNG or any format you prefer
|
159 |
+
img_str=""
|
160 |
+
if is_image_safe(buffered.getvalue()):
|
161 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Base64 encode the image
|
162 |
+
|
163 |
+
# Send the image as a base64 string via WebSocket
|
164 |
+
await websocket.send_json({"type": "image", "image": img_str})
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
except WebSocketDisconnect:
|
169 |
+
await websocket.close()
|
|
|
|