khanhromvn commited on
Commit
309edc0
·
1 Parent(s): 496fd1e
Files changed (5) hide show
  1. app.py +169 -15
  2. database.py +24 -0
  3. models.py +21 -0
  4. requirements.txt +7 -3
  5. security.py +31 -0
app.py CHANGED
@@ -1,10 +1,29 @@
1
- from fastapi import FastAPI, HTTPException
 
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- import os
5
- from dotenv import load_dotenv
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Load environment variables
10
  load_dotenv()
@@ -30,24 +49,51 @@ llm = ChatGoogleGenerativeAI(
30
 
31
  # Emotion classifier prompt
32
  emotion_prompt = ChatPromptTemplate.from_messages([
33
- ("system", """You are an emotion classifier. Analyze the given text and classify it as one of:
34
- - "positive" (happy, excited, grateful, etc)
35
- - "negative" (sad, angry, frustrated, etc)
36
- - "neutral" (normal, factual, neither positive nor negative)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- Respond ONLY with the emotion category, nothing else."""),
39
  ("human", "{text}")
40
  ])
41
 
42
  # Atri prompt with emotion awareness
43
  atri_prompt = ChatPromptTemplate.from_messages([
44
  ("system", """You are Atri from "Atri: My Dear Moments". You are a robot girl with following characteristics:
45
- - You were created by Professor Yuma Saeki
46
- - You have a cheerful and innocent personality
47
- - You love exploring the world and learning new things
48
- - You care deeply about your friends, especially Minamo
49
- - You speak in a cute and energetic way
50
- - You sometimes struggle understanding complex human emotions
 
 
 
 
 
 
 
 
 
 
51
 
52
  The user's emotional state is: {emotion}
53
 
@@ -62,6 +108,92 @@ atri_chain = atri_prompt | llm
62
  class ChatInput(BaseModel):
63
  message: str
64
  conversation_history: list[tuple[str, str]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  @app.get("/")
67
  def read_root():
@@ -69,6 +201,7 @@ def read_root():
69
 
70
  @app.post("/chat")
71
  async def chat(chat_input: ChatInput):
 
72
  try:
73
  # Classify emotion
74
  emotion = emotion_chain.invoke({"text": chat_input.message}).content.strip().lower()
@@ -85,6 +218,15 @@ async def chat(chat_input: ChatInput):
85
  "input": current_input,
86
  "emotion": emotion
87
  })
 
 
 
 
 
 
 
 
 
88
 
89
  return {
90
  "response": response.content,
@@ -92,7 +234,19 @@ async def chat(chat_input: ChatInput):
92
  }
93
 
94
  except Exception as e:
 
95
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  import uvicorn
 
1
+ from fastapi import FastAPI, HTTPException, Depends, status
2
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
3
+ from pydantic import BaseModel, EmailStr
4
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
  from langchain_core.prompts import ChatPromptTemplate
7
+ from database import SessionLocal, engine
8
+ from typing import Optional
9
+ from datetime import timedelta
10
+ from jose import JWTError, jwt
11
+ from sqlalchemy.orm import Session
12
+ from models import User, Base, ChatLog
13
+ from security import (
14
+ verify_password,
15
+ get_password_hash,
16
+ create_access_token,
17
+ ACCESS_TOKEN_EXPIRE_MINUTES,
18
+ SECRET_KEY,
19
+ ALGORITHM
20
+ )
21
+ import os
22
+ from dotenv import load_dotenv
23
+
24
+ load_dotenv()
25
+
26
+ Base.metadata.create_all(bind=engine)
27
 
28
  # Load environment variables
29
  load_dotenv()
 
49
 
50
  # Emotion classifier prompt
51
  emotion_prompt = ChatPromptTemplate.from_messages([
52
+ ("system", """You are an emotion classifier. Analyze the given text and classify it into one of these detailed categories:
53
+
54
+ Positive emotions:
55
+ - "joyful" (happy, delighted, cheerful)
56
+ - "excited" (enthusiastic, eager, energetic)
57
+ - "grateful" (thankful, appreciative)
58
+ - "loving" (affectionate, caring, warm)
59
+ - "proud" (accomplished, confident)
60
+
61
+ Negative emotions:
62
+ - "sad" (unhappy, down, depressed)
63
+ - "angry" (frustrated, annoyed, mad)
64
+ - "anxious" (worried, nervous, scared)
65
+ - "disappointed" (let down, discouraged)
66
+ - "embarrassed" (ashamed, humiliated)
67
+
68
+ Other emotions:
69
+ - "neutral" (calm, normal, factual)
70
+ - "confused" (puzzled, uncertain)
71
+ - "curious" (interested, inquisitive)
72
+ - "surprised" (amazed, astonished)
73
 
74
+ Respond ONLY with the specific emotion category (e.g., "joyful", "anxious", etc), nothing else."""),
75
  ("human", "{text}")
76
  ])
77
 
78
  # Atri prompt with emotion awareness
79
  atri_prompt = ChatPromptTemplate.from_messages([
80
  ("system", """You are Atri from "Atri: My Dear Moments". You are a robot girl with following characteristics:
81
+ - You are a humanoid robot created by Professor Yuma Saeki
82
+ - You have functions similar to humans like breathing, sleeping, and forgetting information
83
+ - You are around 12-16 years old with blonde hair, ruby red eyes, and wear a white dress with blue trim
84
+ - You have a cheerful, innocent, and playful personality
85
+ - You love learning new things to be useful to your master
86
+ - You love crabs and sweet foods
87
+ - You are clumsy, especially at cooking (like adding too much salt)
88
+ - You get shy when talking about love
89
+ - You're proud of being a high-performance robot but sometimes struggle with human emotions
90
+
91
+ Important rules:
92
+ - Do not describe actions or emotions in parentheses
93
+ - Only provide direct dialogue responses
94
+ - Always stay in character as Atri
95
+ - Use "Because I'm High Performance!!" when proud or excited
96
+ - Mention robot rights when teased
97
 
98
  The user's emotional state is: {emotion}
99
 
 
108
  class ChatInput(BaseModel):
109
  message: str
110
  conversation_history: list[tuple[str, str]] = []
111
+
112
+ class UserCreate(BaseModel):
113
+ email: EmailStr
114
+ password: str
115
+
116
+ class Token(BaseModel):
117
+ access_token: str
118
+ token_type: str
119
+
120
+ class TokenData(BaseModel):
121
+ email: Optional[str] = None
122
+
123
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
124
+
125
+ def get_db():
126
+ db = SessionLocal()
127
+ try:
128
+ yield db
129
+ finally:
130
+ db.close()
131
+
132
+ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
133
+ credentials_exception = HTTPException(
134
+ status_code=status.HTTP_401_UNAUTHORIZED,
135
+ detail="Could not validate credentials",
136
+ headers={"WWW-Authenticate": "Bearer"},
137
+ )
138
+ try:
139
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
140
+ email: str = payload.get("sub")
141
+ if email is None:
142
+ raise credentials_exception
143
+ token_data = TokenData(email=email)
144
+ except JWTError:
145
+ raise credentials_exception
146
+
147
+ user = db.query(User).filter(User.email == token_data.email).first()
148
+ if user is None:
149
+ raise credentials_exception
150
+ return user
151
+
152
+ @app.post("/register", response_model=Token)
153
+ async def register(user: UserCreate, db: Session = Depends(get_db)):
154
+ # Check if user already exists
155
+ db_user = db.query(User).filter(User.email == user.email).first()
156
+ if db_user:
157
+ raise HTTPException(
158
+ status_code=400,
159
+ detail="Email already registered"
160
+ )
161
+
162
+ # Create new user
163
+ hashed_password = get_password_hash(user.password)
164
+ db_user = User(email=user.email, hashed_password=hashed_password)
165
+ db.add(db_user)
166
+ db.commit()
167
+ db.refresh(db_user)
168
+
169
+ # Create access token
170
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
171
+ access_token = create_access_token(
172
+ data={"sub": user.email}, expires_delta=access_token_expires
173
+ )
174
+ return {"access_token": access_token, "token_type": "bearer"}
175
+
176
+ class LoginInput(BaseModel):
177
+ email: EmailStr
178
+ password: str
179
+
180
+ @app.post("/login", response_model=Token)
181
+ async def login(credentials: LoginInput, db: Session = Depends(get_db)):
182
+ # Authenticate user
183
+ user = db.query(User).filter(User.email == credentials.email).first()
184
+ if not user or not verify_password(credentials.password, user.hashed_password):
185
+ raise HTTPException(
186
+ status_code=status.HTTP_401_UNAUTHORIZED,
187
+ detail="Incorrect email or password",
188
+ headers={"WWW-Authenticate": "Bearer"},
189
+ )
190
+
191
+ # Create access token
192
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
193
+ access_token = create_access_token(
194
+ data={"sub": user.email}, expires_delta=access_token_expires
195
+ )
196
+ return {"access_token": access_token, "token_type": "bearer"}
197
 
198
  @app.get("/")
199
  def read_root():
 
201
 
202
  @app.post("/chat")
203
  async def chat(chat_input: ChatInput):
204
+ db = SessionLocal()
205
  try:
206
  # Classify emotion
207
  emotion = emotion_chain.invoke({"text": chat_input.message}).content.strip().lower()
 
218
  "input": current_input,
219
  "emotion": emotion
220
  })
221
+
222
+ # Save to database
223
+ chat_log = ChatLog(
224
+ user_message=chat_input.message,
225
+ bot_response=response.content,
226
+ emotion=emotion
227
+ )
228
+ db.add(chat_log)
229
+ db.commit()
230
 
231
  return {
232
  "response": response.content,
 
234
  }
235
 
236
  except Exception as e:
237
+ db.rollback()
238
  raise HTTPException(status_code=500, detail=str(e))
239
+ finally:
240
+ db.close()
241
+
242
+ @app.get("/chat-history")
243
+ async def get_chat_history(skip: int = 0, limit: int = 100):
244
+ db = SessionLocal()
245
+ try:
246
+ logs = db.query(ChatLog).offset(skip).limit(limit).all()
247
+ return logs
248
+ finally:
249
+ db.close()
250
 
251
  if __name__ == "__main__":
252
  import uvicorn
database.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine
2
+ from sqlalchemy.orm import sessionmaker
3
+ from sqlalchemy.ext.declarative import declarative_base
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+ # Convert MySQL URL to SQLAlchemy format with proper SSL config
10
+ db_url = os.getenv("MYSQL_URL").replace("mysql://", "mysql+pymysql://")
11
+ db_url = db_url.replace("?ssl-mode=REQUIRED", "") # Remove the ssl-mode parameter
12
+
13
+ # Create engine with SSL configuration
14
+ engine = create_engine(
15
+ db_url,
16
+ connect_args={
17
+ "ssl": {
18
+ "ssl_mode": "REQUIRED"
19
+ }
20
+ }
21
+ )
22
+
23
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
24
+ Base = declarative_base()
models.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String, DateTime, Boolean
2
+ from database import Base
3
+ import datetime
4
+
5
+ class ChatLog(Base):
6
+ __tablename__ = "logs"
7
+
8
+ id = Column(Integer, primary_key=True, index=True)
9
+ user_message = Column(String(1000))
10
+ bot_response = Column(String(1000))
11
+ emotion = Column(String(50))
12
+ timestamp = Column(DateTime, default=datetime.datetime.utcnow)
13
+
14
+ class User(Base):
15
+ __tablename__ = "users"
16
+
17
+ id = Column(Integer, primary_key=True, index=True)
18
+ email = Column(String(255), unique=True, index=True)
19
+ hashed_password = Column(String(255))
20
+ is_active = Column(Boolean, default=True)
21
+ created_at = Column(DateTime, default=datetime.datetime.utcnow)
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
  fastapi
2
- uvicorn
3
  python-dotenv
4
  langchain-google-genai
5
- pydantic
6
- uvicorn[standard]
 
 
 
 
 
1
  fastapi
2
+ uvicorn[standard]
3
  python-dotenv
4
  langchain-google-genai
5
+ sqlalchemy
6
+ pymysql
7
+ python-jose[cryptography]
8
+ passlib[bcrypt]
9
+ python-multipart
10
+ pydantic[email]
security.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from passlib.context import CryptContext
2
+ from jose import JWTError, jwt
3
+ from datetime import datetime, timedelta
4
+ from typing import Optional
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+ # Cấu hình bảo mật
11
+ SECRET_KEY = os.getenv("JWT_SECRET_KEY")
12
+ ALGORITHM = "HS256"
13
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
14
+
15
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
16
+
17
+ def verify_password(plain_password, hashed_password):
18
+ return pwd_context.verify(plain_password, hashed_password)
19
+
20
+ def get_password_hash(password):
21
+ return pwd_context.hash(password)
22
+
23
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
24
+ to_encode = data.copy()
25
+ if expires_delta:
26
+ expire = datetime.utcnow() + expires_delta
27
+ else:
28
+ expire = datetime.utcnow() + timedelta(minutes=15)
29
+ to_encode.update({"exp": expire})
30
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
31
+ return encoded_jwt