96 lines
3.3 KiB
Python
96 lines
3.3 KiB
Python
import os
|
||
import torch
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||
from typing import Optional
|
||
|
||
from app.models.base import BaseModel
|
||
|
||
class Qwen(BaseModel):
|
||
"""A class to handle inference with fine-tuned Qwen2.5 model."""
|
||
|
||
def __init__(
|
||
self,
|
||
model_path: str,
|
||
system_prompt: Optional[str] = None,
|
||
temperature: float = 0.6,
|
||
max_tokens: int = 2048,
|
||
) -> None:
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(self.device)
|
||
|
||
model_path = os.path.expanduser(model_path)
|
||
quantization_config = BitsAndBytesConfig(
|
||
load_in_8bit=True,
|
||
)
|
||
self.model = AutoModelForCausalLM.from_pretrained(
|
||
model_path,
|
||
torch_dtype=torch.float16,
|
||
quantization_config=quantization_config,
|
||
device_map="auto",
|
||
# low_cpu_mem_usage=True
|
||
)
|
||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
||
self.system_prompt = system_prompt
|
||
self.temperature = temperature
|
||
self.max_tokens = max_tokens
|
||
self.messages = []
|
||
|
||
def ask(self, user_message: str, clear_history: bool = True, debug: bool = False) -> Optional[str]:
|
||
if clear_history:
|
||
self.messages = []
|
||
if self.system_prompt:
|
||
self.messages.append({"role": "system", "text": self.system_prompt})
|
||
|
||
self.messages.append({"role": "user", "text": user_message})
|
||
|
||
prompt_text = self._build_prompt()
|
||
|
||
if debug:
|
||
print(prompt_text)
|
||
|
||
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
|
||
|
||
with torch.no_grad():
|
||
outputs = self.model.generate(
|
||
**inputs,
|
||
max_new_tokens=self.max_tokens,
|
||
temperature=self.temperature,
|
||
)
|
||
|
||
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
lines = response_text.splitlines()
|
||
assistant_index = next(i for i, line in enumerate(lines) if "assistant" in line)
|
||
extracted_lines = lines[assistant_index + 1:]
|
||
|
||
response = "\n".join(extracted_lines)
|
||
response = response.replace("<im_end>", "").strip()
|
||
|
||
self.messages.append({"role": "assistant", "text": response})
|
||
|
||
return response
|
||
|
||
def _build_prompt(self) -> str:
|
||
prompt = ""
|
||
for message in self.messages:
|
||
if message["role"] == "system":
|
||
prompt += f"<|im_start|>system\n{message['text']}<|im_end|>\n\n"
|
||
elif message["role"] == "user":
|
||
prompt += f"<|im_start|>user\n{message['text']}<|im_end|>\n\n"
|
||
prompt += "<|im_start|>assistant\n"
|
||
return prompt
|
||
|
||
|
||
if __name__ == "__main__":
|
||
model_path = "your_path_here"
|
||
|
||
system_prompt = "Ты - гуру хакатонов. Ты должен доходчиво, объемно и понятно объяснять пользователям их просьбы. Ты должен отвечать в веселом формате, как зумер, с эмодзи и поддерживать морально."
|
||
|
||
qwen = Qwen(model_path=model_path, system_prompt=system_prompt)
|
||
|
||
user_message = "Как выиграть хакатон?"
|
||
|
||
response = qwen.ask(user_message)
|
||
print(response)
|