Empereur-Pirate commited on
Commit
6ee90aa
·
verified ·
1 Parent(s): 01d7910

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -44
main.py CHANGED
@@ -1,13 +1,10 @@
 
 
1
  from fastapi import FastAPI, Request, Depends
2
- from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
- from transformers import pipeline
5
  from pydantic import BaseModel
6
  from typing import Optional, Any
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, GenerationConfig
9
- import os
10
- from huggingface_hub import login
11
 
12
  # Check whether we are executing inside a Hugging Face Space
13
  SPACE_NAME = os.getenv("SPACE_NAME", default=None)
@@ -26,28 +23,10 @@ except KeyError:
26
  print('The environment variable "HF_ACCESS_TOKEN" is not found. Please configure it correctly in your Space.')
27
  sys.exit(1)
28
 
29
- # Packages and model loading
30
- import torch
31
- base_model_id = "152334H/miqu-1-70b-sf"
32
- bnb_config = BitsAndBytesConfig(
33
- load_in_4bit=True,
34
- bnb_4bit_use_double_quant=True,
35
- bnb_4bit_quant_type="nf4",
36
- bnb_4bit_compute_dtype=torch.bfloat16
37
- )
38
-
39
- base_model = AutoModelForCausalLM.from_pretrained(
40
- base_model_id,
41
- quantization_config=bnb_config,
42
- device_map="auto",
43
- trust_remote_code=True,
44
- )
45
-
46
- # Tokenizer loading
47
- eval_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf", add_bos_token=True, trust_remote_code=True, use_auth_token=True)
48
-
49
- # Streamer
50
- streamer = TextStreamer(eval_tokenizer)
51
 
52
  # App definition
53
  app = FastAPI()
@@ -56,22 +35,20 @@ app = FastAPI()
56
  async def parse_raw(request: Request):
57
  return await request.body()
58
 
59
- # Generate text
60
  def generate_text(prompt: str) -> str:
61
- model_input = eval_tokenizer(prompt, return_tensors="pt").to("cuda")
62
-
63
- base_model.eval()
64
- with torch.no_grad():
65
- generated_sequences = base_model.generate(
66
- **model_input,
67
- max_new_tokens=4096,
68
- repetition_penalty=1.1,
69
- do_sample=True,
70
- temperature=1,
71
- streamer=streamer,
72
- )
73
-
74
- return eval_tokenizer.decode(generated_sequences[0], skip_special_tokens=True)
75
 
76
  # Route for generating text
77
  @app.post("/generate_text")
@@ -83,4 +60,4 @@ async def generate_text_route(data: BaseModel = Depends(parse_raw)):
83
  return {"output": generate_text(input_text)}
84
 
85
  # Mount static files
86
- app.mount("/static", StaticFiles(directory="static"), name="static")
 
1
+ import os
2
+ import requests
3
  from fastapi import FastAPI, Request, Depends
4
+ from fastapi.responses import JSONResponse
5
  from fastapi.staticfiles import StaticFiles
 
6
  from pydantic import BaseModel
7
  from typing import Optional, Any
 
 
 
 
8
 
9
  # Check whether we are executing inside a Hugging Face Space
10
  SPACE_NAME = os.getenv("SPACE_NAME", default=None)
 
23
  print('The environment variable "HF_ACCESS_TOKEN" is not found. Please configure it correctly in your Space.')
24
  sys.exit(1)
25
 
26
+ # Set up the API endpoint and headers
27
+ model_id = "152334H/miqu-1-70b-sf"
28
+ endpoint = f"https://api-inference.huggingface.co/models/{model_id}"
29
+ headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # App definition
32
  app = FastAPI()
 
35
  async def parse_raw(request: Request):
36
  return await request.body()
37
 
38
+ # Generate text using the Inference API
39
  def generate_text(prompt: str) -> str:
40
+ data = {
41
+ "inputs": prompt,
42
+ "options": {
43
+ "max_new_tokens": 200,
44
+ "temperature": 0.7,
45
+ "top_p": 0.95,
46
+ "use_cache": False,
47
+ },
48
+ }
49
+
50
+ response = requests.post(endpoint, headers=headers, json=data)
51
+ return response.json()["generated_text"]
 
 
52
 
53
  # Route for generating text
54
  @app.post("/generate_text")
 
60
  return {"output": generate_text(input_text)}
61
 
62
  # Mount static files
63
+ app.mount("/static", StaticFiles(directory="static"), name="static")