logeswari commited on
Commit
e31b822
·
1 Parent(s): 94287ed
Files changed (2) hide show
  1. main.py +19 -7
  2. requirements.txt +1 -0
main.py CHANGED
@@ -1,5 +1,5 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Depends
2
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
3
  from jose import JWTError, jwt
4
  from pinecone import Pinecone
5
  import os
@@ -13,7 +13,8 @@ from datetime import datetime, timedelta
13
  # Load environment variables
14
  load_dotenv()
15
 
16
- # Secret key for JWT
 
17
  ALGORITHM = "HS256"
18
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
19
 
@@ -41,15 +42,18 @@ unsplash_index = pc.Index(index_name)
41
  # Load CLIP model and processor
42
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
43
  processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
44
 
45
  # OAuth2 authentication
46
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")
47
 
 
48
  def create_access_token(data: dict, expires_delta: timedelta = None):
49
  to_encode = data.copy()
50
  expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
51
  to_encode.update({"exp": expire})
52
- return jwt.encode(to_encode, "secret", algorithm=ALGORITHM)
 
53
 
54
  def authenticate_user(username: str, password: str):
55
  user = fake_users_db.get(username)
@@ -57,9 +61,10 @@ def authenticate_user(username: str, password: str):
57
  return None
58
  return user
59
 
 
60
  def get_current_user(token: str = Depends(oauth2_scheme)):
61
  try:
62
- payload = jwt.decode(token, "secret", algorithms=[ALGORITHM])
63
  username: str = payload.get("sub")
64
  if username is None or username not in fake_users_db:
65
  raise HTTPException(status_code=401, detail="Invalid authentication")
@@ -67,24 +72,28 @@ def get_current_user(token: str = Depends(oauth2_scheme)):
67
  except JWTError:
68
  raise HTTPException(status_code=401, detail="Invalid authentication")
69
 
 
70
  @app.post("/token")
71
- async def login(form_data: OAuth2PasswordRequestForm = Depends()):
72
- user = authenticate_user(form_data.username, form_data.password)
73
  if not user:
74
  raise HTTPException(status_code=400, detail="Incorrect username or password")
75
  access_token = create_access_token(data={"sub": user["username"]})
76
  return {"access_token": access_token, "token_type": "bearer"}
77
 
 
78
  def get_text_embedding(text: str):
79
  inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
80
  text_features = model.get_text_features(**inputs)
81
  return text_features.detach().cpu().numpy().flatten().tolist()
82
 
 
83
  def get_image_embedding(image: Image.Image):
84
  inputs = processor(images=image, return_tensors="pt")
85
  image_features = model.get_image_features(**inputs)
86
  return image_features.detach().cpu().numpy().flatten().tolist()
87
 
 
88
  def search_similar_images(embedding: list, top_k: int = 10):
89
  results = unsplash_index.query(
90
  vector=embedding,
@@ -94,6 +103,7 @@ def search_similar_images(embedding: list, top_k: int = 10):
94
  )
95
  return results["matches"]
96
 
 
97
  @app.get("/search/text/")
98
  async def search_by_text(query: str, user: str = Depends(get_current_user)):
99
  if not query:
@@ -102,6 +112,7 @@ async def search_by_text(query: str, user: str = Depends(get_current_user)):
102
  matches = search_similar_images(embedding)
103
  return {"matches": [{"id": m["id"], "score": m["score"], "url": m["metadata"]["url"]} for m in matches]}
104
 
 
105
  @app.post("/search/image/")
106
  async def search_by_image(file: UploadFile = File(...), user: str = Depends(get_current_user)):
107
  try:
@@ -113,6 +124,7 @@ async def search_by_image(file: UploadFile = File(...), user: str = Depends(get_
113
  except Exception as e:
114
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
115
 
 
116
  if __name__ == "__main__":
117
  import uvicorn
118
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Form
2
+ from fastapi.security import OAuth2PasswordBearer
3
  from jose import JWTError, jwt
4
  from pinecone import Pinecone
5
  import os
 
13
  # Load environment variables
14
  load_dotenv()
15
 
16
+ # JWT Config
17
+ SECRET_KEY = os.getenv("JWT_SECRET", "default_secret") # Use a secure secret in production
18
  ALGORITHM = "HS256"
19
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
20
 
 
42
  # Load CLIP model and processor
43
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
44
  processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
45
+ model.eval() # Ensure model is in evaluation mode
46
 
47
  # OAuth2 authentication
48
  oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")
49
 
50
+
51
  def create_access_token(data: dict, expires_delta: timedelta = None):
52
  to_encode = data.copy()
53
  expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
54
  to_encode.update({"exp": expire})
55
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
56
+
57
 
58
  def authenticate_user(username: str, password: str):
59
  user = fake_users_db.get(username)
 
61
  return None
62
  return user
63
 
64
+
65
  def get_current_user(token: str = Depends(oauth2_scheme)):
66
  try:
67
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
68
  username: str = payload.get("sub")
69
  if username is None or username not in fake_users_db:
70
  raise HTTPException(status_code=401, detail="Invalid authentication")
 
72
  except JWTError:
73
  raise HTTPException(status_code=401, detail="Invalid authentication")
74
 
75
+
76
  @app.post("/token")
77
+ async def login(username: str = Form(...), password: str = Form(...)):
78
+ user = authenticate_user(username, password)
79
  if not user:
80
  raise HTTPException(status_code=400, detail="Incorrect username or password")
81
  access_token = create_access_token(data={"sub": user["username"]})
82
  return {"access_token": access_token, "token_type": "bearer"}
83
 
84
+
85
  def get_text_embedding(text: str):
86
  inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
87
  text_features = model.get_text_features(**inputs)
88
  return text_features.detach().cpu().numpy().flatten().tolist()
89
 
90
+
91
  def get_image_embedding(image: Image.Image):
92
  inputs = processor(images=image, return_tensors="pt")
93
  image_features = model.get_image_features(**inputs)
94
  return image_features.detach().cpu().numpy().flatten().tolist()
95
 
96
+
97
  def search_similar_images(embedding: list, top_k: int = 10):
98
  results = unsplash_index.query(
99
  vector=embedding,
 
103
  )
104
  return results["matches"]
105
 
106
+
107
  @app.get("/search/text/")
108
  async def search_by_text(query: str, user: str = Depends(get_current_user)):
109
  if not query:
 
112
  matches = search_similar_images(embedding)
113
  return {"matches": [{"id": m["id"], "score": m["score"], "url": m["metadata"]["url"]} for m in matches]}
114
 
115
+
116
  @app.post("/search/image/")
117
  async def search_by_image(file: UploadFile = File(...), user: str = Depends(get_current_user)):
118
  try:
 
124
  except Exception as e:
125
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
126
 
127
+
128
  if __name__ == "__main__":
129
  import uvicorn
130
  uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -8,3 +8,4 @@ numpy
8
  pinecone
9
  python-jose[cryptography]
10
  python-multipart
 
 
8
  pinecone
9
  python-jose[cryptography]
10
  python-multipart
11
+ jose