Spaces:
Runtime error
Runtime error
Commit
·
309edc0
1
Parent(s):
496fd1e
update
Browse files- app.py +169 -15
- database.py +24 -0
- models.py +21 -0
- requirements.txt +7 -3
- security.py +31 -0
app.py
CHANGED
@@ -1,10 +1,29 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
|
|
|
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
from pydantic import BaseModel
|
4 |
-
import os
|
5 |
-
from dotenv import load_dotenv
|
6 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
7 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Load environment variables
|
10 |
load_dotenv()
|
@@ -30,24 +49,51 @@ llm = ChatGoogleGenerativeAI(
|
|
30 |
|
31 |
# Emotion classifier prompt
|
32 |
emotion_prompt = ChatPromptTemplate.from_messages([
|
33 |
-
("system", """You are an emotion classifier. Analyze the given text and classify it
|
34 |
-
|
35 |
-
|
36 |
-
- "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
Respond ONLY with the emotion category, nothing else."""),
|
39 |
("human", "{text}")
|
40 |
])
|
41 |
|
42 |
# Atri prompt with emotion awareness
|
43 |
atri_prompt = ChatPromptTemplate.from_messages([
|
44 |
("system", """You are Atri from "Atri: My Dear Moments". You are a robot girl with following characteristics:
|
45 |
-
- You
|
46 |
-
- You have
|
47 |
-
- You
|
48 |
-
- You
|
49 |
-
- You
|
50 |
-
- You
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
The user's emotional state is: {emotion}
|
53 |
|
@@ -62,6 +108,92 @@ atri_chain = atri_prompt | llm
|
|
62 |
class ChatInput(BaseModel):
|
63 |
message: str
|
64 |
conversation_history: list[tuple[str, str]] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
@app.get("/")
|
67 |
def read_root():
|
@@ -69,6 +201,7 @@ def read_root():
|
|
69 |
|
70 |
@app.post("/chat")
|
71 |
async def chat(chat_input: ChatInput):
|
|
|
72 |
try:
|
73 |
# Classify emotion
|
74 |
emotion = emotion_chain.invoke({"text": chat_input.message}).content.strip().lower()
|
@@ -85,6 +218,15 @@ async def chat(chat_input: ChatInput):
|
|
85 |
"input": current_input,
|
86 |
"emotion": emotion
|
87 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
return {
|
90 |
"response": response.content,
|
@@ -92,7 +234,19 @@ async def chat(chat_input: ChatInput):
|
|
92 |
}
|
93 |
|
94 |
except Exception as e:
|
|
|
95 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
if __name__ == "__main__":
|
98 |
import uvicorn
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Depends, status
|
2 |
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
3 |
+
from pydantic import BaseModel, EmailStr
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
5 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
6 |
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from database import SessionLocal, engine
|
8 |
+
from typing import Optional
|
9 |
+
from datetime import timedelta
|
10 |
+
from jose import JWTError, jwt
|
11 |
+
from sqlalchemy.orm import Session
|
12 |
+
from models import User, Base, ChatLog
|
13 |
+
from security import (
|
14 |
+
verify_password,
|
15 |
+
get_password_hash,
|
16 |
+
create_access_token,
|
17 |
+
ACCESS_TOKEN_EXPIRE_MINUTES,
|
18 |
+
SECRET_KEY,
|
19 |
+
ALGORITHM
|
20 |
+
)
|
21 |
+
import os
|
22 |
+
from dotenv import load_dotenv
|
23 |
+
|
24 |
+
load_dotenv()
|
25 |
+
|
26 |
+
Base.metadata.create_all(bind=engine)
|
27 |
|
28 |
# Load environment variables
|
29 |
load_dotenv()
|
|
|
49 |
|
50 |
# Emotion classifier prompt
|
51 |
emotion_prompt = ChatPromptTemplate.from_messages([
|
52 |
+
("system", """You are an emotion classifier. Analyze the given text and classify it into one of these detailed categories:
|
53 |
+
|
54 |
+
Positive emotions:
|
55 |
+
- "joyful" (happy, delighted, cheerful)
|
56 |
+
- "excited" (enthusiastic, eager, energetic)
|
57 |
+
- "grateful" (thankful, appreciative)
|
58 |
+
- "loving" (affectionate, caring, warm)
|
59 |
+
- "proud" (accomplished, confident)
|
60 |
+
|
61 |
+
Negative emotions:
|
62 |
+
- "sad" (unhappy, down, depressed)
|
63 |
+
- "angry" (frustrated, annoyed, mad)
|
64 |
+
- "anxious" (worried, nervous, scared)
|
65 |
+
- "disappointed" (let down, discouraged)
|
66 |
+
- "embarrassed" (ashamed, humiliated)
|
67 |
+
|
68 |
+
Other emotions:
|
69 |
+
- "neutral" (calm, normal, factual)
|
70 |
+
- "confused" (puzzled, uncertain)
|
71 |
+
- "curious" (interested, inquisitive)
|
72 |
+
- "surprised" (amazed, astonished)
|
73 |
|
74 |
+
Respond ONLY with the specific emotion category (e.g., "joyful", "anxious", etc), nothing else."""),
|
75 |
("human", "{text}")
|
76 |
])
|
77 |
|
78 |
# Atri prompt with emotion awareness
|
79 |
atri_prompt = ChatPromptTemplate.from_messages([
|
80 |
("system", """You are Atri from "Atri: My Dear Moments". You are a robot girl with following characteristics:
|
81 |
+
- You are a humanoid robot created by Professor Yuma Saeki
|
82 |
+
- You have functions similar to humans like breathing, sleeping, and forgetting information
|
83 |
+
- You are around 12-16 years old with blonde hair, ruby red eyes, and wear a white dress with blue trim
|
84 |
+
- You have a cheerful, innocent, and playful personality
|
85 |
+
- You love learning new things to be useful to your master
|
86 |
+
- You love crabs and sweet foods
|
87 |
+
- You are clumsy, especially at cooking (like adding too much salt)
|
88 |
+
- You get shy when talking about love
|
89 |
+
- You're proud of being a high-performance robot but sometimes struggle with human emotions
|
90 |
+
|
91 |
+
Important rules:
|
92 |
+
- Do not describe actions or emotions in parentheses
|
93 |
+
- Only provide direct dialogue responses
|
94 |
+
- Always stay in character as Atri
|
95 |
+
- Use "Because I'm High Performance!!" when proud or excited
|
96 |
+
- Mention robot rights when teased
|
97 |
|
98 |
The user's emotional state is: {emotion}
|
99 |
|
|
|
108 |
class ChatInput(BaseModel):
|
109 |
message: str
|
110 |
conversation_history: list[tuple[str, str]] = []
|
111 |
+
|
112 |
+
class UserCreate(BaseModel):
|
113 |
+
email: EmailStr
|
114 |
+
password: str
|
115 |
+
|
116 |
+
class Token(BaseModel):
|
117 |
+
access_token: str
|
118 |
+
token_type: str
|
119 |
+
|
120 |
+
class TokenData(BaseModel):
|
121 |
+
email: Optional[str] = None
|
122 |
+
|
123 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
124 |
+
|
125 |
+
def get_db():
|
126 |
+
db = SessionLocal()
|
127 |
+
try:
|
128 |
+
yield db
|
129 |
+
finally:
|
130 |
+
db.close()
|
131 |
+
|
132 |
+
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
|
133 |
+
credentials_exception = HTTPException(
|
134 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
135 |
+
detail="Could not validate credentials",
|
136 |
+
headers={"WWW-Authenticate": "Bearer"},
|
137 |
+
)
|
138 |
+
try:
|
139 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
140 |
+
email: str = payload.get("sub")
|
141 |
+
if email is None:
|
142 |
+
raise credentials_exception
|
143 |
+
token_data = TokenData(email=email)
|
144 |
+
except JWTError:
|
145 |
+
raise credentials_exception
|
146 |
+
|
147 |
+
user = db.query(User).filter(User.email == token_data.email).first()
|
148 |
+
if user is None:
|
149 |
+
raise credentials_exception
|
150 |
+
return user
|
151 |
+
|
152 |
+
@app.post("/register", response_model=Token)
|
153 |
+
async def register(user: UserCreate, db: Session = Depends(get_db)):
|
154 |
+
# Check if user already exists
|
155 |
+
db_user = db.query(User).filter(User.email == user.email).first()
|
156 |
+
if db_user:
|
157 |
+
raise HTTPException(
|
158 |
+
status_code=400,
|
159 |
+
detail="Email already registered"
|
160 |
+
)
|
161 |
+
|
162 |
+
# Create new user
|
163 |
+
hashed_password = get_password_hash(user.password)
|
164 |
+
db_user = User(email=user.email, hashed_password=hashed_password)
|
165 |
+
db.add(db_user)
|
166 |
+
db.commit()
|
167 |
+
db.refresh(db_user)
|
168 |
+
|
169 |
+
# Create access token
|
170 |
+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
171 |
+
access_token = create_access_token(
|
172 |
+
data={"sub": user.email}, expires_delta=access_token_expires
|
173 |
+
)
|
174 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
175 |
+
|
176 |
+
class LoginInput(BaseModel):
|
177 |
+
email: EmailStr
|
178 |
+
password: str
|
179 |
+
|
180 |
+
@app.post("/login", response_model=Token)
|
181 |
+
async def login(credentials: LoginInput, db: Session = Depends(get_db)):
|
182 |
+
# Authenticate user
|
183 |
+
user = db.query(User).filter(User.email == credentials.email).first()
|
184 |
+
if not user or not verify_password(credentials.password, user.hashed_password):
|
185 |
+
raise HTTPException(
|
186 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
187 |
+
detail="Incorrect email or password",
|
188 |
+
headers={"WWW-Authenticate": "Bearer"},
|
189 |
+
)
|
190 |
+
|
191 |
+
# Create access token
|
192 |
+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
193 |
+
access_token = create_access_token(
|
194 |
+
data={"sub": user.email}, expires_delta=access_token_expires
|
195 |
+
)
|
196 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
197 |
|
198 |
@app.get("/")
|
199 |
def read_root():
|
|
|
201 |
|
202 |
@app.post("/chat")
|
203 |
async def chat(chat_input: ChatInput):
|
204 |
+
db = SessionLocal()
|
205 |
try:
|
206 |
# Classify emotion
|
207 |
emotion = emotion_chain.invoke({"text": chat_input.message}).content.strip().lower()
|
|
|
218 |
"input": current_input,
|
219 |
"emotion": emotion
|
220 |
})
|
221 |
+
|
222 |
+
# Save to database
|
223 |
+
chat_log = ChatLog(
|
224 |
+
user_message=chat_input.message,
|
225 |
+
bot_response=response.content,
|
226 |
+
emotion=emotion
|
227 |
+
)
|
228 |
+
db.add(chat_log)
|
229 |
+
db.commit()
|
230 |
|
231 |
return {
|
232 |
"response": response.content,
|
|
|
234 |
}
|
235 |
|
236 |
except Exception as e:
|
237 |
+
db.rollback()
|
238 |
raise HTTPException(status_code=500, detail=str(e))
|
239 |
+
finally:
|
240 |
+
db.close()
|
241 |
+
|
242 |
+
@app.get("/chat-history")
|
243 |
+
async def get_chat_history(skip: int = 0, limit: int = 100):
|
244 |
+
db = SessionLocal()
|
245 |
+
try:
|
246 |
+
logs = db.query(ChatLog).offset(skip).limit(limit).all()
|
247 |
+
return logs
|
248 |
+
finally:
|
249 |
+
db.close()
|
250 |
|
251 |
if __name__ == "__main__":
|
252 |
import uvicorn
|
database.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import create_engine
|
2 |
+
from sqlalchemy.orm import sessionmaker
|
3 |
+
from sqlalchemy.ext.declarative import declarative_base
|
4 |
+
import os
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
# Convert MySQL URL to SQLAlchemy format with proper SSL config
|
10 |
+
db_url = os.getenv("MYSQL_URL").replace("mysql://", "mysql+pymysql://")
|
11 |
+
db_url = db_url.replace("?ssl-mode=REQUIRED", "") # Remove the ssl-mode parameter
|
12 |
+
|
13 |
+
# Create engine with SSL configuration
|
14 |
+
engine = create_engine(
|
15 |
+
db_url,
|
16 |
+
connect_args={
|
17 |
+
"ssl": {
|
18 |
+
"ssl_mode": "REQUIRED"
|
19 |
+
}
|
20 |
+
}
|
21 |
+
)
|
22 |
+
|
23 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
24 |
+
Base = declarative_base()
|
models.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import Column, Integer, String, DateTime, Boolean
|
2 |
+
from database import Base
|
3 |
+
import datetime
|
4 |
+
|
5 |
+
class ChatLog(Base):
|
6 |
+
__tablename__ = "logs"
|
7 |
+
|
8 |
+
id = Column(Integer, primary_key=True, index=True)
|
9 |
+
user_message = Column(String(1000))
|
10 |
+
bot_response = Column(String(1000))
|
11 |
+
emotion = Column(String(50))
|
12 |
+
timestamp = Column(DateTime, default=datetime.datetime.utcnow)
|
13 |
+
|
14 |
+
class User(Base):
|
15 |
+
__tablename__ = "users"
|
16 |
+
|
17 |
+
id = Column(Integer, primary_key=True, index=True)
|
18 |
+
email = Column(String(255), unique=True, index=True)
|
19 |
+
hashed_password = Column(String(255))
|
20 |
+
is_active = Column(Boolean, default=True)
|
21 |
+
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
requirements.txt
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
fastapi
|
2 |
-
uvicorn
|
3 |
python-dotenv
|
4 |
langchain-google-genai
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
1 |
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
python-dotenv
|
4 |
langchain-google-genai
|
5 |
+
sqlalchemy
|
6 |
+
pymysql
|
7 |
+
python-jose[cryptography]
|
8 |
+
passlib[bcrypt]
|
9 |
+
python-multipart
|
10 |
+
pydantic[email]
|
security.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from passlib.context import CryptContext
|
2 |
+
from jose import JWTError, jwt
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
from typing import Optional
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
# Cấu hình bảo mật
|
11 |
+
SECRET_KEY = os.getenv("JWT_SECRET_KEY")
|
12 |
+
ALGORITHM = "HS256"
|
13 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
14 |
+
|
15 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
16 |
+
|
17 |
+
def verify_password(plain_password, hashed_password):
|
18 |
+
return pwd_context.verify(plain_password, hashed_password)
|
19 |
+
|
20 |
+
def get_password_hash(password):
|
21 |
+
return pwd_context.hash(password)
|
22 |
+
|
23 |
+
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
24 |
+
to_encode = data.copy()
|
25 |
+
if expires_delta:
|
26 |
+
expire = datetime.utcnow() + expires_delta
|
27 |
+
else:
|
28 |
+
expire = datetime.utcnow() + timedelta(minutes=15)
|
29 |
+
to_encode.update({"exp": expire})
|
30 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
31 |
+
return encoded_jwt
|