htigenai commited on
Commit
2153031
·
verified ·
1 Parent(s): 92fdb20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -92
app.py CHANGED
@@ -4,112 +4,103 @@ import torch
4
  import logging
5
  import sys
6
  import gc
7
- import time
8
  from contextlib import contextmanager
9
 
10
  # Set up logging
11
- logging.basicConfig(
12
- level=logging.INFO,
13
- format='%(asctime)s - %(levelname)s - %(message)s',
14
- handlers=[logging.StreamHandler(sys.stdout)]
15
- )
16
  logger = logging.getLogger(__name__)
17
 
18
- @contextmanager
19
- def timer(description: str):
20
- start = time.time()
21
- yield
22
- elapsed = time.time() - start
23
- logger.info(f"{description}: {elapsed:.2f} seconds")
24
-
25
- def log_system_info():
26
- """Log system information for debugging"""
27
- logger.info(f"Python version: {sys.version}")
28
- logger.info(f"PyTorch version: {torch.__version__}")
29
- logger.info(f"Device: CPU")
30
-
31
- print("Starting application...")
32
- log_system_info()
33
 
34
  try:
35
- print("Loading model and tokenizer...")
36
-
37
- model_id = "htigenai/finetune_test" # Replace with your chosen model ID
38
-
39
- with timer("Loading tokenizer"):
40
- tokenizer = AutoTokenizer.from_pretrained(
41
- model_id,
42
- use_fast=True, # Use fast tokenizer for better performance
43
- cache_dir='./cache'
44
- )
45
- tokenizer.pad_token = tokenizer.eos_token
46
  logger.info("Tokenizer loaded successfully")
47
 
48
- with timer("Loading model"):
49
- model = AutoModelForCausalLM.from_pretrained(
50
- model_id,
51
- device_map={"": "cpu"},
52
- cache_dir='./cache'
53
- )
54
- model.eval()
55
- logger.info("Model loaded successfully")
 
 
 
 
 
 
 
56
 
57
  def generate_text(prompt, max_tokens=100, temperature=0.7):
58
- """Generate text based on the input prompt."""
59
  try:
60
- logger.info(f"Starting generation for prompt: {prompt[:50]}...")
61
-
62
- with timer("Tokenization"):
63
- inputs = tokenizer(
64
- prompt,
65
- return_tensors="pt",
66
- padding=True,
67
- truncation=True,
68
- max_length=256
69
- ).to("cpu") # Ensure inputs are on CPU
70
-
71
- with timer("Generation"):
72
- with torch.no_grad():
73
- outputs = model.generate(
74
- input_ids=inputs["input_ids"],
75
- attention_mask=inputs["attention_mask"],
76
- max_new_tokens=max_tokens,
77
- temperature=temperature,
78
- top_p=0.95,
79
- do_sample=True,
80
- pad_token_id=tokenizer.pad_token_id,
81
- eos_token_id=tokenizer.eos_token_id,
82
- repetition_penalty=1.1,
83
- )
84
-
85
- with timer("Decoding"):
86
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
-
88
- logger.info("Text generation completed successfully")
89
-
90
- # Clean up
91
- with timer("Cleanup"):
92
- gc.collect()
93
-
94
- return generated_text
 
 
 
 
95
 
96
  except Exception as e:
97
  logger.error(f"Error during generation: {str(e)}")
98
- return f"Error during generation: {str(e)}"
99
 
100
- # Create Gradio interface
101
  iface = gr.Interface(
102
  fn=generate_text,
103
  inputs=[
104
  gr.Textbox(
105
- lines=3,
106
  placeholder="Enter your prompt here...",
107
- label="Input Prompt"
 
108
  ),
109
  gr.Slider(
110
- minimum=20,
111
- maximum=200,
112
- value=100,
113
  step=10,
114
  label="Max Tokens"
115
  ),
@@ -123,19 +114,26 @@ try:
123
  ],
124
  outputs=gr.Textbox(
125
  label="Generated Response",
126
- lines=10
127
  ),
128
- title="Text Generation Demo",
129
- description="Enter a prompt to generate text.",
130
  examples=[
131
- ["What are your thoughts about cats?", 50, 0.7],
132
- ["Write a short story about a magical forest", 60, 0.8],
133
- ["Explain quantum computing to a 5-year-old", 40, 0.5],
134
- ]
135
  )
