hse-python-assistant/app/models/qwen.py

92 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import Optional
from app.models.base import BaseModel
class Qwen(BaseModel):
"""A class to handle inference with fine-tuned Qwen2.5 model."""
def __init__(
self,
model_path: str,
system_prompt: Optional[str] = None,
temperature: float = 0.6,
max_tokens: int = 2048,
) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(self.device)
model_path = os.path.expanduser(model_path)
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
quantization_config=quantization_config,
device_map="auto",
# low_cpu_mem_usage=True
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.system_prompt = system_prompt
self.temperature = temperature
self.max_tokens = max_tokens
self.messages = []
def ask(self, user_message: str, clear_history: bool = True) -> Optional[str]:
if clear_history:
self.messages = []
if self.system_prompt:
self.messages.append({"role": "system", "text": self.system_prompt})
self.messages.append({"role": "user", "text": user_message})
prompt_text = self._build_prompt()
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
)
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
lines = response_text.splitlines()
assistant_index = next(i for i, line in enumerate(lines) if "assistant" in line)
extracted_lines = lines[assistant_index + 1:]
response = "\n".join(extracted_lines)
self.messages.append({"role": "assistant", "text": response})
return response
def _build_prompt(self) -> str:
prompt = ""
for message in self.messages:
if message["role"] == "system":
prompt += f"<|im_start|>system\n{message['text']}<|im_end|>\n\n"
elif message["role"] == "user":
prompt += f"<|im_start|>user\n{message['text']}<|im_end|>\n\n"
prompt += "<|im_start|>assistant\n"
return prompt
if __name__ == "__main__":
model_path = "/home/ozakharov/hse_hackathon/Qwen2.5-32B-Instruct-hse_fine_tuned"
system_prompt = "Ты - гуру хакатонов. Ты должен доходчиво, объемно и понятно объяснять пользователям их просьбы. Ты должен отвечать в веселом формате, как зумер, с эмодзи и поддерживать морально."
qwen = Qwen(model_path=model_path, system_prompt=system_prompt)
user_message = "Как выиграть хакатон?"
response = qwen.ask(user_message)
print(response)