mlai / app.py
saifeddinemk's picture
Fixed app v2
7d0de60
raw
history blame
1.52 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, set_seed
from typing import List
# Initialize the FastAPI app
app = FastAPI()
# Initialize the generator pipeline
generator = pipeline('text-generation', model='gpt2-medium')
set_seed(42)
# Data model for FastAPI input
class UserInput(BaseModel):
conversation: str
user_input: str
max_length: int = 50 # default length
num_return_sequences: int = 1 # default number of sequences
@app.post("/generate/")
async def generate_response(user_input: UserInput):
try:
# Construct the prompt from the conversation and user input
prompt = f"{user_input.conversation}{user_input.user_input}"
# Generate response
responses = generator(
prompt,
max_length=user_input.max_length,
num_return_sequences=user_input.num_return_sequences
)
# Extract text from each generated sequence
generated_texts = [response["generated_text"] for response in responses]
# Update conversation with the last generated text
updated_conversation = f"{prompt}\n{generated_texts[-1]}"
return {
"responses": generated_texts,
"updated_conversation": updated_conversation
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Run the app
# To start the server, use the command: uvicorn filename:app --host 0.0.0.0 --port 8000