aliMohammad16 commited on
Commit
30a81ee
·
verified ·
1 Parent(s): bbde13c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -21
app.py CHANGED
@@ -1,38 +1,93 @@
 
 
1
  from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
- from transformers import pipeline
4
  import torch
5
 
6
- app = FastAPI(title="Text Summarization API")
 
7
 
8
- summarizer = pipeline(
9
- "summarization",
10
- model="deepseek-ai/deepseek-llm-7b-base",
11
- device=0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
 
14
- class SummarizationRequest(BaseModel):
15
- text: str
16
- max_length: int = 130
17
- min_length: int = 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- class SummarizationResponse(BaseModel):
20
- summary: str
21
 
22
- @app.post("/summarize", response_model=SummarizationResponse)
23
- async def summarize_text(request: SummarizationRequest):
24
  try:
25
- summary = summarizer(
26
- request.text,
27
  max_length=request.max_length,
28
- min_length=request.min_length,
29
- do_sample=False
30
- )[0]['summary_text']
31
-
32
- return SummarizationResponse(summary=summary)
33
  except Exception as e:
34
  raise HTTPException(status_code=500, detail=str(e))
35
 
36
  @app.get("/health")
37
  async def health_check():
38
  return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
  from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
 
9
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
10
+ os.makedirs('/tmp/transformers_cache', exist_ok=True)
11
 
12
+ app = FastAPI(title="DeepSeek LLM Interface")
13
+
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ model_name = "deepseek-ai/deepseek-llm-7b-base"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/tmp/transformers_cache')
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ cache_dir='/tmp/transformers_cache',
27
+ torch_dtype=torch.float16,
28
+ device_map="auto"
29
  )
30
 
31
+ def generate_response(prompt, max_length=500, temperature=0.7):
32
+ """Generate response using the DeepSeek model"""
33
+ try:
34
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
35
+ outputs = model.generate(
36
+ **inputs,
37
+ max_length=max_length,
38
+ temperature=temperature,
39
+ do_sample=True,
40
+ pad_token_id=tokenizer.eos_token_id
41
+ )
42
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ return response
44
+ except Exception as e:
45
+ print(f"Error in generate_response: {str(e)}")
46
+ return f"Error generating response: {str(e)}"
47
+
48
+ class GenerationRequest(BaseModel):
49
+ prompt: str
50
+ max_length: int = 500
51
+ temperature: float = 0.7
52
 
53
+ class GenerationResponse(BaseModel):
54
+ response: str
55
 
56
+ @app.post("/generate", response_model=GenerationResponse)
57
+ async def generate_text(request: GenerationRequest):
58
  try:
59
+ response = generate_response(
60
+ request.prompt,
61
  max_length=request.max_length,
62
+ temperature=request.temperature
63
+ )
64
+ return GenerationResponse(response=response)
 
 
65
  except Exception as e:
66
  raise HTTPException(status_code=500, detail=str(e))
67
 
68
  @app.get("/health")
69
  async def health_check():
70
  return {"status": "healthy"}
71
+
72
+ def gradio_generate(prompt, max_length, temperature):
73
+ return generate_response(prompt, int(max_length), float(temperature))
74
+
75
+ interface = gr.Interface(
76
+ fn=gradio_generate,
77
+ inputs=[
78
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
79
+ gr.Slider(minimum=50, maximum=1000, value=500, step=50, label="Max Length"),
80
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
81
+ ],
82
+ outputs=gr.Textbox(label="Generated Response"),
83
+ title="DeepSeek LLM Interface",
84
+ description="Enter a prompt to generate text using DeepSeek LLM",
85
+ examples=[
86
+ ["Write a short story about a mysterious garden"],
87
+ ["Explain quantum computing in simple terms"],
88
+ ["Create a recipe for chocolate chip cookies"]
89
+ ]
90
+ )
91
+
92
+ app = gr.mount_gradio_app(app, interface, path="/")
93
+