rajeshchoudharyt commited on
Commit
3171fa3
1 Parent(s): 97f5fbf

Add application file

Browse files
Files changed (3) hide show
  1. Dockerfile +13 -0
  2. app.py +60 -0
  3. requirements.txt +4 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.15.4
2
+
3
+ # Copy the current directory contents into the container at .
4
+ COPY . .
5
+
6
+ # Set the working directory to /
7
+ WORKDIR /
8
+
9
+ # Install requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r /requirements.txt
11
+
12
+ # Start the FastAPI app on port 7860, the default port expected by Spaces
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from enum import Enum
5
+ import os
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ model = SentenceTransformer(
9
+ "dunzhang/stella_en_400M_v5",
10
+ trust_remote_code=True,
11
+ device="cpu",
12
+ config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False}
13
+ )
14
+
15
+ class Enum(str, Enum):
16
+ s2p_query = "s2p_query" # sentence-to-sentence
17
+ s2s_query = "s2s_query" # sentence-to-passage, Q&A
18
+
19
+ class Embedding(BaseModel):
20
+ input: list[str]
21
+ embedding_type: Enum = None
22
+
23
+
24
+ app = FastAPI()
25
+
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_credentials=True,
30
+ allow_methods=["POST"],
31
+ allow_headers=["Authorization"]
32
+ )
33
+
34
+ def parse(data):
35
+ result = []
36
+ for dimension in data:
37
+ temp = []
38
+ for val in dimension:
39
+ temp.append(round(val, 8))
40
+ result.append(temp)
41
+ return result
42
+
43
+
44
+ @app.post("/embeddings/")
45
+ async def get_embedding(embedding: Embedding, req: Request):
46
+
47
+ token = req.headers.get("Authorization")
48
+ if os.environ.get('token') != token[7:]:
49
+ raise HTTPException(status_code=401, detail="Unauthorized.")
50
+
51
+ if model == None:
52
+ raise HTTPException(status_code=400, detail="Model load failed.")
53
+
54
+ if embedding.embedding_type == None:
55
+ data = model.encode(embedding.input).tolist()
56
+ return parse(data)
57
+ else:
58
+ data = model.encode(embedding.input, prompt_name=embedding.embedding_type).tolist()
59
+ return parse(data)
60
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi==0.108.0
2
+ uvicorn
3
+ pydantic
4
+ sentence_transformers==3.0.1