Leri777 commited on
Commit
7ecb022
1 Parent(s): 7096a95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -124
app.py CHANGED
@@ -1,140 +1,134 @@
1
  import os
2
- import logging
3
- import time
4
- import random
5
- from logging.handlers import RotatingFileHandler
6
  import gradio as gr
 
7
  import torch
8
- from transformers import AutoModelForCausalLM, GemmaTokenizerFast, pipeline
9
- from langchain_huggingface import HuggingFacePipeline
10
- from langchain.prompts import PromptTemplate
11
- from langchain.chains import LLMChain
 
 
 
 
 
 
 
12
 
13
- # Logging setup
14
- log_file = '/tmp/app_debug.log'
15
- logger = logging.getLogger(__name__)
16
- logger.setLevel(logging.DEBUG)
17
- file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
18
- file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
19
- logger.addHandler(file_handler)
20
 
21
- logger.debug("Application started")
22
 
23
  model_id = "google/gemma-2-9b-it"
24
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
 
 
 
 
 
 
 
25
 
26
- # Function to load model with GPU availability check
27
- def load_model():
28
- max_attempts = 5
29
- attempts = 0
30
- while attempts < max_attempts:
31
- if torch.cuda.is_available():
32
- logger.debug("GPU is available. Proceeding with GPU setup.")
33
- try:
34
- return AutoModelForCausalLM.from_pretrained(
35
- model_id,
36
- device_map="auto",
37
- torch_dtype=torch.bfloat16,
38
- )
39
- except Exception as e:
40
- logger.error(f"Error initializing model with GPU: {e}. Retrying...")
41
- attempts += 1
42
- time.sleep(random.uniform(20, 60)) # Wait before retrying
43
- else:
44
- logger.warning("GPU is not available. Retrying GPU initialization...")
45
- attempts += 1
46
- time.sleep(random.uniform(20, 60))
47
-
48
- # If GPU is still not available, fall back to CPU
49
- logger.warning("Falling back to CPU setup after multiple attempts.")
50
- return AutoModelForCausalLM.from_pretrained(
51
- model_id,
52
- device_map="auto",
53
- low_cpu_mem_usage=True,
54
- token=os.getenv('HF_TOKEN'),
55
- )
56
 
57
- # Retry logic to load model with random delay
58
- model = None
59
- while model is None:
60
- try:
61
- model = load_model()
62
- model.eval()
63
- except Exception as e:
64
- retry_delay = random.uniform(30, 60) # Increased delay between retries
65
- logger.error(f"Failed to load model: {e}. Retrying in {retry_delay:.2f} seconds...")
66
- time.sleep(retry_delay)
67
-
68
- # Create Hugging Face pipeline
69
- pipe = pipeline(
70
- "text-generation",
71
- model=model,
72
- tokenizer=tokenizer,
73
- max_length=2048,
74
- temperature=0.7,
75
- top_k=50,
76
- top_p=0.9,
77
- repetition_penalty=1.2,
78
- )
79
 
80
- # Initialize HuggingFacePipeline model for LangChain
81
- chat_model = HuggingFacePipeline(pipeline=pipe)
82
-
83
- # Define the conversation template for LangChain
84
- template = """<|im_start|>system
85
- {system_prompt}
86
- <|im_end|>
87
- {history}
88
- <|im_start|>user
89
- {human_input}
90
- <|im_end|>
91
- <|im_start|>assistant"""
92
-
93
- # Create LangChain prompt and chain
94
- prompt = PromptTemplate(
95
- template=template, input_variables=["system_prompt", "history", "human_input"]
96
- )
97
- chain = prompt | chat_model
98
 
99
- # Prediction function using LangChain and model
100
- def predict(message, chat_history=[]):
101
- formatted_history = "\n".join(
102
- [f"<|im_start|>{entry['role']}\n{entry['content']}<|im_end|>" for entry in chat_history]
 
 
 
 
 
 
 
103
  )
