Patrick Walukagga commited on
Commit
78b79c7
·
1 Parent(s): 728278b

Enable cors and new study choices endpoint

Browse files
api.py CHANGED
@@ -5,6 +5,7 @@ from typing import List, Optional
5
 
6
  from dotenv import load_dotenv
7
  from fastapi import FastAPI, HTTPException
 
8
  from fastapi.responses import FileResponse
9
  from gradio_client import Client
10
  from pydantic import BaseModel, ConfigDict, Field, constr
@@ -15,14 +16,25 @@ load_dotenv()
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
18
  app = FastAPI(
19
  title="ACRES RAG API",
20
  description=description,
21
  openapi_tags=tags_metadata,
22
  )
23
- GRADIO_URL = os.getenv("GRADIO_URL", "http://localhost:7860/")
24
- logger.info(f"GRADIO_URL: {GRADIO_URL}")
25
- client = Client(GRADIO_URL)
 
 
 
 
 
 
 
26
 
27
 
28
  class StudyVariables(str, Enum):
@@ -95,6 +107,12 @@ def process_study_variables(
95
  return {"result": result[0]}
96
 
97
 
 
 
 
 
 
 
98
  @app.post("/download_csv", tags=["zotero"])
99
  def download_csv(download_request: DownloadCSV):
100
  result = client.predict(
 
5
 
6
  from dotenv import load_dotenv
7
  from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import FileResponse
10
  from gradio_client import Client
11
  from pydantic import BaseModel, ConfigDict, Field, constr
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ GRADIO_URL = os.getenv("GRADIO_URL", "http://localhost:7860/")
20
+ logger.info(f"GRADIO_URL: {GRADIO_URL}")
21
+ client = Client(GRADIO_URL)
22
+
23
  app = FastAPI(
24
  title="ACRES RAG API",
25
  description=description,
26
  openapi_tags=tags_metadata,
27
  )
28
+
29
+ origins = ["*"]
30
+
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=origins,
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
 
39
 
40
  class StudyVariables(str, Enum):
 
107
  return {"result": result[0]}
108
 
109
 
110
+ @app.post("/new_study_choices", tags=["zotero"])
111
+ def new_study_choices():
112
+ result = client.predict(api_name="/new_study_choices")
113
+ return {"result": result}
114
+
115
+
116
  @app.post("/download_csv", tags=["zotero"])
117
  def download_csv(download_request: DownloadCSV):
118
  result = client.predict(
bin/cfn/ecs-deploy-update-api CHANGED
@@ -1,4 +1,4 @@
1
- #! /usr/bin/bash
2
 
3
  set -e
4
 
 
1
+ #! /usr/bin/env bash
2
 
3
  set -e
4
 
bin/cfn/ecs-deploy-update-gradio CHANGED
@@ -1,4 +1,4 @@
1
- #! /usr/bin/bash
2
 
3
  set -e
4
 
 
1
+ #! /usr/bin/env bash
2
 
3
  set -e
4