File size: 6,708 Bytes
98dfe53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import threading
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import  FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
from concurrent.futures import ThreadPoolExecutor
from fastapi.responses import HTMLResponse
from groq import Groq
import uvicorn
import base64
from io import BytesIO
import os
import requests

API_URL = "https://api-inference.huggingface.co/models/Falconsai/nsfw_image_detection"
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}

def is_image_safe(data):
    try:
        response = requests.post(API_URL, headers=headers, data=data)
        response.raise_for_status()  # Raise an error for HTTP issues
        result = response.json()
        print(result)

        # Look for the 'nsfw' label and extract its score
        nsfw_score = next((item["score"] for item in result if item["label"] == "nsfw"), None)

        # Ensure the 'nsfw' label exists in the response
        if nsfw_score is not None:
            return nsfw_score < 0.5  # Return True if the NSFW score is less than 0.5
        else:
            print("Error: 'nsfw' label not found in response.")
            return False
    except requests.exceptions.RequestException as e:
        print(f"Error: Failed to process the request. {e}")
        return False


app = FastAPI()

# Serve static files (HTML, CSS, JS)
app.mount("/static", StaticFiles(directory="static"), name="static")

# Initialize a single Groq client and inference client
client = Groq()
image_gen_client = InferenceClient("black-forest-labs/FLUX.1-schnell")
executor = ThreadPoolExecutor()

@app.get("/", response_class=FileResponse)
async def get():
    return FileResponse("static/index.html")


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()

    def generate_text(input_text):
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are an assistant that generates gratitude journal entries. Focus solely on expressing gratitude in a concise, meaningful tone. Do not include introductory or concluding statements. Avoid unrelated topics or personal opinions. Keep the entries simple, neutral, and directly reflective of gratitude."},
                {"role": "user", "content": input_text},
            ],
            model="llama3-8b-8192",
        )
        print(chat_completion.choices[0].message.content)
        return chat_completion.choices[0].message.content
    
    def generate_image_prompt(input_text):
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are an assistant that generates image prompts reflecting gratitude journal entries. Create concise, vivid descriptions that visually represent gratitude and positive moments. Avoid adding any introductory or concluding statements. Keep the prompts simple, neutral, and focused solely on the imagery."},
                {"role": "user", "content": input_text},
            ],
            model="llama3-8b-8192",
        )
        print(chat_completion.choices[0].message.content)
        return chat_completion.choices[0].message.content

    def analyze_image(image_data):
        completion = client.chat.completions.create(
            model="llama-3.2-11b-vision-preview",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Describe the contents of this image for a gratitude journal."},
                        {"type": "image", "image": image_data},
                    ],
                }
            ],
            temperature=1,
            max_tokens=1024,
            top_p=1,
            stream=False,
            stop=None,
        )
        print(completion.choices[0].message.content)
        return completion.choices[0].message.content

    def handle_text_to_speech(audio_data):
        transcription = client.audio.transcriptions.create(
            file=("audio.mp3", audio_data),
            model="whisper-large-v3-turbo",
            response_format="json",
            language="en",
        )
        return transcription.text

    def generate_image(prompt):
        print(prompt)
        return image_gen_client.text_to_image(prompt)

    try:
        while True:
            data = await websocket.receive_json()
            input_text = data.get("text", "")
            input_images = data.get("images", [])
            input_audio = data.get("audio", None)

            journal_text = input_text
            tasks = []

            # Process images
            for image in input_images:
                def process_image(image=image):
                    image_description = analyze_image(image)
                    nonlocal journal_text
                    journal_text += f" {image_description}"  
                tasks.append(executor.submit(process_image))

            # Process audio
            if input_audio:
                def process_audio():
                    audio_transcription = handle_text_to_speech(input_audio)
                    nonlocal journal_text
                    journal_text += f" {audio_transcription}"  
                tasks.append(executor.submit(process_audio))

            # Wait for all tasks to complete
            for task in tasks:
                task.result()

            # Generate journal entry
            journal_output = generate_text(journal_text)
            journal_image_prompt= generate_image_prompt(journal_text)
            await websocket.send_json({"type": "journal", "content": journal_output})
            await websocket.send_json({"type": "info", "content": "Gratitude-themed image is being generated"})
            image_prompt = f"Generate a gratitude-themed artistic image based on this journal entry: {journal_image_prompt}"
            generated_image = generate_image(image_prompt)  # This returns a PIL.Image object
            print("image generated")
            


            # Convert PIL.Image to Base64
            buffered = BytesIO()
            generated_image.save(buffered, format="PNG")  # Save as PNG or any format you prefer
            img_str=""
            if is_image_safe(buffered.getvalue()):
                img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")  # Base64 encode the image

            # Send the image as a base64 string via WebSocket
            await websocket.send_json({"type": "image", "image": img_str})

            

    except WebSocketDisconnect:
        await websocket.close()