136
 
137
- iface.launch()
 
 
 
 
 
 
 
138
 
139
  except Exception as e:
140
  logger.error(f"Application startup failed: {str(e)}")
141
- raise
 
4
  import logging
5
  import sys
6
  import gc
 
7
  from contextlib import contextmanager
8
 
9
  # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
 
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ logger.info("Starting application...")
14
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
15
+ if torch.cuda.is_available():
16
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  try:
19
+ logger.info("Loading tokenizer...")
20
+ model_id = "htigenai/finetune_test_2"
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ model_id,
23
+ use_fast=False # Use slow tokenizer to save memory
24
+ )
25
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
26
  logger.info("Tokenizer loaded successfully")
27
 
28
+ logger.info("Loading model in 8-bit...")
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_id,
31
+ device_map="auto",
32
+ load_in_8bit=True, # Load in 8-bit instead of 4-bit
33
+ torch_dtype=torch.float16,
34
+ low_cpu_mem_usage=True,
35
+ max_memory={0: "12GB", "cpu": "4GB"} # Limit memory usage
36
+ )
37
+ model.eval()
38
+ logger.info("Model loaded successfully in 8-bit")
39
+
40
+ # Clear any residual memory
41
+ gc.collect()
42
+ torch.cuda.empty_cache()
43
 
44
  def generate_text(prompt, max_tokens=100, temperature=0.7):
 
45
  try:
46
+ # Format the prompt
47
+ formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
48
+
49
+ # Generate with memory-efficient settings
50
+ inputs = tokenizer(
51
+ formatted_prompt,
52
+ return_tensors="pt",
53
+ padding=True,
54
+ truncation=True,
55
+ max_length=256 # Limit input length
56
+ ).to(model.device)
57
+
58
+ with torch.inference_mode():
59
+ outputs = model.generate(
60
+ **inputs,
61
+ max_new_tokens=max_tokens,
62
+ temperature=temperature,
63
+ do_sample=True,
64
+ top_p=0.95,
65
+ repetition_penalty=1.2,
66
+ pad_token_id=tokenizer.pad_token_id,
67
+ eos_token_id=tokenizer.eos_token_id,
68
+ early_stopping=True,
69
+ no_repeat_ngram_size=3,
70
+ use_cache=True
71
+ )
72
+
73
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+
75
+ # Extract only the assistant's response
76
+ if "### Assistant:" in response:
77
+ response = response.split("### Assistant:")[-1].strip()
78
+
79
+ # Clean up memory after generation
80
+ del outputs, inputs
81
+ gc.collect()
82
+ torch.cuda.empty_cache()
83
+
84
+ return response
85
 
86
  except Exception as e:
87
  logger.error(f"Error during generation: {str(e)}")
88
+ return f"Error generating response: {str(e)}"
89
 
90
+ # Create a more memory-efficient Gradio interface
91
  iface = gr.Interface(
92
  fn=generate_text,
93
  inputs=[
94
  gr.Textbox(
95
+ lines=3,
96
  placeholder="Enter your prompt here...",
97
+ label="Input Prompt",
98
+ max_lines=5
99
  ),
100
  gr.Slider(
101
+ minimum=10,
102
+ maximum=100,
103
+ value=50,
104
  step=10,
105
  label="Max Tokens"
106
  ),
 
114
  ],
115
  outputs=gr.Textbox(
116
  label="Generated Response",
117
+ lines=5
118
  ),
119
+ title="HTIGENAI Reflection Analyzer (8-bit)",
120
+ description="8-bit quantized text generation. Please keep prompts concise for best results.",
121
  examples=[
122
+ ["What is machine learning?", 50, 0.7],
123
+ ["Explain quantum computing", 50, 0.7],
124
+ ],
125
+ cache_examples=False
126
  )
127
 
128
+ # Launch with minimal memory usage
129
+ iface.launch(
130
+ server_name="0.0.0.0",
131
+ share=False,
132
+ show_error=True,
133
+ enable_queue=True,
134
+ max_threads=1
135
+ )
136
 
137
  except Exception as e:
138
  logger.error(f"Application startup failed: {str(e)}")
139
+ raise