Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline, set_seed | |
from typing import List | |
# Initialize the FastAPI app | |
app = FastAPI() | |
# Initialize the generator pipeline | |
generator = pipeline('text-generation', model='gpt2-medium') | |
set_seed(42) | |
# Data model for FastAPI input | |
class UserInput(BaseModel): | |
conversation: str | |
user_input: str | |
max_length: int = 50 # default length | |
num_return_sequences: int = 1 # default number of sequences | |
async def generate_response(user_input: UserInput): | |
try: | |
# Construct the prompt from the conversation and user input | |
prompt = f"{user_input.conversation}{user_input.user_input}" | |
# Generate response | |
responses = generator( | |
prompt, | |
max_length=user_input.max_length, | |
num_return_sequences=user_input.num_return_sequences | |
) | |
# Extract text from each generated sequence | |
generated_texts = [response["generated_text"] for response in responses] | |
# Update conversation with the last generated text | |
updated_conversation = f"{prompt}\n{generated_texts[-1]}" | |
return { | |
"responses": generated_texts, | |
"updated_conversation": updated_conversation | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Run the app | |
# To start the server, use the command: uvicorn filename:app --host 0.0.0.0 --port 8000 | |