Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
from fastapi import FastAPI, HTTPException
|
| 4 |
-
from fastapi.responses import JSONResponse
|
| 5 |
from pydantic import BaseModel, field_validator
|
| 6 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
|
| 7 |
import boto3
|
|
@@ -55,16 +55,16 @@ class GenerateRequest(BaseModel):
|
|
| 55 |
|
| 56 |
@field_validator("max_new_tokens")
|
| 57 |
def max_new_tokens_must_be_within_limit(cls, v):
|
| 58 |
-
if v >
|
| 59 |
-
raise ValueError("max_new_tokens cannot be greater than
|
| 60 |
return v
|
| 61 |
|
| 62 |
class S3ModelLoader:
|
| 63 |
-
def
|
| 64 |
self.bucket_name = bucket_name
|
| 65 |
self.s3_client = s3_client
|
| 66 |
|
| 67 |
-
def
|
| 68 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
| 69 |
|
| 70 |
async def load_model_and_tokenizer(self, model_name):
|
|
@@ -187,7 +187,10 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
| 187 |
audio = audio_generator(validated_body.input_text)[0]
|
| 188 |
|
| 189 |
audio_byte_arr = BytesIO()
|
| 190 |
-
audio.
|
|
|
|
|
|
|
|
|
|
| 191 |
audio_byte_arr.seek(0)
|
| 192 |
|
| 193 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
|
@@ -204,7 +207,10 @@ async def generate_video(request: GenerateRequest):
|
|
| 204 |
video = video_generator(validated_body.input_text)[0]
|
| 205 |
|
| 206 |
video_byte_arr = BytesIO()
|
| 207 |
-
video
|
|
|
|
|
|
|
|
|
|
| 208 |
video_byte_arr.seek(0)
|
| 209 |
|
| 210 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
from fastapi import FastAPI, HTTPException
|
| 4 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 5 |
from pydantic import BaseModel, field_validator
|
| 6 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
|
| 7 |
import boto3
|
|
|
|
| 55 |
|
| 56 |
@field_validator("max_new_tokens")
|
| 57 |
def max_new_tokens_must_be_within_limit(cls, v):
|
| 58 |
+
if v > 500:
|
| 59 |
+
raise ValueError("max_new_tokens cannot be greater than 500.")
|
| 60 |
return v
|
| 61 |
|
| 62 |
class S3ModelLoader:
|
| 63 |
+
def __init__(self, bucket_name, s3_client):
|
| 64 |
self.bucket_name = bucket_name
|
| 65 |
self.s3_client = s3_client
|
| 66 |
|
| 67 |
+
def _get_s3_uri(self, model_name):
|
| 68 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
| 69 |
|
| 70 |
async def load_model_and_tokenizer(self, model_name):
|
|
|
|
| 187 |
audio = audio_generator(validated_body.input_text)[0]
|
| 188 |
|
| 189 |
audio_byte_arr = BytesIO()
|
| 190 |
+
# It is expected that the audio is saved as wav.
|
| 191 |
+
# Saving like this will not always work. Please check how your
|
| 192 |
+
# audio_generator model is working.
|
| 193 |
+
audio_generator.save_audio(audio_byte_arr, audio)
|
| 194 |
audio_byte_arr.seek(0)
|
| 195 |
|
| 196 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
|
|
|
| 207 |
video = video_generator(validated_body.input_text)[0]
|
| 208 |
|
| 209 |
video_byte_arr = BytesIO()
|
| 210 |
+
# Same as above. Please check how your video model is returning the
|
| 211 |
+
# videos and save them accordingly.
|
| 212 |
+
# It is expected that the video is saved as MP4
|
| 213 |
+
video_generator.save_video(video_byte_arr, video)
|
| 214 |
video_byte_arr.seek(0)
|
| 215 |
|
| 216 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|