saifeddinemk commited on
Commit
a510792
1 Parent(s): 7d0de60

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -1,44 +1,36 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import pipeline, set_seed
4
- from typing import List
5
 
6
  # Initialize the FastAPI app
7
  app = FastAPI()
8
 
9
- # Initialize the generator pipeline
10
- generator = pipeline('text-generation', model='gpt2-medium')
11
- set_seed(42)
12
 
13
- # Data model for FastAPI input
14
- class UserInput(BaseModel):
15
- conversation: str
16
- user_input: str
17
- max_length: int = 50 # default length
18
- num_return_sequences: int = 1 # default number of sequences
 
19
 
20
  @app.post("/generate/")
21
- async def generate_response(user_input: UserInput):
22
  try:
23
- # Construct the prompt from the conversation and user input
24
- prompt = f"{user_input.conversation}{user_input.user_input}"
25
-
26
- # Generate response
27
- responses = generator(
28
- prompt,
29
- max_length=user_input.max_length,
30
- num_return_sequences=user_input.num_return_sequences
31
- )
32
-
33
- # Extract text from each generated sequence
34
- generated_texts = [response["generated_text"] for response in responses]
35
 
36
- # Update conversation with the last generated text
37
- updated_conversation = f"{prompt}\n{generated_texts[-1]}"
38
 
39
  return {
40
- "responses": generated_texts,
41
- "updated_conversation": updated_conversation
42
  }
43
  except Exception as e:
44
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import pipeline
4
+ from typing import List, Dict
5
 
6
  # Initialize the FastAPI app
7
  app = FastAPI()
8
 
9
+ # Initialize the text generation pipeline
10
+ pipe = pipeline("text-generation", model="CyberNative-AI/Colibri_8b_v0.1")
 
11
 
12
+ # Define the input schema for FastAPI
13
+ class Message(BaseModel):
14
+ role: str
15
+ content: str
16
+
17
+ class MessagesInput(BaseModel):
18
+ messages: List[Message]
19
 
20
  @app.post("/generate/")
21
+ async def generate_response(messages_input: MessagesInput):
22
  try:
23
+ # Convert messages to the expected format
24
+ messages = [{"role": msg.role, "content": msg.content} for msg in messages_input.messages]
25
+
26
+ # Generate response using the pipeline
27
+ response = pipe(messages)
 
 
 
 
 
 
 
28
 
29
+ # Extract generated text
30
+ generated_text = response[0]["generated_text"]
31
 
32
  return {
33
+ "response": generated_text
 
34
  }
35
  except Exception as e:
36
  raise HTTPException(status_code=500, detail=str(e))