import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from typing import Optional from app.models.base import BaseModel class Qwen(BaseModel): """A class to handle inference with fine-tuned Qwen2.5 model.""" def __init__( self, model_path: str, system_prompt: Optional[str] = None, temperature: float = 0.6, max_tokens: int = 2048, ) -> None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(self.device) model_path = os.path.expanduser(model_path) quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, quantization_config=quantization_config, device_map="auto", # low_cpu_mem_usage=True ) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.system_prompt = system_prompt self.temperature = temperature self.max_tokens = max_tokens self.messages = [] def ask(self, user_message: str, clear_history: bool = True) -> Optional[str]: if clear_history: self.messages = [] if self.system_prompt: self.messages.append({"role": "system", "text": self.system_prompt}) self.messages.append({"role": "user", "text": user_message}) prompt_text = self._build_prompt() inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=self.max_tokens, temperature=self.temperature, ) response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) lines = response_text.splitlines() assistant_index = next(i for i, line in enumerate(lines) if "assistant" in line) extracted_lines = lines[assistant_index + 1:] response = "\n".join(extracted_lines) self.messages.append({"role": "assistant", "text": response}) return response def _build_prompt(self) -> str: prompt = "" for message in self.messages: if message["role"] == "system": prompt += f"<|im_start|>system\n{message['text']}<|im_end|>\n\n" elif message["role"] == "user": prompt += f"<|im_start|>user\n{message['text']}<|im_end|>\n\n" prompt += "<|im_start|>assistant\n" return prompt if __name__ == "__main__": model_path = "/home/ozakharov/hse_hackathon/Qwen2.5-32B-Instruct-hse_fine_tuned_v2" system_prompt = "Ты - гуру хакатонов. Ты должен доходчиво, объемно и понятно объяснять пользователям их просьбы. Ты должен отвечать в веселом формате, как зумер, с эмодзи и поддерживать морально." qwen = Qwen(model_path=model_path, system_prompt=system_prompt) user_message = "Как выиграть хакатон?" response = qwen.ask(user_message) print(response)