File size: 1,518 Bytes
5aaa320
52422d2
7d0de60
3bb8b4d
7826a83
5aaa320
52422d2
 
7d0de60
 
 
3bb8b4d
5aaa320
 
 
 
7d0de60
 
31dafcd
5aaa320
 
 
7d0de60
 
 
5aaa320
7d0de60
 
 
 
 
 
 
 
5aaa320
7d0de60
 
 
5aaa320
7d0de60
5aaa320
 
 
 
5b2d750
5aaa320
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, set_seed
from typing import List

# Initialize the FastAPI app
app = FastAPI()

# Initialize the generator pipeline
generator = pipeline('text-generation', model='gpt2-medium')
set_seed(42)

# Data model for FastAPI input
class UserInput(BaseModel):
    conversation: str
    user_input: str
    max_length: int = 50  # default length
    num_return_sequences: int = 1  # default number of sequences

@app.post("/generate/")
async def generate_response(user_input: UserInput):
    try:
        # Construct the prompt from the conversation and user input
        prompt = f"{user_input.conversation}{user_input.user_input}"

        # Generate response
        responses = generator(
            prompt,
            max_length=user_input.max_length,
            num_return_sequences=user_input.num_return_sequences
        )

        # Extract text from each generated sequence
        generated_texts = [response["generated_text"] for response in responses]
        
        # Update conversation with the last generated text
        updated_conversation = f"{prompt}\n{generated_texts[-1]}"

        return {
            "responses": generated_texts,
            "updated_conversation": updated_conversation
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Run the app
# To start the server, use the command: uvicorn filename:app --host 0.0.0.0 --port 8000