104
- system_prompt = "You are a helpful coding assistant."
105
-
106
- try:
107
- result = chain.run({
108
- "system_prompt": system_prompt,
109
- "history": formatted_history,
110
- "human_input": message
111
- })
112
- return result
113
- except Exception as e:
114
- logger.exception(f"Error during prediction: {e}")
115
- return "An error occurred."
116
-
117
- # Gradio UI
118
- interface = gr.Interface(
119
- fn=predict,
120
- inputs=[
121
- gr.Textbox(label="User input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  ],
123
- outputs="text", allow_flagging='never',
124
- live=True,
125
  )
126
 
127
- # Retry logic to launch interface with random delay
128
- max_retries = 5
129
- retry_count = 0
130
- while retry_count < max_retries:
131
- try:
132
- interface.launch()
133
- break
134
- except Exception as e:
135
- retry_delay = random.uniform(60, 120) # Increased delay between retries
136
- logger.error(f"Failed to launch interface: {e}. Retrying in {retry_delay:.2f} seconds...")
137
- retry_count += 1
138
- time.sleep(retry_delay)
139
-
140
- logger.debug("Chat interface initialized and launched")
 
1
  import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
 
5
  import gradio as gr
6
+ import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Gemma 2 9B IT
12
+
13
+ Gemma 2 is Google's latest iteration of open LLMs.
14
+ This is a demo of [`google/gemma-2-9b-it`](https://huggingface.co/google/gemma-2-9b-it), fine-tuned for instruction following.
15
+ For more details, please check [our post](https://huggingface.co/blog/gemma2).
16
+
17
+ 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it).
18
+ """
19
 
20
+ MAX_MAX_NEW_TOKENS = 2048
21
+ DEFAULT_MAX_NEW_TOKENS = 1024
22
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
 
23
 
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
 
26
  model_id = "google/gemma-2-9b-it"
27
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ device_map="auto",
31
+ torch_dtype=torch.bfloat16,
32
+ )
33
+ model.config.sliding_window = 4096
34
+ model.eval()
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ @spaces.GPU(duration=90)
38
+ def generate(
39
+ message: str,
40
+ chat_history: list[dict],
41
+ max_new_tokens: int = 1024,
42
+ temperature: float = 0.6,
43
+ top_p: float = 0.9,
44
+ top_k: int = 50,
45
+ repetition_penalty: float = 1.2,
46
+ ) -> Iterator[str]:
47
+ conversation = chat_history.copy()
48
+ conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
49
 
50
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
51
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
52
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
+ input_ids = input_ids.to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
57
+ generate_kwargs = dict(
58
+ {"input_ids": input_ids},
59
+ streamer=streamer,
60
+ max_new_tokens=max_new_tokens,
61
+ do_sample=True,
62
+ top_p=top_p,
63
+ top_k=top_k,
64
+ temperature=temperature,
65
+ num_beams=1,
66
+ repetition_penalty=repetition_penalty,
67
  )
68
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
69
+ t.start()
70
+
71
+ outputs = []
72
+ for text in streamer:
73
+ outputs.append(text)
74
+ yield "".join(outputs)
75
+
76
+
77
+ chat_interface = gr.ChatInterface(
78
+ fn=generate,
79
+ additional_inputs=[
80
+ gr.Slider(
81
+ label="Max new tokens",
82
+ minimum=1,
83
+ maximum=MAX_MAX_NEW_TOKENS,
84
+ step=1,
85
+ value=DEFAULT_MAX_NEW_TOKENS,
86
+ ),
87
+ gr.Slider(
88
+ label="Temperature",
89
+ minimum=0.1,
90
+ maximum=4.0,
91
+ step=0.1,
92
+ value=0.6,
93
+ ),
94
+ gr.Slider(
95
+ label="Top-p (nucleus sampling)",
96
+ minimum=0.05,
97
+ maximum=1.0,
98
+ step=0.05,
99
+ value=0.9,
100
+ ),
101
+ gr.Slider(
102
+ label="Top-k",
103
+ minimum=1,
104
+ maximum=1000,
105
+ step=1,
106
+ value=50,
107
+ ),
108
+ gr.Slider(
109
+ label="Repetition penalty",
110
+ minimum=1.0,
111
+ maximum=2.0,
112
+ step=0.05,
113
+ value=1.2,
114
+ ),
115
+ ],
116
+ stop_btn=None,
117
+ examples=[
118
+ ["Hello there! How are you doing?"],
119
+ ["Can you explain briefly to me what is the Python programming language?"],
120
+ ["Explain the plot of Cinderella in a sentence."],
121
+ ["How many hours does it take a man to eat a Helicopter?"],
122
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
123
  ],
124
+ cache_examples=False,
125
+ type="messages",
126
  )
127
 
128
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
129
+ gr.Markdown(DESCRIPTION)
130
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
131
+ chat_interface.render()
132
+
133
+ if __name__ == "__main__":
134
+ demo.queue(max_size=20).launch()