Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Form | |
from fastapi.security import OAuth2PasswordBearer | |
from jose import JWTError, jwt | |
from pinecone import Pinecone | |
import os | |
from dotenv import load_dotenv | |
from PIL import Image | |
import io | |
from transformers import AutoProcessor, CLIPModel | |
import numpy as np | |
from datetime import datetime, timedelta | |
# Load environment variables | |
load_dotenv() | |
# JWT Config | |
SECRET_KEY = os.getenv("JWT_SECRET", "default_secret") # Use a secure secret in production | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
# Fake user database (replace with real authentication logic) | |
fake_users_db = { | |
"admin": { | |
"username": "admin", | |
"password": "password123" # Replace with hashed password in production | |
} | |
} | |
# Initialize FastAPI | |
app = FastAPI() | |
# Load Pinecone API key | |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
if not PINECONE_API_KEY: | |
raise RuntimeError("PINECONE_API_KEY is not set. Please set it in the environment or .env file.") | |
# Initialize Pinecone | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
index_name = "images-index" | |
unsplash_index = pc.Index(index_name) | |
# Load CLIP model and processor | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
model.eval() # Ensure model is in evaluation mode | |
# OAuth2 authentication | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token") | |
def create_access_token(data: dict, expires_delta: timedelta = None): | |
to_encode = data.copy() | |
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) | |
to_encode.update({"exp": expire}) | |
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
def authenticate_user(username: str, password: str): | |
user = fake_users_db.get(username) | |
if not user or user["password"] != password: | |
return None | |
return user | |
def get_current_user(token: str = Depends(oauth2_scheme)): | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username: str = payload.get("sub") | |
if username is None or username not in fake_users_db: | |
raise HTTPException(status_code=401, detail="Invalid authentication") | |
return username | |
except JWTError: | |
raise HTTPException(status_code=401, detail="Invalid authentication") | |
async def login(username: str = Form(...), password: str = Form(...)): | |
user = authenticate_user(username, password) | |
if not user: | |
raise HTTPException(status_code=400, detail="Incorrect username or password") | |
access_token = create_access_token(data={"sub": user["username"]}) | |
return {"access_token": access_token, "token_type": "bearer"} | |
def get_text_embedding(text: str): | |
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) | |
text_features = model.get_text_features(**inputs) | |
return text_features.detach().cpu().numpy().flatten().tolist() | |
def get_image_embedding(image: Image.Image): | |
inputs = processor(images=image, return_tensors="pt") | |
image_features = model.get_image_features(**inputs) | |
return image_features.detach().cpu().numpy().flatten().tolist() | |
def search_similar_images(embedding: list, top_k: int = 10): | |
results = unsplash_index.query( | |
vector=embedding, | |
top_k=top_k, | |
include_metadata=True, | |
namespace="image-search-dataset" | |
) | |
return results["matches"] | |
async def search_by_text(query: str, user: str = Depends(get_current_user)): | |
if not query: | |
raise HTTPException(status_code=400, detail="Query text is required") | |
embedding = get_text_embedding(query) | |
matches = search_similar_images(embedding) | |
return {"matches": [{"id": m["id"], "score": m["score"], "url": m["metadata"]["url"]} for m in matches]} | |
async def search_by_image(file: UploadFile = File(...), user: str = Depends(get_current_user)): | |
try: | |
image_data = await file.read() | |
image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
embedding = get_image_embedding(image) | |
matches = search_similar_images(embedding) | |
return {"matches": [{"id": m["id"], "score": m["score"], "url": m["metadata"]["url"]} for m in matches]} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |