|
import gradio as gr |
|
import torch |
|
from mamba_model import MambaModel |
|
|
|
|
|
model = MambaModel.from_pretrained(pretrained_model_name="Zyphra/BlackMamba-2.8B") |
|
model = model.cuda().half() |
|
|
|
|
|
def generate_output(input_text): |
|
|
|
try: |
|
input_ids = [int(x.strip()) for x in input_text.split(",")] |
|
inputs = torch.tensor(input_ids).cuda().long().unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
out = model(inputs) |
|
|
|
|
|
return out.cpu().numpy().tolist() |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
input_component = gr.Textbox(label="Input IDs (comma-separated)", placeholder="Enter input IDs like: 1, 2") |
|
output_component = gr.Textbox(label="Output") |
|
|
|
iface = gr.Interface(fn=generate_output, inputs=input_component, outputs=output_component, title="BlackMamba Model") |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |