Spaces:
Sleeping
Sleeping
File size: 1,518 Bytes
5aaa320 52422d2 7d0de60 3bb8b4d 7826a83 5aaa320 52422d2 7d0de60 3bb8b4d 5aaa320 7d0de60 31dafcd 5aaa320 7d0de60 5aaa320 7d0de60 5aaa320 7d0de60 5aaa320 7d0de60 5aaa320 5b2d750 5aaa320 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, set_seed
from typing import List
# Initialize the FastAPI app
app = FastAPI()
# Initialize the generator pipeline
generator = pipeline('text-generation', model='gpt2-medium')
set_seed(42)
# Data model for FastAPI input
class UserInput(BaseModel):
conversation: str
user_input: str
max_length: int = 50 # default length
num_return_sequences: int = 1 # default number of sequences
@app.post("/generate/")
async def generate_response(user_input: UserInput):
try:
# Construct the prompt from the conversation and user input
prompt = f"{user_input.conversation}{user_input.user_input}"
# Generate response
responses = generator(
prompt,
max_length=user_input.max_length,
num_return_sequences=user_input.num_return_sequences
)
# Extract text from each generated sequence
generated_texts = [response["generated_text"] for response in responses]
# Update conversation with the last generated text
updated_conversation = f"{prompt}\n{generated_texts[-1]}"
return {
"responses": generated_texts,
"updated_conversation": updated_conversation
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Run the app
# To start the server, use the command: uvicorn filename:app --host 0.0.0.0 --port 8000
|