File size: 4,234 Bytes
33b03d6
be8e867
33b03d6
5654017
 
 
33b03d6
5654017
2153031
5654017
 
2153031
 
 
 
5654017
 
2153031
c8ef1f7
 
2153031
c8ef1f7
 
 
2153031
 
4006646
10d28e9
c8ef1f7
 
2153031
 
 
c8ef1f7
2153031
 
c8ef1f7
2153031
 
 
 
 
 
 
5654017
10d28e9
5654017
c8ef1f7
2153031
 
 
 
 
 
 
c8ef1f7
2153031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ef1f7
2153031
 
 
c8ef1f7
2153031
 
 
 
 
10d28e9
5654017
 
2153031
5654017
c8ef1f7
5654017
 
0d73298
 
2153031
0d73298
2153031
 
0d73298
 
2153031
 
 
0d73298
 
 
 
 
 
 
 
 
 
 
5654017
 
2153031
5654017
2153031
c8ef1f7
5654017
2153031
 
 
 
5654017
 
c8ef1f7
2153031
 
 
 
 
 
 
33b03d6
5654017
 
2153031
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging
import sys
import gc

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info("Starting application...")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")

try:
    logger.info("Loading tokenizer...")
    # Use the base model's tokenizer instead
    base_model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_id,
        use_fast=True,
        trust_remote_code=True
    )
    tokenizer.pad_token = tokenizer.eos_token
    logger.info("Tokenizer loaded successfully")

    logger.info("Loading fine-tuned model in 8-bit...")
    model_id = "htigenai/finetune_test_2"
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        load_in_8bit=True,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={0: "12GB", "cpu": "4GB"}
    )
    model.eval()
    logger.info("Model loaded successfully in 8-bit")

    # Clear any residual memory
    gc.collect()
    torch.cuda.empty_cache()

    def generate_text(prompt, max_tokens=100, temperature=0.7):
        try:
            # Format prompt with chat template
            formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
            
            inputs = tokenizer(
                formatted_prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=256
            ).to(model.device)

            with torch.inference_mode():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    do_sample=True,
                    top_p=0.95,
                    repetition_penalty=1.2,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    early_stopping=True,
                    no_repeat_ngram_size=3,
                    use_cache=True
                )

            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract assistant's response
            if "### Assistant:" in response:
                response = response.split("### Assistant:")[-1].strip()
            
            # Clean up
            del outputs, inputs
            gc.collect()
            torch.cuda.empty_cache()
            
            return response

        except Exception as e:
            logger.error(f"Error during generation: {str(e)}")
            return f"Error generating response: {str(e)}"

    # Create Gradio interface
    iface = gr.Interface(
        fn=generate_text,
        inputs=[
            gr.Textbox(
                lines=3, 
                placeholder="Enter your prompt here...",
                label="Input Prompt",
                max_lines=5
            ),
            gr.Slider(
                minimum=10,
                maximum=100,
                value=50,
                step=10,
                label="Max Tokens"
            ),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.7,
                step=0.1,
                label="Temperature"
            )
        ],
        outputs=gr.Textbox(
            label="Generated Response",
            lines=5
        ),
        title="HTIGENAI Reflection Analyzer (8-bit)",
        description="Using Llama 3.1 base tokenizer with fine-tuned model. Keep prompts concise for best results.",
        examples=[
            ["What is machine learning?", 50, 0.7],
            ["Explain quantum computing", 50, 0.7],
        ],
        cache_examples=False
    )

    # Launch interface
    iface.launch(
        server_name="0.0.0.0",
        share=False,
        show_error=True,
        enable_queue=True,
        max_threads=1
    )

except Exception as e:
    logger.error(f"Application startup failed: {str(e)}")
    raise