2024-10-16 22:57:05 +00:00
import os
import torch
from transformers import AutoTokenizer , AutoModelForCausalLM , BitsAndBytesConfig
from typing import Optional
2024-10-17 07:24:52 +00:00
from app . models . base import BaseModel
2024-10-16 22:57:05 +00:00
2024-10-17 07:24:52 +00:00
class Qwen ( BaseModel ) :
2024-10-16 22:57:05 +00:00
""" A class to handle inference with fine-tuned Qwen2.5 model. """
def __init__ (
self ,
model_path : str ,
system_prompt : Optional [ str ] = None ,
temperature : float = 0.6 ,
max_tokens : int = 2048 ,
) - > None :
self . device = torch . device ( " cuda " if torch . cuda . is_available ( ) else " cpu " )
print ( self . device )
model_path = os . path . expanduser ( model_path )
quantization_config = BitsAndBytesConfig (
load_in_8bit = True ,
)
self . model = AutoModelForCausalLM . from_pretrained (
model_path ,
torch_dtype = torch . float16 ,
quantization_config = quantization_config ,
device_map = " auto " ,
# low_cpu_mem_usage=True
)
self . tokenizer = AutoTokenizer . from_pretrained ( model_path )
self . system_prompt = system_prompt
self . temperature = temperature
self . max_tokens = max_tokens
self . messages = [ ]
2024-10-17 21:07:19 +00:00
def ask ( self , user_message : str , clear_history : bool = True , debug : bool = False ) - > Optional [ str ] :
2024-10-16 22:57:05 +00:00
if clear_history :
self . messages = [ ]
if self . system_prompt :
self . messages . append ( { " role " : " system " , " text " : self . system_prompt } )
self . messages . append ( { " role " : " user " , " text " : user_message } )
prompt_text = self . _build_prompt ( )
2024-10-17 21:07:19 +00:00
if debug :
print ( prompt_text )
2024-10-16 22:57:05 +00:00
inputs = self . tokenizer ( prompt_text , return_tensors = " pt " ) . to ( self . device )
with torch . no_grad ( ) :
outputs = self . model . generate (
* * inputs ,
max_new_tokens = self . max_tokens ,
temperature = self . temperature ,
)
response_text = self . tokenizer . decode ( outputs [ 0 ] , skip_special_tokens = True )
2024-10-17 07:24:52 +00:00
lines = response_text . splitlines ( )
assistant_index = next ( i for i , line in enumerate ( lines ) if " assistant " in line )
extracted_lines = lines [ assistant_index + 1 : ]
2024-10-18 09:55:50 +00:00
2024-10-17 07:24:52 +00:00
response = " \n " . join ( extracted_lines )
2024-10-18 09:55:50 +00:00
response = response . replace ( " <im_end> " , " " ) . strip ( )
2024-10-17 07:24:52 +00:00
self . messages . append ( { " role " : " assistant " , " text " : response } )
2024-10-18 09:55:50 +00:00
2024-10-17 07:24:52 +00:00
return response
2024-10-16 22:57:05 +00:00
def _build_prompt ( self ) - > str :
prompt = " "
for message in self . messages :
if message [ " role " ] == " system " :
prompt + = f " <|im_start|>system \n { message [ ' text ' ] } <|im_end|> \n \n "
elif message [ " role " ] == " user " :
prompt + = f " <|im_start|>user \n { message [ ' text ' ] } <|im_end|> \n \n "
2024-10-16 23:03:58 +00:00
prompt + = " <|im_start|>assistant \n "
2024-10-16 22:57:05 +00:00
return prompt
if __name__ == " __main__ " :
2024-10-18 20:44:44 +00:00
model_path = " your_path_here "
2024-10-16 22:57:05 +00:00
2024-10-17 07:24:52 +00:00
system_prompt = " Ты - г у р у хакатонов. Ты должен доходчиво, объемно и понятно объяснять пользователям их просьбы. Ты должен отвечать в веселом формате, как зумер, с эмодзи и поддерживать морально. "
2024-10-16 22:57:05 +00:00
qwen = Qwen ( model_path = model_path , system_prompt = system_prompt )
user_message = " Как выиграть хакатон? "
response = qwen . ask ( user_message )
print ( response )